AUTOMATIC1111 2 жил өмнө
parent
commit
016554e437

+ 1 - 0
modules/cmd_args.py

@@ -35,6 +35,7 @@ parser.add_argument("--hypernetwork-dir", type=str, default=os.path.join(models_
 parser.add_argument("--localizations-dir", type=str, default=os.path.join(script_path, 'localizations'), help="localizations directory")
 parser.add_argument("--allow-code", action='store_true', help="allow custom script execution from webui")
 parser.add_argument("--medvram", action='store_true', help="enable stable diffusion model optimizations for sacrificing a little speed for low VRM usage")
+parser.add_argument("--medvram-sdxl", action='store_true', help="enable --medvram optimization just for SDXL models")
 parser.add_argument("--lowvram", action='store_true', help="enable stable diffusion model optimizations for sacrificing a lot of speed for very low VRM usage")
 parser.add_argument("--lowram", action='store_true', help="load stable diffusion checkpoint weights to VRAM instead of RAM")
 parser.add_argument("--always-batch-cond-uncond", action='store_true', help="does not do anything")

+ 2 - 3
modules/interrogate.py

@@ -186,9 +186,8 @@ class InterrogateModels:
         res = ""
         shared.state.begin(job="interrogate")
         try:
-            if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
-                lowvram.send_everything_to_cpu()
-                devices.torch_gc()
+            lowvram.send_everything_to_cpu()
+            devices.torch_gc()
 
             self.load()
 

+ 16 - 2
modules/lowvram.py

@@ -1,5 +1,5 @@
 import torch
-from modules import devices
+from modules import devices, shared
 
 module_in_gpu = None
 cpu = torch.device("cpu")
@@ -14,6 +14,20 @@ def send_everything_to_cpu():
     module_in_gpu = None
 
 
+def is_needed(sd_model):
+    return shared.cmd_opts.lowvram or shared.cmd_opts.medvram or shared.cmd_opts.medvram_sdxl and hasattr(sd_model, 'conditioner')
+
+
+def apply(sd_model):
+    enable = is_needed(sd_model)
+    shared.parallel_processing_allowed = not enable
+
+    if enable:
+        setup_for_low_vram(sd_model, not shared.cmd_opts.lowvram)
+    else:
+        sd_model.lowvram = False
+
+
 def setup_for_low_vram(sd_model, use_medvram):
     if getattr(sd_model, 'lowvram', False):
         return
@@ -130,4 +144,4 @@ def setup_for_low_vram(sd_model, use_medvram):
 
 
 def is_enabled(sd_model):
-    return getattr(sd_model, 'lowvram', False)
+    return sd_model.lowvram

+ 8 - 8
modules/sd_models.py

@@ -517,7 +517,7 @@ def get_empty_cond(sd_model):
 
 
 def send_model_to_cpu(m):
-    if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
+    if m.lowvram:
         lowvram.send_everything_to_cpu()
     else:
         m.to(devices.cpu)
@@ -525,17 +525,17 @@ def send_model_to_cpu(m):
     devices.torch_gc()
 
 
-def model_target_device():
-    if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
+def model_target_device(m):
+    if lowvram.is_needed(m):
         return devices.cpu
     else:
         return devices.device
 
 
 def send_model_to_device(m):
-    if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
-        lowvram.setup_for_low_vram(m, shared.cmd_opts.medvram)
-    else:
+    lowvram.apply(m)
+
+    if not m.lowvram:
         m.to(shared.device)
 
 
@@ -601,7 +601,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
             '': torch.float16,
         }
 
-    with sd_disable_initialization.LoadStateDictOnMeta(state_dict, device=model_target_device(), weight_dtype_conversion=weight_dtype_conversion):
+    with sd_disable_initialization.LoadStateDictOnMeta(state_dict, device=model_target_device(sd_model), weight_dtype_conversion=weight_dtype_conversion):
         load_model_weights(sd_model, checkpoint_info, state_dict, timer)
     timer.record("load weights from state dict")
 
@@ -743,7 +743,7 @@ def reload_model_weights(sd_model=None, info=None):
         script_callbacks.model_loaded_callback(sd_model)
         timer.record("script callbacks")
 
-        if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram:
+        if not sd_model.lowvram:
             sd_model.to(devices.device)
             timer.record("move model to device")
 

+ 1 - 1
modules/sd_unet.py

@@ -47,7 +47,7 @@ def apply_unet(option=None):
     if current_unet_option is None:
         current_unet = None
 
-        if not (shared.cmd_opts.lowvram or shared.cmd_opts.medvram):
+        if not shared.sd_model.lowvram:
             shared.sd_model.model.diffusion_model.to(devices.device)
 
         return

+ 2 - 2
modules/sd_vae.py

@@ -263,7 +263,7 @@ def reload_vae_weights(sd_model=None, vae_file=unspecified):
     if loaded_vae_file == vae_file:
         return
 
-    if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
+    if sd_model.lowvram:
         lowvram.send_everything_to_cpu()
     else:
         sd_model.to(devices.cpu)
@@ -275,7 +275,7 @@ def reload_vae_weights(sd_model=None, vae_file=unspecified):
     sd_hijack.model_hijack.hijack(sd_model)
     script_callbacks.model_loaded_callback(sd_model)
 
-    if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram:
+    if not sd_model.lowvram:
         sd_model.to(devices.device)
 
     print("VAE weights loaded.")

+ 1 - 1
modules/shared.py

@@ -11,7 +11,7 @@ cmd_opts = shared_cmd_options.cmd_opts
 parser = shared_cmd_options.parser
 
 batch_cond_uncond = True  # old field, unused now in favor of shared.opts.batch_cond_uncond
-parallel_processing_allowed = not cmd_opts.lowvram and not cmd_opts.medvram
+parallel_processing_allowed = True
 styles_filename = cmd_opts.styles_file
 config_filename = cmd_opts.ui_settings_file
 hide_dirs = {"visible": not cmd_opts.hide_ui_dir_config}