|
@@ -15,7 +15,6 @@ import ldm.modules.midas as midas
|
|
|
from ldm.util import instantiate_from_config
|
|
|
|
|
|
from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet, sd_models_xl, cache
|
|
|
-from modules.sd_hijack_inpainting import do_inpainting_hijack
|
|
|
from modules.timer import Timer
|
|
|
import tomesd
|
|
|
|
|
@@ -434,6 +433,7 @@ sdxl_refiner_clip_weight = 'conditioner.embedders.0.model.ln_final.weight'
|
|
|
class SdModelData:
|
|
|
def __init__(self):
|
|
|
self.sd_model = None
|
|
|
+ self.loaded_sd_models = []
|
|
|
self.was_loaded_at_least_once = False
|
|
|
self.lock = threading.Lock()
|
|
|
|
|
@@ -448,6 +448,7 @@ class SdModelData:
|
|
|
|
|
|
try:
|
|
|
load_model()
|
|
|
+
|
|
|
except Exception as e:
|
|
|
errors.display(e, "loading stable diffusion model", full_traceback=True)
|
|
|
print("", file=sys.stderr)
|
|
@@ -459,11 +460,24 @@ class SdModelData:
|
|
|
def set_sd_model(self, v):
|
|
|
self.sd_model = v
|
|
|
|
|
|
+ try:
|
|
|
+ self.loaded_sd_models.remove(v)
|
|
|
+ except ValueError:
|
|
|
+ pass
|
|
|
+
|
|
|
+ if v is not None:
|
|
|
+ self.loaded_sd_models.insert(0, v)
|
|
|
+
|
|
|
|
|
|
model_data = SdModelData()
|
|
|
|
|
|
|
|
|
def get_empty_cond(sd_model):
|
|
|
+ from modules import extra_networks, processing
|
|
|
+
|
|
|
+ p = processing.StableDiffusionProcessingTxt2Img()
|
|
|
+ extra_networks.activate(p, {})
|
|
|
+
|
|
|
if hasattr(sd_model, 'conditioner'):
|
|
|
d = sd_model.get_learned_conditioning([""])
|
|
|
return d['crossattn']
|
|
@@ -471,19 +485,43 @@ def get_empty_cond(sd_model):
|
|
|
return sd_model.cond_stage_model([""])
|
|
|
|
|
|
|
|
|
+def send_model_to_cpu(m):
|
|
|
+ from modules import lowvram
|
|
|
+
|
|
|
+ if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
|
|
|
+ lowvram.send_everything_to_cpu()
|
|
|
+ else:
|
|
|
+ m.to(devices.cpu)
|
|
|
+
|
|
|
+ devices.torch_gc()
|
|
|
+
|
|
|
+
|
|
|
+def send_model_to_device(m):
|
|
|
+ from modules import lowvram
|
|
|
+
|
|
|
+ if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
|
|
|
+ lowvram.setup_for_low_vram(m, shared.cmd_opts.medvram)
|
|
|
+ else:
|
|
|
+ m.to(shared.device)
|
|
|
+
|
|
|
+
|
|
|
+def send_model_to_trash(m):
|
|
|
+ m.to(device="meta")
|
|
|
+ devices.torch_gc()
|
|
|
+
|
|
|
+
|
|
|
def load_model(checkpoint_info=None, already_loaded_state_dict=None):
|
|
|
- from modules import lowvram, sd_hijack
|
|
|
+ from modules import sd_hijack
|
|
|
checkpoint_info = checkpoint_info or select_checkpoint()
|
|
|
|
|
|
+ timer = Timer()
|
|
|
+
|
|
|
if model_data.sd_model:
|
|
|
- sd_hijack.model_hijack.undo_hijack(model_data.sd_model)
|
|
|
+ send_model_to_trash(model_data.sd_model)
|
|
|
model_data.sd_model = None
|
|
|
- gc.collect()
|
|
|
devices.torch_gc()
|
|
|
|
|
|
- do_inpainting_hijack()
|
|
|
-
|
|
|
- timer = Timer()
|
|
|
+ timer.record("unload existing model")
|
|
|
|
|
|
if already_loaded_state_dict is not None:
|
|
|
state_dict = already_loaded_state_dict
|
|
@@ -523,12 +561,9 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
|
|
|
|
|
|
with sd_disable_initialization.LoadStateDictOnMeta(state_dict, devices.cpu):
|
|
|
load_model_weights(sd_model, checkpoint_info, state_dict, timer)
|
|
|
+ timer.record("load weights from state dict")
|
|
|
|
|
|
- if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
|
|
|
- lowvram.setup_for_low_vram(sd_model, shared.cmd_opts.medvram)
|
|
|
- else:
|
|
|
- sd_model.to(shared.device)
|
|
|
-
|
|
|
+ send_model_to_device(sd_model)
|
|
|
timer.record("move model to device")
|
|
|
|
|
|
sd_hijack.model_hijack.hijack(sd_model)
|
|
@@ -536,7 +571,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
|
|
|
timer.record("hijack")
|
|
|
|
|
|
sd_model.eval()
|
|
|
- model_data.sd_model = sd_model
|
|
|
+ model_data.set_sd_model(sd_model)
|
|
|
model_data.was_loaded_at_least_once = True
|
|
|
|
|
|
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True) # Reload embeddings after model load as they may or may not fit the model
|
|
@@ -557,10 +592,61 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
|
|
|
return sd_model
|
|
|
|
|
|
|
|
|
+def reuse_model_from_already_loaded(sd_model, checkpoint_info, timer):
|
|
|
+ """
|
|
|
+ Checks if the desired checkpoint from checkpoint_info is not already loaded in model_data.loaded_sd_models.
|
|
|
+ If it is loaded, returns that (moving it to GPU if necessary, and moving the currently loadded model to CPU if necessary).
|
|
|
+ If not, returns the model that can be used to load weights from checkpoint_info's file.
|
|
|
+ If no such model exists, returns None.
|
|
|
+ Additionaly deletes loaded models that are over the limit set in settings (sd_checkpoints_limit).
|
|
|
+ """
|
|
|
+
|
|
|
+ already_loaded = None
|
|
|
+ for i in reversed(range(len(model_data.loaded_sd_models))):
|
|
|
+ loaded_model = model_data.loaded_sd_models[i]
|
|
|
+ if loaded_model.sd_checkpoint_info.filename == checkpoint_info.filename:
|
|
|
+ already_loaded = loaded_model
|
|
|
+ continue
|
|
|
+
|
|
|
+ if len(model_data.loaded_sd_models) > shared.opts.sd_checkpoints_limit > 0:
|
|
|
+ print(f"Unloading model {len(model_data.loaded_sd_models)} over the limit of {shared.opts.sd_checkpoints_limit}: {loaded_model.sd_checkpoint_info.title}")
|
|
|
+ model_data.loaded_sd_models.pop()
|
|
|
+ send_model_to_trash(loaded_model)
|
|
|
+ timer.record("send model to trash")
|
|
|
+
|
|
|
+ if shared.opts.sd_checkpoints_keep_in_cpu:
|
|
|
+ send_model_to_cpu(sd_model)
|
|
|
+ timer.record("send model to cpu")
|
|
|
+
|
|
|
+ if already_loaded is not None:
|
|
|
+ send_model_to_device(already_loaded)
|
|
|
+ timer.record("send model to device")
|
|
|
+
|
|
|
+ model_data.set_sd_model(already_loaded)
|
|
|
+ print(f"Using already loaded model {already_loaded.sd_checkpoint_info.title}: done in {timer.summary()}")
|
|
|
+ return model_data.sd_model
|
|
|
+ elif shared.opts.sd_checkpoints_limit > 1 and len(model_data.loaded_sd_models) < shared.opts.sd_checkpoints_limit:
|
|
|
+ print(f"Loading model {checkpoint_info.title} ({len(model_data.loaded_sd_models) + 1} out of {shared.opts.sd_checkpoints_limit})")
|
|
|
+
|
|
|
+ model_data.sd_model = None
|
|
|
+ load_model(checkpoint_info)
|
|
|
+ return model_data.sd_model
|
|
|
+ elif len(model_data.loaded_sd_models) > 0:
|
|
|
+ sd_model = model_data.loaded_sd_models.pop()
|
|
|
+ model_data.sd_model = sd_model
|
|
|
+
|
|
|
+ print(f"Reusing loaded model {sd_model.sd_checkpoint_info.title} to load {checkpoint_info.title}")
|
|
|
+ return sd_model
|
|
|
+ else:
|
|
|
+ return None
|
|
|
+
|
|
|
+
|
|
|
def reload_model_weights(sd_model=None, info=None):
|
|
|
- from modules import lowvram, devices, sd_hijack
|
|
|
+ from modules import devices, sd_hijack
|
|
|
checkpoint_info = info or select_checkpoint()
|
|
|
|
|
|
+ timer = Timer()
|
|
|
+
|
|
|
if not sd_model:
|
|
|
sd_model = model_data.sd_model
|
|
|
|
|
@@ -569,19 +655,17 @@ def reload_model_weights(sd_model=None, info=None):
|
|
|
else:
|
|
|
current_checkpoint_info = sd_model.sd_checkpoint_info
|
|
|
if sd_model.sd_model_checkpoint == checkpoint_info.filename:
|
|
|
- return
|
|
|
-
|
|
|
- sd_unet.apply_unet("None")
|
|
|
+ return sd_model
|
|
|
|
|
|
- if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
|
|
|
- lowvram.send_everything_to_cpu()
|
|
|
- else:
|
|
|
- sd_model.to(devices.cpu)
|
|
|
+ sd_model = reuse_model_from_already_loaded(sd_model, checkpoint_info, timer)
|
|
|
+ if sd_model is not None and sd_model.sd_checkpoint_info.filename == checkpoint_info.filename:
|
|
|
+ return sd_model
|
|
|
|
|
|
+ if sd_model is not None:
|
|
|
+ sd_unet.apply_unet("None")
|
|
|
+ send_model_to_cpu(sd_model)
|
|
|
sd_hijack.model_hijack.undo_hijack(sd_model)
|
|
|
|
|
|
- timer = Timer()
|
|
|
-
|
|
|
state_dict = get_checkpoint_state_dict(checkpoint_info, timer)
|
|
|
|
|
|
checkpoint_config = sd_models_config.find_checkpoint_config(state_dict, checkpoint_info)
|
|
@@ -590,9 +674,8 @@ def reload_model_weights(sd_model=None, info=None):
|
|
|
|
|
|
if sd_model is None or checkpoint_config != sd_model.used_config:
|
|
|
if sd_model is not None:
|
|
|
- sd_model.to(device="meta")
|
|
|
+ send_model_to_trash(sd_model)
|
|
|
|
|
|
- devices.torch_gc()
|
|
|
load_model(checkpoint_info, already_loaded_state_dict=state_dict)
|
|
|
return model_data.sd_model
|
|
|
|
|
@@ -615,6 +698,8 @@ def reload_model_weights(sd_model=None, info=None):
|
|
|
|
|
|
print(f"Weights loaded in {timer.summary()}.")
|
|
|
|
|
|
+ model_data.set_sd_model(sd_model)
|
|
|
+
|
|
|
return sd_model
|
|
|
|
|
|
|