sd_models.py 28 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798
  1. import collections
  2. import os.path
  3. import sys
  4. import gc
  5. import threading
  6. import torch
  7. import re
  8. import safetensors.torch
  9. from omegaconf import OmegaConf
  10. from os import mkdir
  11. from urllib import request
  12. import ldm.modules.midas as midas
  13. from ldm.util import instantiate_from_config
  14. from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet, sd_models_xl, cache, extra_networks, processing, lowvram, sd_hijack
  15. from modules.timer import Timer
  16. import tomesd
  17. model_dir = "Stable-diffusion"
  18. model_path = os.path.abspath(os.path.join(paths.models_path, model_dir))
  19. checkpoints_list = {}
  20. checkpoint_aliases = {}
  21. checkpoint_alisases = checkpoint_aliases # for compatibility with old name
  22. checkpoints_loaded = collections.OrderedDict()
  23. class CheckpointInfo:
  24. def __init__(self, filename):
  25. self.filename = filename
  26. abspath = os.path.abspath(filename)
  27. self.is_safetensors = os.path.splitext(filename)[1].lower() == ".safetensors"
  28. if shared.cmd_opts.ckpt_dir is not None and abspath.startswith(shared.cmd_opts.ckpt_dir):
  29. name = abspath.replace(shared.cmd_opts.ckpt_dir, '')
  30. elif abspath.startswith(model_path):
  31. name = abspath.replace(model_path, '')
  32. else:
  33. name = os.path.basename(filename)
  34. if name.startswith("\\") or name.startswith("/"):
  35. name = name[1:]
  36. def read_metadata():
  37. metadata = read_metadata_from_safetensors(filename)
  38. self.modelspec_thumbnail = metadata.pop('modelspec.thumbnail', None)
  39. return metadata
  40. self.metadata = {}
  41. if self.is_safetensors:
  42. try:
  43. self.metadata = cache.cached_data_for_file('safetensors-metadata', "checkpoint/" + name, filename, read_metadata)
  44. except Exception as e:
  45. errors.display(e, f"reading metadata for {filename}")
  46. self.name = name
  47. self.name_for_extra = os.path.splitext(os.path.basename(filename))[0]
  48. self.model_name = os.path.splitext(name.replace("/", "_").replace("\\", "_"))[0]
  49. self.hash = model_hash(filename)
  50. self.sha256 = hashes.sha256_from_cache(self.filename, f"checkpoint/{name}")
  51. self.shorthash = self.sha256[0:10] if self.sha256 else None
  52. self.title = name if self.shorthash is None else f'{name} [{self.shorthash}]'
  53. self.short_title = self.name_for_extra if self.shorthash is None else f'{self.name_for_extra} [{self.shorthash}]'
  54. self.ids = [self.hash, self.model_name, self.title, name, self.name_for_extra, f'{name} [{self.hash}]']
  55. if self.shorthash:
  56. self.ids += [self.shorthash, self.sha256, f'{self.name} [{self.shorthash}]', f'{self.name_for_extra} [{self.shorthash}]']
  57. def register(self):
  58. checkpoints_list[self.title] = self
  59. for id in self.ids:
  60. checkpoint_aliases[id] = self
  61. def calculate_shorthash(self):
  62. self.sha256 = hashes.sha256(self.filename, f"checkpoint/{self.name}")
  63. if self.sha256 is None:
  64. return
  65. shorthash = self.sha256[0:10]
  66. if self.shorthash == self.sha256[0:10]:
  67. return self.shorthash
  68. self.shorthash = shorthash
  69. if self.shorthash not in self.ids:
  70. self.ids += [self.shorthash, self.sha256, f'{self.name} [{self.shorthash}]', f'{self.name_for_extra} [{self.shorthash}]']
  71. checkpoints_list.pop(self.title, None)
  72. self.title = f'{self.name} [{self.shorthash}]'
  73. self.short_title = f'{self.name_for_extra} [{self.shorthash}]'
  74. self.register()
  75. return self.shorthash
  76. try:
  77. # this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start.
  78. from transformers import logging, CLIPModel # noqa: F401
  79. logging.set_verbosity_error()
  80. except Exception:
  81. pass
  82. def setup_model():
  83. os.makedirs(model_path, exist_ok=True)
  84. enable_midas_autodownload()
  85. def checkpoint_tiles(use_short=False):
  86. return [x.short_title if use_short else x.title for x in checkpoints_list.values()]
  87. def list_models():
  88. checkpoints_list.clear()
  89. checkpoint_aliases.clear()
  90. cmd_ckpt = shared.cmd_opts.ckpt
  91. if shared.cmd_opts.no_download_sd_model or cmd_ckpt != shared.sd_model_file or os.path.exists(cmd_ckpt):
  92. model_url = None
  93. else:
  94. model_url = "https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/v1-5-pruned-emaonly.safetensors"
  95. model_list = modelloader.load_models(model_path=model_path, model_url=model_url, command_path=shared.cmd_opts.ckpt_dir, ext_filter=[".ckpt", ".safetensors"], download_name="v1-5-pruned-emaonly.safetensors", ext_blacklist=[".vae.ckpt", ".vae.safetensors"])
  96. if os.path.exists(cmd_ckpt):
  97. checkpoint_info = CheckpointInfo(cmd_ckpt)
  98. checkpoint_info.register()
  99. shared.opts.data['sd_model_checkpoint'] = checkpoint_info.title
  100. elif cmd_ckpt is not None and cmd_ckpt != shared.default_sd_model_file:
  101. print(f"Checkpoint in --ckpt argument not found (Possible it was moved to {model_path}: {cmd_ckpt}", file=sys.stderr)
  102. for filename in model_list:
  103. checkpoint_info = CheckpointInfo(filename)
  104. checkpoint_info.register()
  105. re_strip_checksum = re.compile(r"\s*\[[^]]+]\s*$")
  106. def get_closet_checkpoint_match(search_string):
  107. if not search_string:
  108. return None
  109. checkpoint_info = checkpoint_aliases.get(search_string, None)
  110. if checkpoint_info is not None:
  111. return checkpoint_info
  112. found = sorted([info for info in checkpoints_list.values() if search_string in info.title], key=lambda x: len(x.title))
  113. if found:
  114. return found[0]
  115. search_string_without_checksum = re.sub(re_strip_checksum, '', search_string)
  116. found = sorted([info for info in checkpoints_list.values() if search_string_without_checksum in info.title], key=lambda x: len(x.title))
  117. if found:
  118. return found[0]
  119. return None
  120. def model_hash(filename):
  121. """old hash that only looks at a small part of the file and is prone to collisions"""
  122. try:
  123. with open(filename, "rb") as file:
  124. import hashlib
  125. m = hashlib.sha256()
  126. file.seek(0x100000)
  127. m.update(file.read(0x10000))
  128. return m.hexdigest()[0:8]
  129. except FileNotFoundError:
  130. return 'NOFILE'
  131. def select_checkpoint():
  132. """Raises `FileNotFoundError` if no checkpoints are found."""
  133. model_checkpoint = shared.opts.sd_model_checkpoint
  134. checkpoint_info = checkpoint_aliases.get(model_checkpoint, None)
  135. if checkpoint_info is not None:
  136. return checkpoint_info
  137. if len(checkpoints_list) == 0:
  138. error_message = "No checkpoints found. When searching for checkpoints, looked at:"
  139. if shared.cmd_opts.ckpt is not None:
  140. error_message += f"\n - file {os.path.abspath(shared.cmd_opts.ckpt)}"
  141. error_message += f"\n - directory {model_path}"
  142. if shared.cmd_opts.ckpt_dir is not None:
  143. error_message += f"\n - directory {os.path.abspath(shared.cmd_opts.ckpt_dir)}"
  144. error_message += "Can't run without a checkpoint. Find and place a .ckpt or .safetensors file into any of those locations."
  145. raise FileNotFoundError(error_message)
  146. checkpoint_info = next(iter(checkpoints_list.values()))
  147. if model_checkpoint is not None:
  148. print(f"Checkpoint {model_checkpoint} not found; loading fallback {checkpoint_info.title}", file=sys.stderr)
  149. return checkpoint_info
  150. checkpoint_dict_replacements = {
  151. 'cond_stage_model.transformer.embeddings.': 'cond_stage_model.transformer.text_model.embeddings.',
  152. 'cond_stage_model.transformer.encoder.': 'cond_stage_model.transformer.text_model.encoder.',
  153. 'cond_stage_model.transformer.final_layer_norm.': 'cond_stage_model.transformer.text_model.final_layer_norm.',
  154. }
  155. def transform_checkpoint_dict_key(k):
  156. for text, replacement in checkpoint_dict_replacements.items():
  157. if k.startswith(text):
  158. k = replacement + k[len(text):]
  159. return k
  160. def get_state_dict_from_checkpoint(pl_sd):
  161. pl_sd = pl_sd.pop("state_dict", pl_sd)
  162. pl_sd.pop("state_dict", None)
  163. sd = {}
  164. for k, v in pl_sd.items():
  165. new_key = transform_checkpoint_dict_key(k)
  166. if new_key is not None:
  167. sd[new_key] = v
  168. pl_sd.clear()
  169. pl_sd.update(sd)
  170. return pl_sd
  171. def read_metadata_from_safetensors(filename):
  172. import json
  173. with open(filename, mode="rb") as file:
  174. metadata_len = file.read(8)
  175. metadata_len = int.from_bytes(metadata_len, "little")
  176. json_start = file.read(2)
  177. assert metadata_len > 2 and json_start in (b'{"', b"{'"), f"{filename} is not a safetensors file"
  178. json_data = json_start + file.read(metadata_len-2)
  179. json_obj = json.loads(json_data)
  180. res = {}
  181. for k, v in json_obj.get("__metadata__", {}).items():
  182. res[k] = v
  183. if isinstance(v, str) and v[0:1] == '{':
  184. try:
  185. res[k] = json.loads(v)
  186. except Exception:
  187. pass
  188. return res
  189. def read_state_dict(checkpoint_file, print_global_state=False, map_location=None):
  190. _, extension = os.path.splitext(checkpoint_file)
  191. if extension.lower() == ".safetensors":
  192. device = map_location or shared.weight_load_location or devices.get_optimal_device_name()
  193. if not shared.opts.disable_mmap_load_safetensors:
  194. pl_sd = safetensors.torch.load_file(checkpoint_file, device=device)
  195. else:
  196. pl_sd = safetensors.torch.load(open(checkpoint_file, 'rb').read())
  197. pl_sd = {k: v.to(device) for k, v in pl_sd.items()}
  198. else:
  199. pl_sd = torch.load(checkpoint_file, map_location=map_location or shared.weight_load_location)
  200. if print_global_state and "global_step" in pl_sd:
  201. print(f"Global Step: {pl_sd['global_step']}")
  202. sd = get_state_dict_from_checkpoint(pl_sd)
  203. return sd
  204. def get_checkpoint_state_dict(checkpoint_info: CheckpointInfo, timer):
  205. sd_model_hash = checkpoint_info.calculate_shorthash()
  206. timer.record("calculate hash")
  207. if checkpoint_info in checkpoints_loaded:
  208. # use checkpoint cache
  209. print(f"Loading weights [{sd_model_hash}] from cache")
  210. return checkpoints_loaded[checkpoint_info]
  211. print(f"Loading weights [{sd_model_hash}] from {checkpoint_info.filename}")
  212. res = read_state_dict(checkpoint_info.filename)
  213. timer.record("load weights from disk")
  214. return res
  215. class SkipWritingToConfig:
  216. """This context manager prevents load_model_weights from writing checkpoint name to the config when it loads weight."""
  217. skip = False
  218. previous = None
  219. def __enter__(self):
  220. self.previous = SkipWritingToConfig.skip
  221. SkipWritingToConfig.skip = True
  222. return self
  223. def __exit__(self, exc_type, exc_value, exc_traceback):
  224. SkipWritingToConfig.skip = self.previous
  225. def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer):
  226. sd_model_hash = checkpoint_info.calculate_shorthash()
  227. timer.record("calculate hash")
  228. if not SkipWritingToConfig.skip:
  229. shared.opts.data["sd_model_checkpoint"] = checkpoint_info.title
  230. if state_dict is None:
  231. state_dict = get_checkpoint_state_dict(checkpoint_info, timer)
  232. model.is_sdxl = hasattr(model, 'conditioner')
  233. model.is_sd2 = not model.is_sdxl and hasattr(model.cond_stage_model, 'model')
  234. model.is_sd1 = not model.is_sdxl and not model.is_sd2
  235. if model.is_sdxl:
  236. sd_models_xl.extend_sdxl(model)
  237. model.load_state_dict(state_dict, strict=False)
  238. timer.record("apply weights to model")
  239. if shared.opts.sd_checkpoint_cache > 0:
  240. # cache newly loaded model
  241. checkpoints_loaded[checkpoint_info] = state_dict
  242. del state_dict
  243. if shared.cmd_opts.opt_channelslast:
  244. model.to(memory_format=torch.channels_last)
  245. timer.record("apply channels_last")
  246. if shared.cmd_opts.no_half:
  247. model.float()
  248. devices.dtype_unet = torch.float32
  249. timer.record("apply float()")
  250. else:
  251. vae = model.first_stage_model
  252. depth_model = getattr(model, 'depth_model', None)
  253. # with --no-half-vae, remove VAE from model when doing half() to prevent its weights from being converted to float16
  254. if shared.cmd_opts.no_half_vae:
  255. model.first_stage_model = None
  256. # with --upcast-sampling, don't convert the depth model weights to float16
  257. if shared.cmd_opts.upcast_sampling and depth_model:
  258. model.depth_model = None
  259. model.half()
  260. model.first_stage_model = vae
  261. if depth_model:
  262. model.depth_model = depth_model
  263. devices.dtype_unet = torch.float16
  264. timer.record("apply half()")
  265. devices.unet_needs_upcast = shared.cmd_opts.upcast_sampling and devices.dtype == torch.float16 and devices.dtype_unet == torch.float16
  266. model.first_stage_model.to(devices.dtype_vae)
  267. timer.record("apply dtype to VAE")
  268. # clean up cache if limit is reached
  269. while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache:
  270. checkpoints_loaded.popitem(last=False)
  271. model.sd_model_hash = sd_model_hash
  272. model.sd_model_checkpoint = checkpoint_info.filename
  273. model.sd_checkpoint_info = checkpoint_info
  274. shared.opts.data["sd_checkpoint_hash"] = checkpoint_info.sha256
  275. if hasattr(model, 'logvar'):
  276. model.logvar = model.logvar.to(devices.device) # fix for training
  277. sd_vae.delete_base_vae()
  278. sd_vae.clear_loaded_vae()
  279. vae_file, vae_source = sd_vae.resolve_vae(checkpoint_info.filename).tuple()
  280. sd_vae.load_vae(model, vae_file, vae_source)
  281. timer.record("load VAE")
  282. def enable_midas_autodownload():
  283. """
  284. Gives the ldm.modules.midas.api.load_model function automatic downloading.
  285. When the 512-depth-ema model, and other future models like it, is loaded,
  286. it calls midas.api.load_model to load the associated midas depth model.
  287. This function applies a wrapper to download the model to the correct
  288. location automatically.
  289. """
  290. midas_path = os.path.join(paths.models_path, 'midas')
  291. # stable-diffusion-stability-ai hard-codes the midas model path to
  292. # a location that differs from where other scripts using this model look.
  293. # HACK: Overriding the path here.
  294. for k, v in midas.api.ISL_PATHS.items():
  295. file_name = os.path.basename(v)
  296. midas.api.ISL_PATHS[k] = os.path.join(midas_path, file_name)
  297. midas_urls = {
  298. "dpt_large": "https://github.com/intel-isl/DPT/releases/download/1_0/dpt_large-midas-2f21e586.pt",
  299. "dpt_hybrid": "https://github.com/intel-isl/DPT/releases/download/1_0/dpt_hybrid-midas-501f0c75.pt",
  300. "midas_v21": "https://github.com/AlexeyAB/MiDaS/releases/download/midas_dpt/midas_v21-f6b98070.pt",
  301. "midas_v21_small": "https://github.com/AlexeyAB/MiDaS/releases/download/midas_dpt/midas_v21_small-70d6b9c8.pt",
  302. }
  303. midas.api.load_model_inner = midas.api.load_model
  304. def load_model_wrapper(model_type):
  305. path = midas.api.ISL_PATHS[model_type]
  306. if not os.path.exists(path):
  307. if not os.path.exists(midas_path):
  308. mkdir(midas_path)
  309. print(f"Downloading midas model weights for {model_type} to {path}")
  310. request.urlretrieve(midas_urls[model_type], path)
  311. print(f"{model_type} downloaded")
  312. return midas.api.load_model_inner(model_type)
  313. midas.api.load_model = load_model_wrapper
  314. def repair_config(sd_config):
  315. if not hasattr(sd_config.model.params, "use_ema"):
  316. sd_config.model.params.use_ema = False
  317. if hasattr(sd_config.model.params, 'unet_config'):
  318. if shared.cmd_opts.no_half:
  319. sd_config.model.params.unet_config.params.use_fp16 = False
  320. elif shared.cmd_opts.upcast_sampling:
  321. sd_config.model.params.unet_config.params.use_fp16 = True
  322. if getattr(sd_config.model.params.first_stage_config.params.ddconfig, "attn_type", None) == "vanilla-xformers" and not shared.xformers_available:
  323. sd_config.model.params.first_stage_config.params.ddconfig.attn_type = "vanilla"
  324. # For UnCLIP-L, override the hardcoded karlo directory
  325. if hasattr(sd_config.model.params, "noise_aug_config") and hasattr(sd_config.model.params.noise_aug_config.params, "clip_stats_path"):
  326. karlo_path = os.path.join(paths.models_path, 'karlo')
  327. sd_config.model.params.noise_aug_config.params.clip_stats_path = sd_config.model.params.noise_aug_config.params.clip_stats_path.replace("checkpoints/karlo_models", karlo_path)
  328. sd1_clip_weight = 'cond_stage_model.transformer.text_model.embeddings.token_embedding.weight'
  329. sd2_clip_weight = 'cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight'
  330. sdxl_clip_weight = 'conditioner.embedders.1.model.ln_final.weight'
  331. sdxl_refiner_clip_weight = 'conditioner.embedders.0.model.ln_final.weight'
  332. class SdModelData:
  333. def __init__(self):
  334. self.sd_model = None
  335. self.loaded_sd_models = []
  336. self.was_loaded_at_least_once = False
  337. self.lock = threading.Lock()
  338. def get_sd_model(self):
  339. if self.was_loaded_at_least_once:
  340. return self.sd_model
  341. if self.sd_model is None:
  342. with self.lock:
  343. if self.sd_model is not None or self.was_loaded_at_least_once:
  344. return self.sd_model
  345. try:
  346. load_model()
  347. except Exception as e:
  348. errors.display(e, "loading stable diffusion model", full_traceback=True)
  349. print("", file=sys.stderr)
  350. print("Stable diffusion model failed to load", file=sys.stderr)
  351. self.sd_model = None
  352. return self.sd_model
  353. def set_sd_model(self, v, already_loaded=False):
  354. self.sd_model = v
  355. if already_loaded:
  356. sd_vae.base_vae = getattr(v, "base_vae", None)
  357. sd_vae.loaded_vae_file = getattr(v, "loaded_vae_file", None)
  358. sd_vae.checkpoint_info = v.sd_checkpoint_info
  359. try:
  360. self.loaded_sd_models.remove(v)
  361. except ValueError:
  362. pass
  363. if v is not None:
  364. self.loaded_sd_models.insert(0, v)
  365. model_data = SdModelData()
  366. def get_empty_cond(sd_model):
  367. p = processing.StableDiffusionProcessingTxt2Img()
  368. extra_networks.activate(p, {})
  369. if hasattr(sd_model, 'conditioner'):
  370. d = sd_model.get_learned_conditioning([""])
  371. return d['crossattn']
  372. else:
  373. return sd_model.cond_stage_model([""])
  374. def send_model_to_cpu(m):
  375. if m.lowvram:
  376. lowvram.send_everything_to_cpu()
  377. else:
  378. m.to(devices.cpu)
  379. devices.torch_gc()
  380. def model_target_device(m):
  381. if lowvram.is_needed(m):
  382. return devices.cpu
  383. else:
  384. return devices.device
  385. def send_model_to_device(m):
  386. lowvram.apply(m)
  387. if not m.lowvram:
  388. m.to(shared.device)
  389. def send_model_to_trash(m):
  390. m.to(device="meta")
  391. devices.torch_gc()
  392. def load_model(checkpoint_info=None, already_loaded_state_dict=None):
  393. from modules import sd_hijack
  394. checkpoint_info = checkpoint_info or select_checkpoint()
  395. timer = Timer()
  396. if model_data.sd_model:
  397. send_model_to_trash(model_data.sd_model)
  398. model_data.sd_model = None
  399. devices.torch_gc()
  400. timer.record("unload existing model")
  401. if already_loaded_state_dict is not None:
  402. state_dict = already_loaded_state_dict
  403. else:
  404. state_dict = get_checkpoint_state_dict(checkpoint_info, timer)
  405. checkpoint_config = sd_models_config.find_checkpoint_config(state_dict, checkpoint_info)
  406. clip_is_included_into_sd = any(x for x in [sd1_clip_weight, sd2_clip_weight, sdxl_clip_weight, sdxl_refiner_clip_weight] if x in state_dict)
  407. timer.record("find config")
  408. sd_config = OmegaConf.load(checkpoint_config)
  409. repair_config(sd_config)
  410. timer.record("load config")
  411. print(f"Creating model from config: {checkpoint_config}")
  412. sd_model = None
  413. try:
  414. with sd_disable_initialization.DisableInitialization(disable_clip=clip_is_included_into_sd or shared.cmd_opts.do_not_download_clip):
  415. with sd_disable_initialization.InitializeOnMeta():
  416. sd_model = instantiate_from_config(sd_config.model)
  417. except Exception as e:
  418. errors.display(e, "creating model quickly", full_traceback=True)
  419. if sd_model is None:
  420. print('Failed to create model quickly; will retry using slow method.', file=sys.stderr)
  421. with sd_disable_initialization.InitializeOnMeta():
  422. sd_model = instantiate_from_config(sd_config.model)
  423. sd_model.used_config = checkpoint_config
  424. timer.record("create model")
  425. if shared.cmd_opts.no_half:
  426. weight_dtype_conversion = None
  427. else:
  428. weight_dtype_conversion = {
  429. 'first_stage_model': None,
  430. '': torch.float16,
  431. }
  432. with sd_disable_initialization.LoadStateDictOnMeta(state_dict, device=model_target_device(sd_model), weight_dtype_conversion=weight_dtype_conversion):
  433. load_model_weights(sd_model, checkpoint_info, state_dict, timer)
  434. timer.record("load weights from state dict")
  435. send_model_to_device(sd_model)
  436. timer.record("move model to device")
  437. sd_hijack.model_hijack.hijack(sd_model)
  438. timer.record("hijack")
  439. sd_model.eval()
  440. model_data.set_sd_model(sd_model)
  441. model_data.was_loaded_at_least_once = True
  442. sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True) # Reload embeddings after model load as they may or may not fit the model
  443. timer.record("load textual inversion embeddings")
  444. script_callbacks.model_loaded_callback(sd_model)
  445. timer.record("scripts callbacks")
  446. with devices.autocast(), torch.no_grad():
  447. sd_model.cond_stage_model_empty_prompt = get_empty_cond(sd_model)
  448. timer.record("calculate empty prompt")
  449. print(f"Model loaded in {timer.summary()}.")
  450. return sd_model
  451. def reuse_model_from_already_loaded(sd_model, checkpoint_info, timer):
  452. """
  453. Checks if the desired checkpoint from checkpoint_info is not already loaded in model_data.loaded_sd_models.
  454. If it is loaded, returns that (moving it to GPU if necessary, and moving the currently loadded model to CPU if necessary).
  455. If not, returns the model that can be used to load weights from checkpoint_info's file.
  456. If no such model exists, returns None.
  457. Additionaly deletes loaded models that are over the limit set in settings (sd_checkpoints_limit).
  458. """
  459. already_loaded = None
  460. for i in reversed(range(len(model_data.loaded_sd_models))):
  461. loaded_model = model_data.loaded_sd_models[i]
  462. if loaded_model.sd_checkpoint_info.filename == checkpoint_info.filename:
  463. already_loaded = loaded_model
  464. continue
  465. if len(model_data.loaded_sd_models) > shared.opts.sd_checkpoints_limit > 0:
  466. print(f"Unloading model {len(model_data.loaded_sd_models)} over the limit of {shared.opts.sd_checkpoints_limit}: {loaded_model.sd_checkpoint_info.title}")
  467. model_data.loaded_sd_models.pop()
  468. send_model_to_trash(loaded_model)
  469. timer.record("send model to trash")
  470. if shared.opts.sd_checkpoints_keep_in_cpu:
  471. send_model_to_cpu(sd_model)
  472. timer.record("send model to cpu")
  473. if already_loaded is not None:
  474. send_model_to_device(already_loaded)
  475. timer.record("send model to device")
  476. model_data.set_sd_model(already_loaded, already_loaded=True)
  477. if not SkipWritingToConfig.skip:
  478. shared.opts.data["sd_model_checkpoint"] = already_loaded.sd_checkpoint_info.title
  479. shared.opts.data["sd_checkpoint_hash"] = already_loaded.sd_checkpoint_info.sha256
  480. print(f"Using already loaded model {already_loaded.sd_checkpoint_info.title}: done in {timer.summary()}")
  481. sd_vae.reload_vae_weights(already_loaded)
  482. return model_data.sd_model
  483. elif shared.opts.sd_checkpoints_limit > 1 and len(model_data.loaded_sd_models) < shared.opts.sd_checkpoints_limit:
  484. print(f"Loading model {checkpoint_info.title} ({len(model_data.loaded_sd_models) + 1} out of {shared.opts.sd_checkpoints_limit})")
  485. model_data.sd_model = None
  486. load_model(checkpoint_info)
  487. return model_data.sd_model
  488. elif len(model_data.loaded_sd_models) > 0:
  489. sd_model = model_data.loaded_sd_models.pop()
  490. model_data.sd_model = sd_model
  491. sd_vae.base_vae = getattr(sd_model, "base_vae", None)
  492. sd_vae.loaded_vae_file = getattr(sd_model, "loaded_vae_file", None)
  493. sd_vae.checkpoint_info = sd_model.sd_checkpoint_info
  494. print(f"Reusing loaded model {sd_model.sd_checkpoint_info.title} to load {checkpoint_info.title}")
  495. return sd_model
  496. else:
  497. return None
  498. def reload_model_weights(sd_model=None, info=None):
  499. checkpoint_info = info or select_checkpoint()
  500. timer = Timer()
  501. if not sd_model:
  502. sd_model = model_data.sd_model
  503. if sd_model is None: # previous model load failed
  504. current_checkpoint_info = None
  505. else:
  506. current_checkpoint_info = sd_model.sd_checkpoint_info
  507. if sd_model.sd_model_checkpoint == checkpoint_info.filename:
  508. return sd_model
  509. sd_model = reuse_model_from_already_loaded(sd_model, checkpoint_info, timer)
  510. if sd_model is not None and sd_model.sd_checkpoint_info.filename == checkpoint_info.filename:
  511. return sd_model
  512. if sd_model is not None:
  513. sd_unet.apply_unet("None")
  514. send_model_to_cpu(sd_model)
  515. sd_hijack.model_hijack.undo_hijack(sd_model)
  516. state_dict = get_checkpoint_state_dict(checkpoint_info, timer)
  517. checkpoint_config = sd_models_config.find_checkpoint_config(state_dict, checkpoint_info)
  518. timer.record("find config")
  519. if sd_model is None or checkpoint_config != sd_model.used_config:
  520. if sd_model is not None:
  521. send_model_to_trash(sd_model)
  522. load_model(checkpoint_info, already_loaded_state_dict=state_dict)
  523. return model_data.sd_model
  524. try:
  525. load_model_weights(sd_model, checkpoint_info, state_dict, timer)
  526. except Exception:
  527. print("Failed to load checkpoint, restoring previous")
  528. load_model_weights(sd_model, current_checkpoint_info, None, timer)
  529. raise
  530. finally:
  531. sd_hijack.model_hijack.hijack(sd_model)
  532. timer.record("hijack")
  533. script_callbacks.model_loaded_callback(sd_model)
  534. timer.record("script callbacks")
  535. if not sd_model.lowvram:
  536. sd_model.to(devices.device)
  537. timer.record("move model to device")
  538. print(f"Weights loaded in {timer.summary()}.")
  539. model_data.set_sd_model(sd_model)
  540. sd_unet.apply_unet()
  541. return sd_model
  542. def unload_model_weights(sd_model=None, info=None):
  543. timer = Timer()
  544. if model_data.sd_model:
  545. model_data.sd_model.to(devices.cpu)
  546. sd_hijack.model_hijack.undo_hijack(model_data.sd_model)
  547. model_data.sd_model = None
  548. sd_model = None
  549. gc.collect()
  550. devices.torch_gc()
  551. print(f"Unloaded weights {timer.summary()}.")
  552. return sd_model
  553. def apply_token_merging(sd_model, token_merging_ratio):
  554. """
  555. Applies speed and memory optimizations from tomesd.
  556. """
  557. current_token_merging_ratio = getattr(sd_model, 'applied_token_merged_ratio', 0)
  558. if current_token_merging_ratio == token_merging_ratio:
  559. return
  560. if current_token_merging_ratio > 0:
  561. tomesd.remove_patch(sd_model)
  562. if token_merging_ratio > 0:
  563. tomesd.apply_patch(
  564. sd_model,
  565. ratio=token_merging_ratio,
  566. use_rand=False, # can cause issues with some samplers
  567. merge_attn=True,
  568. merge_crossattn=False,
  569. merge_mlp=False
  570. )
  571. sd_model.applied_token_merged_ratio = token_merging_ratio