瀏覽代碼

Use options instead of cmd_args

Kohaku-Blueleaf 1 年之前
父節點
當前提交
598da5cd49
共有 6 個文件被更改,包括 49 次插入42 次删除
  1. 0 2
      modules/cmd_args.py
  2. 14 11
      modules/devices.py
  3. 1 0
      modules/initialize_util.py
  4. 32 29
      modules/sd_models.py
  5. 1 0
      modules/shared_options.py
  6. 1 0
      scripts/xyz_grid.py

+ 0 - 2
modules/cmd_args.py

@@ -118,5 +118,3 @@ parser.add_argument('--timeout-keep-alive', type=int, default=30, help='set time
 parser.add_argument("--disable-all-extensions", action='store_true', help="prevent all extensions from running regardless of any other settings", default=False)
 parser.add_argument("--disable-extra-extensions", action='store_true', help="prevent all extensions except built-in from running regardless of any other settings", default=False)
 parser.add_argument("--skip-load-model-at-start", action='store_true', help="if load a model at web start, only take effect when --nowebui", )
-parser.add_argument("--opt-unet-fp8-storage", action='store_true', help="use fp8 for SD UNet to save vram", default=False)
-parser.add_argument("--opt-unet-fp8-storage-xl", action='store_true', help="use fp8 for SD UNet to save vram", default=False)

+ 14 - 11
modules/devices.py

@@ -20,15 +20,15 @@ def cuda_no_autocast(device_id=None) -> bool:
     if device_id is None:
         device_id = get_cuda_device_id()
     return (
-        torch.cuda.get_device_capability(device_id) == (7, 5) 
+        torch.cuda.get_device_capability(device_id) == (7, 5)
         and torch.cuda.get_device_name(device_id).startswith("NVIDIA GeForce GTX 16")
     )
 
 
 def get_cuda_device_id():
     return (
-        int(shared.cmd_opts.device_id) 
-        if shared.cmd_opts.device_id is not None and shared.cmd_opts.device_id.isdigit() 
+        int(shared.cmd_opts.device_id)
+        if shared.cmd_opts.device_id is not None and shared.cmd_opts.device_id.isdigit()
         else 0
     ) or torch.cuda.current_device()
 
@@ -116,16 +116,19 @@ patch_module_list = [
     torch.nn.LayerNorm,
 ]
 
+
+def manual_cast_forward(self, *args, **kwargs):
+    org_dtype = next(self.parameters()).dtype
+    self.to(dtype)
+    args = [arg.to(dtype) if isinstance(arg, torch.Tensor) else arg for arg in args]
+    kwargs = {k: v.to(dtype) if isinstance(v, torch.Tensor) else v for k, v in kwargs.items()}
+    result = self.org_forward(*args, **kwargs)
+    self.to(org_dtype)
+    return result
+
+
 @contextlib.contextmanager
 def manual_autocast():
-    def manual_cast_forward(self, *args, **kwargs):
-        org_dtype = next(self.parameters()).dtype
-        self.to(dtype)
-        args = [arg.to(dtype) if isinstance(arg, torch.Tensor) else arg for arg in args]
-        kwargs = {k: v.to(dtype) if isinstance(v, torch.Tensor) else v for k, v in kwargs.items()}
-        result = self.org_forward(*args, **kwargs)
-        self.to(org_dtype)
-        return result
     for module_type in patch_module_list:
         org_forward = module_type.forward
         module_type.forward = manual_cast_forward

+ 1 - 0
modules/initialize_util.py

@@ -177,6 +177,7 @@ def configure_opts_onchange():
     shared.opts.onchange("temp_dir", ui_tempdir.on_tmpdir_changed)
     shared.opts.onchange("gradio_theme", shared.reload_gradio_theme)
     shared.opts.onchange("cross_attention_optimization", wrap_queued_call(lambda: sd_hijack.model_hijack.redo_hijack(shared.sd_model)), call=False)
+    shared.opts.onchange("fp8_storage", wrap_queued_call(lambda: sd_models.reload_model_weights()), call=False)
     startup_timer.record("opts onchange")
 
 

+ 32 - 29
modules/sd_models.py

@@ -339,10 +339,28 @@ class SkipWritingToConfig:
         SkipWritingToConfig.skip = self.previous
 
 
+def check_fp8(model):
+    if model is None:
+        return None
+    if devices.get_optimal_device_name() == "mps":
+        enable_fp8 = False
+    elif shared.opts.fp8_storage == "Enable":
+        enable_fp8 = True
+    elif getattr(model, "is_sdxl", False) and shared.opts.fp8_storage == "Enable for SDXL":
+        enable_fp8 = True
+    else:
+        enable_fp8 = False
+    return enable_fp8
+
+
 def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer):
     sd_model_hash = checkpoint_info.calculate_shorthash()
     timer.record("calculate hash")
 
+    if not check_fp8(model) and devices.fp8:
+        # prevent model to load state dict in fp8
+        model.half()
+
     if not SkipWritingToConfig.skip:
         shared.opts.data["sd_model_checkpoint"] = checkpoint_info.title
 
@@ -395,34 +413,16 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
         devices.dtype_unet = torch.float16
         timer.record("apply half()")
 
-    if devices.get_optimal_device_name() == "mps":
-        enable_fp8 = False
-    elif shared.cmd_opts.opt_unet_fp8_storage:
-        enable_fp8 = True
-    elif model.is_sdxl and shared.cmd_opts.opt_unet_fp8_storage_xl:
-        enable_fp8 = True
-    else:
-        enable_fp8 = False
-
-    if enable_fp8:
+    if check_fp8(model):
         devices.fp8 = True
-        if model.is_sdxl:
-            cond_stage = model.conditioner
-        else:
-            cond_stage = model.cond_stage_model
-
-        for module in cond_stage.modules():
-            if isinstance(module, torch.nn.Linear):
+        first_stage = model.first_stage_model
+        model.first_stage_model = None
+        for module in model.modules():
+            if isinstance(module, torch.nn.Conv2d):
                 module.to(torch.float8_e4m3fn)
-
-        if devices.device == devices.cpu:
-            for module in model.model.diffusion_model.modules():
-                if isinstance(module, torch.nn.Conv2d):
-                    module.to(torch.float8_e4m3fn)
-                elif isinstance(module, torch.nn.Linear):
-                    module.to(torch.float8_e4m3fn)
-        else:
-            model.model.diffusion_model = model.model.diffusion_model.to(torch.float8_e4m3fn)
+            elif isinstance(module, torch.nn.Linear):
+                module.to(torch.float8_e4m3fn)
+        model.first_stage_model = first_stage
         timer.record("apply fp8")
     else:
         devices.fp8 = False
@@ -769,7 +769,7 @@ def reuse_model_from_already_loaded(sd_model, checkpoint_info, timer):
         return None
 
 
-def reload_model_weights(sd_model=None, info=None):
+def reload_model_weights(sd_model=None, info=None, forced_reload=False):
     checkpoint_info = info or select_checkpoint()
 
     timer = Timer()
@@ -781,11 +781,14 @@ def reload_model_weights(sd_model=None, info=None):
         current_checkpoint_info = None
     else:
         current_checkpoint_info = sd_model.sd_checkpoint_info
-        if sd_model.sd_model_checkpoint == checkpoint_info.filename:
+        if check_fp8(sd_model) != devices.fp8:
+            # load from state dict again to prevent extra numerical errors
+            forced_reload = True
+        elif sd_model.sd_model_checkpoint == checkpoint_info.filename:
             return sd_model
 
     sd_model = reuse_model_from_already_loaded(sd_model, checkpoint_info, timer)
-    if sd_model is not None and sd_model.sd_checkpoint_info.filename == checkpoint_info.filename:
+    if not forced_reload and sd_model is not None and sd_model.sd_checkpoint_info.filename == checkpoint_info.filename:
         return sd_model
 
     if sd_model is not None:

+ 1 - 0
modules/shared_options.py

@@ -200,6 +200,7 @@ options_templates.update(options_section(('optimizations', "Optimizations"), {
     "pad_cond_uncond": OptionInfo(False, "Pad prompt/negative prompt to be same length", infotext='Pad conds').info("improves performance when prompt and negative prompt have different lengths; changes seeds"),
     "persistent_cond_cache": OptionInfo(True, "Persistent cond cache").info("do not recalculate conds from prompts if prompts have not changed since previous calculation"),
     "batch_cond_uncond": OptionInfo(True, "Batch cond/uncond").info("do both conditional and unconditional denoising in one batch; uses a bit more VRAM during sampling, but improves speed; previously this was controlled by --always-batch-cond-uncond comandline argument"),
+    "fp8_storage": OptionInfo("Disable", "FP8 weight", gr.Dropdown, {"choices": ["Disable", "Enable for SDXL", "Enable"]}).info("Use FP8 to store Linear/Conv layers' weight. Require pytorch>=2.1.0."),
 }))
 
 options_templates.update(options_section(('compatibility', "Compatibility"), {

+ 1 - 0
scripts/xyz_grid.py

@@ -270,6 +270,7 @@ axis_options = [
     AxisOption("Refiner checkpoint", str, apply_field('refiner_checkpoint'), format_value=format_remove_path, confirm=confirm_checkpoints_or_none, cost=1.0, choices=lambda: ['None'] + sorted(sd_models.checkpoints_list, key=str.casefold)),
     AxisOption("Refiner switch at", float, apply_field('refiner_switch_at')),
     AxisOption("RNG source", str, apply_override("randn_source"), choices=lambda: ["GPU", "CPU", "NV"]),
+    AxisOption("FP8 mode", str, apply_override("fp8_storage"), cost=0.9, choices=lambda: ["Disable", "Enable for SDXL", "Enable"]),
 ]