|
@@ -1,7 +1,9 @@
|
|
|
import collections
|
|
|
+import importlib
|
|
|
import os
|
|
|
import sys
|
|
|
import threading
|
|
|
+import enum
|
|
|
|
|
|
import torch
|
|
|
import re
|
|
@@ -10,8 +12,6 @@ from omegaconf import OmegaConf, ListConfig
|
|
|
from urllib import request
|
|
|
import ldm.modules.midas as midas
|
|
|
|
|
|
-from ldm.util import instantiate_from_config
|
|
|
-
|
|
|
from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet, sd_models_xl, cache, extra_networks, processing, lowvram, sd_hijack, patches
|
|
|
from modules.timer import Timer
|
|
|
from modules.shared import opts
|
|
@@ -27,6 +27,14 @@ checkpoint_alisases = checkpoint_aliases # for compatibility with old name
|
|
|
checkpoints_loaded = collections.OrderedDict()
|
|
|
|
|
|
|
|
|
+class ModelType(enum.Enum):
|
|
|
+ SD1 = 1
|
|
|
+ SD2 = 2
|
|
|
+ SDXL = 3
|
|
|
+ SSD = 4
|
|
|
+ SD3 = 5
|
|
|
+
|
|
|
+
|
|
|
def replace_key(d, key, new_key, value):
|
|
|
keys = list(d.keys())
|
|
|
|
|
@@ -368,6 +376,36 @@ def check_fp8(model):
|
|
|
return enable_fp8
|
|
|
|
|
|
|
|
|
+def set_model_type(model, state_dict):
|
|
|
+ model.is_sd1 = False
|
|
|
+ model.is_sd2 = False
|
|
|
+ model.is_sdxl = False
|
|
|
+ model.is_ssd = False
|
|
|
+ model.is_ssd3 = False
|
|
|
+
|
|
|
+ if "model.diffusion_model.x_embedder.proj.weight" in state_dict:
|
|
|
+ model.is_sd3 = True
|
|
|
+ model.model_type = ModelType.SD3
|
|
|
+ elif hasattr(model, 'conditioner'):
|
|
|
+ model.is_sdxl = True
|
|
|
+
|
|
|
+ if 'model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_q.weight' not in state_dict.keys():
|
|
|
+ model.is_ssd = True
|
|
|
+ model.model_type = ModelType.SSD
|
|
|
+ else:
|
|
|
+ model.model_type = ModelType.SDXL
|
|
|
+ elif hasattr(model.cond_stage_model, 'model'):
|
|
|
+ model.is_sd2 = True
|
|
|
+ model.model_type = ModelType.SD2
|
|
|
+ else:
|
|
|
+ model.is_sd1 = True
|
|
|
+ model.model_type = ModelType.SD1
|
|
|
+
|
|
|
+
|
|
|
+def set_model_fields(model):
|
|
|
+ if not hasattr(model, 'latent_channels'):
|
|
|
+ model.latent_channels = 4
|
|
|
+
|
|
|
def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer):
|
|
|
sd_model_hash = checkpoint_info.calculate_shorthash()
|
|
|
timer.record("calculate hash")
|
|
@@ -382,10 +420,9 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
|
|
|
if state_dict is None:
|
|
|
state_dict = get_checkpoint_state_dict(checkpoint_info, timer)
|
|
|
|
|
|
- model.is_sdxl = hasattr(model, 'conditioner')
|
|
|
- model.is_sd2 = not model.is_sdxl and hasattr(model.cond_stage_model, 'model')
|
|
|
- model.is_sd1 = not model.is_sdxl and not model.is_sd2
|
|
|
- model.is_ssd = model.is_sdxl and 'model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_q.weight' not in state_dict.keys()
|
|
|
+ set_model_type(model, state_dict)
|
|
|
+ set_model_fields(model)
|
|
|
+
|
|
|
if model.is_sdxl:
|
|
|
sd_models_xl.extend_sdxl(model)
|
|
|
|
|
@@ -552,8 +589,7 @@ def patch_given_betas():
|
|
|
original_register_schedule = patches.patch(__name__, ldm.models.diffusion.ddpm.DDPM, 'register_schedule', patched_register_schedule)
|
|
|
|
|
|
|
|
|
-def repair_config(sd_config):
|
|
|
-
|
|
|
+def repair_config(sd_config, state_dict=None):
|
|
|
if not hasattr(sd_config.model.params, "use_ema"):
|
|
|
sd_config.model.params.use_ema = False
|
|
|
|
|
@@ -563,8 +599,9 @@ def repair_config(sd_config):
|
|
|
elif shared.cmd_opts.upcast_sampling or shared.cmd_opts.precision == "half":
|
|
|
sd_config.model.params.unet_config.params.use_fp16 = True
|
|
|
|
|
|
- if getattr(sd_config.model.params.first_stage_config.params.ddconfig, "attn_type", None) == "vanilla-xformers" and not shared.xformers_available:
|
|
|
- sd_config.model.params.first_stage_config.params.ddconfig.attn_type = "vanilla"
|
|
|
+ if hasattr(sd_config.model.params, 'first_stage_config'):
|
|
|
+ if getattr(sd_config.model.params.first_stage_config.params.ddconfig, "attn_type", None) == "vanilla-xformers" and not shared.xformers_available:
|
|
|
+ sd_config.model.params.first_stage_config.params.ddconfig.attn_type = "vanilla"
|
|
|
|
|
|
# For UnCLIP-L, override the hardcoded karlo directory
|
|
|
if hasattr(sd_config.model.params, "noise_aug_config") and hasattr(sd_config.model.params.noise_aug_config.params, "clip_stats_path"):
|
|
@@ -580,6 +617,7 @@ def repair_config(sd_config):
|
|
|
sd_config.model.params.unet_config.params.use_checkpoint = False
|
|
|
|
|
|
|
|
|
+
|
|
|
def rescale_zero_terminal_snr_abar(alphas_cumprod):
|
|
|
alphas_bar_sqrt = alphas_cumprod.sqrt()
|
|
|
|
|
@@ -715,6 +753,25 @@ def send_model_to_trash(m):
|
|
|
devices.torch_gc()
|
|
|
|
|
|
|
|
|
+def instantiate_from_config(config, state_dict=None):
|
|
|
+ constructor = get_obj_from_str(config["target"])
|
|
|
+
|
|
|
+ params = {**config.get("params", {})}
|
|
|
+
|
|
|
+ if state_dict and "state_dict" in params and params["state_dict"] is None:
|
|
|
+ params["state_dict"] = state_dict
|
|
|
+
|
|
|
+ return constructor(**params)
|
|
|
+
|
|
|
+
|
|
|
+def get_obj_from_str(string, reload=False):
|
|
|
+ module, cls = string.rsplit(".", 1)
|
|
|
+ if reload:
|
|
|
+ module_imp = importlib.import_module(module)
|
|
|
+ importlib.reload(module_imp)
|
|
|
+ return getattr(importlib.import_module(module, package=None), cls)
|
|
|
+
|
|
|
+
|
|
|
def load_model(checkpoint_info=None, already_loaded_state_dict=None):
|
|
|
from modules import sd_hijack
|
|
|
checkpoint_info = checkpoint_info or select_checkpoint()
|
|
@@ -739,7 +796,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
|
|
|
timer.record("find config")
|
|
|
|
|
|
sd_config = OmegaConf.load(checkpoint_config)
|
|
|
- repair_config(sd_config)
|
|
|
+ repair_config(sd_config, state_dict)
|
|
|
|
|
|
timer.record("load config")
|
|
|
|
|
@@ -749,7 +806,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
|
|
|
try:
|
|
|
with sd_disable_initialization.DisableInitialization(disable_clip=clip_is_included_into_sd or shared.cmd_opts.do_not_download_clip):
|
|
|
with sd_disable_initialization.InitializeOnMeta():
|
|
|
- sd_model = instantiate_from_config(sd_config.model)
|
|
|
+ sd_model = instantiate_from_config(sd_config.model, state_dict)
|
|
|
|
|
|
except Exception as e:
|
|
|
errors.display(e, "creating model quickly", full_traceback=True)
|
|
@@ -758,7 +815,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
|
|
|
print('Failed to create model quickly; will retry using slow method.', file=sys.stderr)
|
|
|
|
|
|
with sd_disable_initialization.InitializeOnMeta():
|
|
|
- sd_model = instantiate_from_config(sd_config.model)
|
|
|
+ sd_model = instantiate_from_config(sd_config.model, state_dict)
|
|
|
|
|
|
sd_model.used_config = checkpoint_config
|
|
|
|
|
@@ -775,6 +832,10 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
|
|
|
|
|
|
with sd_disable_initialization.LoadStateDictOnMeta(state_dict, device=model_target_device(sd_model), weight_dtype_conversion=weight_dtype_conversion):
|
|
|
load_model_weights(sd_model, checkpoint_info, state_dict, timer)
|
|
|
+
|
|
|
+ if hasattr(sd_model, "after_load_weights"):
|
|
|
+ sd_model.after_load_weights()
|
|
|
+
|
|
|
timer.record("load weights from state dict")
|
|
|
|
|
|
send_model_to_device(sd_model)
|