Browse Source

Merge pull request #12227 from AUTOMATIC1111/multiple_loaded_models

option to keep multiple models in memory
AUTOMATIC1111 2 years ago
parent
commit
c613416af3
6 changed files with 135 additions and 35 deletions
  1. 3 0
      modules/lowvram.py
  2. 7 3
      modules/sd_hijack.py
  3. 0 2
      modules/sd_hijack_inpainting.py
  4. 110 25
      modules/sd_models.py
  5. 4 4
      modules/sd_models_xl.py
  6. 11 1
      modules/shared.py

+ 3 - 0
modules/lowvram.py

@@ -15,6 +15,9 @@ def send_everything_to_cpu():
 
 
 
 
 def setup_for_low_vram(sd_model, use_medvram):
 def setup_for_low_vram(sd_model, use_medvram):
+    if getattr(sd_model, 'lowvram', False):
+        return
+
     sd_model.lowvram = True
     sd_model.lowvram = True
 
 
     parents = {}
     parents = {}

+ 7 - 3
modules/sd_hijack.py

@@ -5,7 +5,7 @@ from types import MethodType
 from modules import devices, sd_hijack_optimizations, shared, script_callbacks, errors, sd_unet
 from modules import devices, sd_hijack_optimizations, shared, script_callbacks, errors, sd_unet
 from modules.hypernetworks import hypernetwork
 from modules.hypernetworks import hypernetwork
 from modules.shared import cmd_opts
 from modules.shared import cmd_opts
-from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr
+from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr, sd_hijack_inpainting
 
 
 import ldm.modules.attention
 import ldm.modules.attention
 import ldm.modules.diffusionmodules.model
 import ldm.modules.diffusionmodules.model
@@ -29,8 +29,12 @@ ldm.modules.attention.MemoryEfficientCrossAttention = ldm.modules.attention.Cros
 ldm.modules.attention.BasicTransformerBlock.ATTENTION_MODES["softmax-xformers"] = ldm.modules.attention.CrossAttention
 ldm.modules.attention.BasicTransformerBlock.ATTENTION_MODES["softmax-xformers"] = ldm.modules.attention.CrossAttention
 
 
 # silence new console spam from SD2
 # silence new console spam from SD2
-ldm.modules.attention.print = lambda *args: None
-ldm.modules.diffusionmodules.model.print = lambda *args: None
+ldm.modules.attention.print = shared.ldm_print
+ldm.modules.diffusionmodules.model.print = shared.ldm_print
+ldm.util.print = shared.ldm_print
+ldm.models.diffusion.ddpm.print = shared.ldm_print
+
+sd_hijack_inpainting.do_inpainting_hijack()
 
 
 optimizers = []
 optimizers = []
 current_optimizer: sd_hijack_optimizations.SdOptimization = None
 current_optimizer: sd_hijack_optimizations.SdOptimization = None

+ 0 - 2
modules/sd_hijack_inpainting.py

@@ -92,6 +92,4 @@ def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=F
 
 
 
 
 def do_inpainting_hijack():
 def do_inpainting_hijack():
-    # p_sample_plms is needed because PLMS can't work with dicts as conditionings
-
     ldm.models.diffusion.plms.PLMSSampler.p_sample_plms = p_sample_plms
     ldm.models.diffusion.plms.PLMSSampler.p_sample_plms = p_sample_plms

+ 110 - 25
modules/sd_models.py

@@ -15,7 +15,6 @@ import ldm.modules.midas as midas
 from ldm.util import instantiate_from_config
 from ldm.util import instantiate_from_config
 
 
 from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet, sd_models_xl, cache
 from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet, sd_models_xl, cache
-from modules.sd_hijack_inpainting import do_inpainting_hijack
 from modules.timer import Timer
 from modules.timer import Timer
 import tomesd
 import tomesd
 
 
@@ -434,6 +433,7 @@ sdxl_refiner_clip_weight = 'conditioner.embedders.0.model.ln_final.weight'
 class SdModelData:
 class SdModelData:
     def __init__(self):
     def __init__(self):
         self.sd_model = None
         self.sd_model = None
+        self.loaded_sd_models = []
         self.was_loaded_at_least_once = False
         self.was_loaded_at_least_once = False
         self.lock = threading.Lock()
         self.lock = threading.Lock()
 
 
@@ -448,6 +448,7 @@ class SdModelData:
 
 
                 try:
                 try:
                     load_model()
                     load_model()
+
                 except Exception as e:
                 except Exception as e:
                     errors.display(e, "loading stable diffusion model", full_traceback=True)
                     errors.display(e, "loading stable diffusion model", full_traceback=True)
                     print("", file=sys.stderr)
                     print("", file=sys.stderr)
@@ -459,11 +460,24 @@ class SdModelData:
     def set_sd_model(self, v):
     def set_sd_model(self, v):
         self.sd_model = v
         self.sd_model = v
 
 
+        try:
+            self.loaded_sd_models.remove(v)
+        except ValueError:
+            pass
+
+        if v is not None:
+            self.loaded_sd_models.insert(0, v)
+
 
 
 model_data = SdModelData()
 model_data = SdModelData()
 
 
 
 
 def get_empty_cond(sd_model):
 def get_empty_cond(sd_model):
+    from modules import extra_networks, processing
+
+    p = processing.StableDiffusionProcessingTxt2Img()
+    extra_networks.activate(p, {})
+
     if hasattr(sd_model, 'conditioner'):
     if hasattr(sd_model, 'conditioner'):
         d = sd_model.get_learned_conditioning([""])
         d = sd_model.get_learned_conditioning([""])
         return d['crossattn']
         return d['crossattn']
@@ -471,19 +485,43 @@ def get_empty_cond(sd_model):
         return sd_model.cond_stage_model([""])
         return sd_model.cond_stage_model([""])
 
 
 
 
+def send_model_to_cpu(m):
+    from modules import lowvram
+
+    if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
+        lowvram.send_everything_to_cpu()
+    else:
+        m.to(devices.cpu)
+
+    devices.torch_gc()
+
+
+def send_model_to_device(m):
+    from modules import lowvram
+
+    if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
+        lowvram.setup_for_low_vram(m, shared.cmd_opts.medvram)
+    else:
+        m.to(shared.device)
+
+
+def send_model_to_trash(m):
+    m.to(device="meta")
+    devices.torch_gc()
+
+
 def load_model(checkpoint_info=None, already_loaded_state_dict=None):
 def load_model(checkpoint_info=None, already_loaded_state_dict=None):
-    from modules import lowvram, sd_hijack
+    from modules import sd_hijack
     checkpoint_info = checkpoint_info or select_checkpoint()
     checkpoint_info = checkpoint_info or select_checkpoint()
 
 
+    timer = Timer()
+
     if model_data.sd_model:
     if model_data.sd_model:
-        sd_hijack.model_hijack.undo_hijack(model_data.sd_model)
+        send_model_to_trash(model_data.sd_model)
         model_data.sd_model = None
         model_data.sd_model = None
-        gc.collect()
         devices.torch_gc()
         devices.torch_gc()
 
 
-    do_inpainting_hijack()
-
-    timer = Timer()
+    timer.record("unload existing model")
 
 
     if already_loaded_state_dict is not None:
     if already_loaded_state_dict is not None:
         state_dict = already_loaded_state_dict
         state_dict = already_loaded_state_dict
@@ -523,12 +561,9 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
 
 
     with sd_disable_initialization.LoadStateDictOnMeta(state_dict, devices.cpu):
     with sd_disable_initialization.LoadStateDictOnMeta(state_dict, devices.cpu):
         load_model_weights(sd_model, checkpoint_info, state_dict, timer)
         load_model_weights(sd_model, checkpoint_info, state_dict, timer)
+    timer.record("load weights from state dict")
 
 
-    if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
-        lowvram.setup_for_low_vram(sd_model, shared.cmd_opts.medvram)
-    else:
-        sd_model.to(shared.device)
-
+    send_model_to_device(sd_model)
     timer.record("move model to device")
     timer.record("move model to device")
 
 
     sd_hijack.model_hijack.hijack(sd_model)
     sd_hijack.model_hijack.hijack(sd_model)
@@ -536,7 +571,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
     timer.record("hijack")
     timer.record("hijack")
 
 
     sd_model.eval()
     sd_model.eval()
-    model_data.sd_model = sd_model
+    model_data.set_sd_model(sd_model)
     model_data.was_loaded_at_least_once = True
     model_data.was_loaded_at_least_once = True
 
 
     sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True)  # Reload embeddings after model load as they may or may not fit the model
     sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True)  # Reload embeddings after model load as they may or may not fit the model
@@ -557,10 +592,61 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
     return sd_model
     return sd_model
 
 
 
 
+def reuse_model_from_already_loaded(sd_model, checkpoint_info, timer):
+    """
+    Checks if the desired checkpoint from checkpoint_info is not already loaded in model_data.loaded_sd_models.
+    If it is loaded, returns that (moving it to GPU if necessary, and moving the currently loadded model to CPU if necessary).
+    If not, returns the model that can be used to load weights from checkpoint_info's file.
+    If no such model exists, returns None.
+    Additionaly deletes loaded models that are over the limit set in settings (sd_checkpoints_limit).
+    """
+
+    already_loaded = None
+    for i in reversed(range(len(model_data.loaded_sd_models))):
+        loaded_model = model_data.loaded_sd_models[i]
+        if loaded_model.sd_checkpoint_info.filename == checkpoint_info.filename:
+            already_loaded = loaded_model
+            continue
+
+        if len(model_data.loaded_sd_models) > shared.opts.sd_checkpoints_limit > 0:
+            print(f"Unloading model {len(model_data.loaded_sd_models)} over the limit of {shared.opts.sd_checkpoints_limit}: {loaded_model.sd_checkpoint_info.title}")
+            model_data.loaded_sd_models.pop()
+            send_model_to_trash(loaded_model)
+            timer.record("send model to trash")
+
+        if shared.opts.sd_checkpoints_keep_in_cpu:
+            send_model_to_cpu(sd_model)
+            timer.record("send model to cpu")
+
+    if already_loaded is not None:
+        send_model_to_device(already_loaded)
+        timer.record("send model to device")
+
+        model_data.set_sd_model(already_loaded)
+        print(f"Using already loaded model {already_loaded.sd_checkpoint_info.title}: done in {timer.summary()}")
+        return model_data.sd_model
+    elif shared.opts.sd_checkpoints_limit > 1 and len(model_data.loaded_sd_models) < shared.opts.sd_checkpoints_limit:
+        print(f"Loading model {checkpoint_info.title} ({len(model_data.loaded_sd_models) + 1} out of {shared.opts.sd_checkpoints_limit})")
+
+        model_data.sd_model = None
+        load_model(checkpoint_info)
+        return model_data.sd_model
+    elif len(model_data.loaded_sd_models) > 0:
+        sd_model = model_data.loaded_sd_models.pop()
+        model_data.sd_model = sd_model
+
+        print(f"Reusing loaded model {sd_model.sd_checkpoint_info.title} to load {checkpoint_info.title}")
+        return sd_model
+    else:
+        return None
+
+
 def reload_model_weights(sd_model=None, info=None):
 def reload_model_weights(sd_model=None, info=None):
-    from modules import lowvram, devices, sd_hijack
+    from modules import devices, sd_hijack
     checkpoint_info = info or select_checkpoint()
     checkpoint_info = info or select_checkpoint()
 
 
+    timer = Timer()
+
     if not sd_model:
     if not sd_model:
         sd_model = model_data.sd_model
         sd_model = model_data.sd_model
 
 
@@ -569,19 +655,17 @@ def reload_model_weights(sd_model=None, info=None):
     else:
     else:
         current_checkpoint_info = sd_model.sd_checkpoint_info
         current_checkpoint_info = sd_model.sd_checkpoint_info
         if sd_model.sd_model_checkpoint == checkpoint_info.filename:
         if sd_model.sd_model_checkpoint == checkpoint_info.filename:
-            return
-
-        sd_unet.apply_unet("None")
+            return sd_model
 
 
-        if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
-            lowvram.send_everything_to_cpu()
-        else:
-            sd_model.to(devices.cpu)
+    sd_model = reuse_model_from_already_loaded(sd_model, checkpoint_info, timer)
+    if sd_model is not None and sd_model.sd_checkpoint_info.filename == checkpoint_info.filename:
+        return sd_model
 
 
+    if sd_model is not None:
+        sd_unet.apply_unet("None")
+        send_model_to_cpu(sd_model)
         sd_hijack.model_hijack.undo_hijack(sd_model)
         sd_hijack.model_hijack.undo_hijack(sd_model)
 
 
-    timer = Timer()
-
     state_dict = get_checkpoint_state_dict(checkpoint_info, timer)
     state_dict = get_checkpoint_state_dict(checkpoint_info, timer)
 
 
     checkpoint_config = sd_models_config.find_checkpoint_config(state_dict, checkpoint_info)
     checkpoint_config = sd_models_config.find_checkpoint_config(state_dict, checkpoint_info)
@@ -590,9 +674,8 @@ def reload_model_weights(sd_model=None, info=None):
 
 
     if sd_model is None or checkpoint_config != sd_model.used_config:
     if sd_model is None or checkpoint_config != sd_model.used_config:
         if sd_model is not None:
         if sd_model is not None:
-            sd_model.to(device="meta")
+            send_model_to_trash(sd_model)
 
 
-        devices.torch_gc()
         load_model(checkpoint_info, already_loaded_state_dict=state_dict)
         load_model(checkpoint_info, already_loaded_state_dict=state_dict)
         return model_data.sd_model
         return model_data.sd_model
 
 
@@ -615,6 +698,8 @@ def reload_model_weights(sd_model=None, info=None):
 
 
     print(f"Weights loaded in {timer.summary()}.")
     print(f"Weights loaded in {timer.summary()}.")
 
 
+    model_data.set_sd_model(sd_model)
+
     return sd_model
     return sd_model
 
 
 
 

+ 4 - 4
modules/sd_models_xl.py

@@ -98,10 +98,10 @@ def extend_sdxl(model):
     model.conditioner.wrapped = torch.nn.Module()
     model.conditioner.wrapped = torch.nn.Module()
 
 
 
 
-sgm.modules.attention.print = lambda *args: None
-sgm.modules.diffusionmodules.model.print = lambda *args: None
-sgm.modules.diffusionmodules.openaimodel.print = lambda *args: None
-sgm.modules.encoders.modules.print = lambda *args: None
+sgm.modules.attention.print = shared.ldm_print
+sgm.modules.diffusionmodules.model.print = shared.ldm_print
+sgm.modules.diffusionmodules.openaimodel.print = shared.ldm_print
+sgm.modules.encoders.modules.print = shared.ldm_print
 
 
 # this gets the code to load the vanilla attention that we override
 # this gets the code to load the vanilla attention that we override
 sgm.modules.attention.SDP_IS_AVAILABLE = True
 sgm.modules.attention.SDP_IS_AVAILABLE = True

+ 11 - 1
modules/shared.py

@@ -400,6 +400,7 @@ options_templates.update(options_section(('system', "System"), {
     "print_hypernet_extra": OptionInfo(False, "Print extra hypernetwork information to console."),
     "print_hypernet_extra": OptionInfo(False, "Print extra hypernetwork information to console."),
     "list_hidden_files": OptionInfo(True, "Load models/files in hidden directories").info("directory is hidden if its name starts with \".\""),
     "list_hidden_files": OptionInfo(True, "Load models/files in hidden directories").info("directory is hidden if its name starts with \".\""),
     "disable_mmap_load_safetensors": OptionInfo(False, "Disable memmapping for loading .safetensors files.").info("fixes very slow loading speed in some cases"),
     "disable_mmap_load_safetensors": OptionInfo(False, "Disable memmapping for loading .safetensors files.").info("fixes very slow loading speed in some cases"),
+    "hide_ldm_prints": OptionInfo(True, "Prevent Stability-AI's ldm/sgm modules from printing noise to console."),
 }))
 }))
 
 
 options_templates.update(options_section(('training', "Training"), {
 options_templates.update(options_section(('training', "Training"), {
@@ -419,7 +420,9 @@ options_templates.update(options_section(('training', "Training"), {
 
 
 options_templates.update(options_section(('sd', "Stable Diffusion"), {
 options_templates.update(options_section(('sd', "Stable Diffusion"), {
     "sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": list_checkpoint_tiles()}, refresh=refresh_checkpoints),
     "sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": list_checkpoint_tiles()}, refresh=refresh_checkpoints),
-    "sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
+    "sd_checkpoints_limit": OptionInfo(1, "Maximum number of checkpoints loaded at the same time", gr.Slider, {"minimum": 1, "maximum": 10, "step": 1}),
+    "sd_checkpoints_keep_in_cpu": OptionInfo(True, "Only keep one model on device").info("will keep models other than the currently used one in RAM rather than VRAM"),
+    "sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}).info("obsolete; set to 0 and use the two settings above instead"),
     "sd_vae_checkpoint_cache": OptionInfo(0, "VAE Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
     "sd_vae_checkpoint_cache": OptionInfo(0, "VAE Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
     "sd_vae": OptionInfo("Automatic", "SD VAE", gr.Dropdown, lambda: {"choices": shared_items.sd_vae_items()}, refresh=shared_items.refresh_vae_list).info("choose VAE model: Automatic = use one with same filename as checkpoint; None = use VAE from checkpoint"),
     "sd_vae": OptionInfo("Automatic", "SD VAE", gr.Dropdown, lambda: {"choices": shared_items.sd_vae_items()}, refresh=shared_items.refresh_vae_list).info("choose VAE model: Automatic = use one with same filename as checkpoint; None = use VAE from checkpoint"),
     "sd_vae_as_default": OptionInfo(True, "Ignore selected VAE for stable diffusion checkpoints that have their own .vae.pt next to them"),
     "sd_vae_as_default": OptionInfo(True, "Ignore selected VAE for stable diffusion checkpoints that have their own .vae.pt next to them"),
@@ -906,3 +909,10 @@ def walk_files(path, allowed_extensions=None):
                 continue
                 continue
 
 
             yield os.path.join(root, filename)
             yield os.path.join(root, filename)
+
+
+def ldm_print(*args, **kwargs):
+    if opts.hide_ldm_prints:
+        return
+
+    print(*args, **kwargs)