lora.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506
  1. import os
  2. import re
  3. import torch
  4. from typing import Union
  5. from modules import shared, devices, sd_models, errors, scripts, sd_hijack, hashes
  6. metadata_tags_order = {"ss_sd_model_name": 1, "ss_resolution": 2, "ss_clip_skip": 3, "ss_num_train_images": 10, "ss_tag_frequency": 20}
  7. re_digits = re.compile(r"\d+")
  8. re_x_proj = re.compile(r"(.*)_([qkv]_proj)$")
  9. re_compiled = {}
  10. suffix_conversion = {
  11. "attentions": {},
  12. "resnets": {
  13. "conv1": "in_layers_2",
  14. "conv2": "out_layers_3",
  15. "time_emb_proj": "emb_layers_1",
  16. "conv_shortcut": "skip_connection",
  17. }
  18. }
  19. def convert_diffusers_name_to_compvis(key, is_sd2):
  20. def match(match_list, regex_text):
  21. regex = re_compiled.get(regex_text)
  22. if regex is None:
  23. regex = re.compile(regex_text)
  24. re_compiled[regex_text] = regex
  25. r = re.match(regex, key)
  26. if not r:
  27. return False
  28. match_list.clear()
  29. match_list.extend([int(x) if re.match(re_digits, x) else x for x in r.groups()])
  30. return True
  31. m = []
  32. if match(m, r"lora_unet_down_blocks_(\d+)_(attentions|resnets)_(\d+)_(.+)"):
  33. suffix = suffix_conversion.get(m[1], {}).get(m[3], m[3])
  34. return f"diffusion_model_input_blocks_{1 + m[0] * 3 + m[2]}_{1 if m[1] == 'attentions' else 0}_{suffix}"
  35. if match(m, r"lora_unet_mid_block_(attentions|resnets)_(\d+)_(.+)"):
  36. suffix = suffix_conversion.get(m[0], {}).get(m[2], m[2])
  37. return f"diffusion_model_middle_block_{1 if m[0] == 'attentions' else m[1] * 2}_{suffix}"
  38. if match(m, r"lora_unet_up_blocks_(\d+)_(attentions|resnets)_(\d+)_(.+)"):
  39. suffix = suffix_conversion.get(m[1], {}).get(m[3], m[3])
  40. return f"diffusion_model_output_blocks_{m[0] * 3 + m[2]}_{1 if m[1] == 'attentions' else 0}_{suffix}"
  41. if match(m, r"lora_unet_down_blocks_(\d+)_downsamplers_0_conv"):
  42. return f"diffusion_model_input_blocks_{3 + m[0] * 3}_0_op"
  43. if match(m, r"lora_unet_up_blocks_(\d+)_upsamplers_0_conv"):
  44. return f"diffusion_model_output_blocks_{2 + m[0] * 3}_{2 if m[0]>0 else 1}_conv"
  45. if match(m, r"lora_te_text_model_encoder_layers_(\d+)_(.+)"):
  46. if is_sd2:
  47. if 'mlp_fc1' in m[1]:
  48. return f"model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc1', 'mlp_c_fc')}"
  49. elif 'mlp_fc2' in m[1]:
  50. return f"model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc2', 'mlp_c_proj')}"
  51. else:
  52. return f"model_transformer_resblocks_{m[0]}_{m[1].replace('self_attn', 'attn')}"
  53. return f"transformer_text_model_encoder_layers_{m[0]}_{m[1]}"
  54. return key
  55. class LoraOnDisk:
  56. def __init__(self, name, filename):
  57. self.name = name
  58. self.filename = filename
  59. self.metadata = {}
  60. self.is_safetensors = os.path.splitext(filename)[1].lower() == ".safetensors"
  61. if self.is_safetensors:
  62. try:
  63. self.metadata = sd_models.read_metadata_from_safetensors(filename)
  64. except Exception as e:
  65. errors.display(e, f"reading lora {filename}")
  66. if self.metadata:
  67. m = {}
  68. for k, v in sorted(self.metadata.items(), key=lambda x: metadata_tags_order.get(x[0], 999)):
  69. m[k] = v
  70. self.metadata = m
  71. self.ssmd_cover_images = self.metadata.pop('ssmd_cover_images', None) # those are cover images and they are too big to display in UI as text
  72. self.alias = self.metadata.get('ss_output_name', self.name)
  73. self.hash = None
  74. self.shorthash = None
  75. self.set_hash(
  76. self.metadata.get('sshs_model_hash') or
  77. hashes.sha256_from_cache(self.filename, "lora/" + self.name, use_addnet_hash=self.is_safetensors) or
  78. ''
  79. )
  80. def set_hash(self, v):
  81. self.hash = v
  82. self.shorthash = self.hash[0:12]
  83. if self.shorthash:
  84. available_lora_hash_lookup[self.shorthash] = self
  85. def read_hash(self):
  86. if not self.hash:
  87. self.set_hash(hashes.sha256(self.filename, "lora/" + self.name, use_addnet_hash=self.is_safetensors) or '')
  88. def get_alias(self):
  89. if shared.opts.lora_preferred_name == "Filename" or self.alias.lower() in forbidden_lora_aliases:
  90. return self.name
  91. else:
  92. return self.alias
  93. class LoraModule:
  94. def __init__(self, name, lora_on_disk: LoraOnDisk):
  95. self.name = name
  96. self.lora_on_disk = lora_on_disk
  97. self.multiplier = 1.0
  98. self.modules = {}
  99. self.mtime = None
  100. self.mentioned_name = None
  101. """the text that was used to add lora to prompt - can be either name or an alias"""
  102. class LoraUpDownModule:
  103. def __init__(self):
  104. self.up = None
  105. self.down = None
  106. self.alpha = None
  107. def assign_lora_names_to_compvis_modules(sd_model):
  108. lora_layer_mapping = {}
  109. for name, module in shared.sd_model.cond_stage_model.wrapped.named_modules():
  110. lora_name = name.replace(".", "_")
  111. lora_layer_mapping[lora_name] = module
  112. module.lora_layer_name = lora_name
  113. for name, module in shared.sd_model.model.named_modules():
  114. lora_name = name.replace(".", "_")
  115. lora_layer_mapping[lora_name] = module
  116. module.lora_layer_name = lora_name
  117. sd_model.lora_layer_mapping = lora_layer_mapping
  118. def load_lora(name, lora_on_disk):
  119. lora = LoraModule(name, lora_on_disk)
  120. lora.mtime = os.path.getmtime(lora_on_disk.filename)
  121. sd = sd_models.read_state_dict(lora_on_disk.filename)
  122. # this should not be needed but is here as an emergency fix for an unknown error people are experiencing in 1.2.0
  123. if not hasattr(shared.sd_model, 'lora_layer_mapping'):
  124. assign_lora_names_to_compvis_modules(shared.sd_model)
  125. keys_failed_to_match = {}
  126. is_sd2 = 'model_transformer_resblocks' in shared.sd_model.lora_layer_mapping
  127. for key_diffusers, weight in sd.items():
  128. key_diffusers_without_lora_parts, lora_key = key_diffusers.split(".", 1)
  129. key = convert_diffusers_name_to_compvis(key_diffusers_without_lora_parts, is_sd2)
  130. sd_module = shared.sd_model.lora_layer_mapping.get(key, None)
  131. if sd_module is None:
  132. m = re_x_proj.match(key)
  133. if m:
  134. sd_module = shared.sd_model.lora_layer_mapping.get(m.group(1), None)
  135. if sd_module is None:
  136. keys_failed_to_match[key_diffusers] = key
  137. continue
  138. lora_module = lora.modules.get(key, None)
  139. if lora_module is None:
  140. lora_module = LoraUpDownModule()
  141. lora.modules[key] = lora_module
  142. if lora_key == "alpha":
  143. lora_module.alpha = weight.item()
  144. continue
  145. if type(sd_module) == torch.nn.Linear:
  146. module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False)
  147. elif type(sd_module) == torch.nn.modules.linear.NonDynamicallyQuantizableLinear:
  148. module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False)
  149. elif type(sd_module) == torch.nn.MultiheadAttention:
  150. module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False)
  151. elif type(sd_module) == torch.nn.Conv2d and weight.shape[2:] == (1, 1):
  152. module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (1, 1), bias=False)
  153. elif type(sd_module) == torch.nn.Conv2d and weight.shape[2:] == (3, 3):
  154. module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (3, 3), bias=False)
  155. else:
  156. print(f'Lora layer {key_diffusers} matched a layer with unsupported type: {type(sd_module).__name__}')
  157. continue
  158. raise AssertionError(f"Lora layer {key_diffusers} matched a layer with unsupported type: {type(sd_module).__name__}")
  159. with torch.no_grad():
  160. module.weight.copy_(weight)
  161. module.to(device=devices.cpu, dtype=devices.dtype)
  162. if lora_key == "lora_up.weight":
  163. lora_module.up = module
  164. elif lora_key == "lora_down.weight":
  165. lora_module.down = module
  166. else:
  167. raise AssertionError(f"Bad Lora layer name: {key_diffusers} - must end in lora_up.weight, lora_down.weight or alpha")
  168. if keys_failed_to_match:
  169. print(f"Failed to match keys when loading Lora {lora_on_disk.filename}: {keys_failed_to_match}")
  170. return lora
  171. def load_loras(names, multipliers=None):
  172. already_loaded = {}
  173. for lora in loaded_loras:
  174. if lora.name in names:
  175. already_loaded[lora.name] = lora
  176. loaded_loras.clear()
  177. loras_on_disk = [available_lora_aliases.get(name, None) for name in names]
  178. if any(x is None for x in loras_on_disk):
  179. list_available_loras()
  180. loras_on_disk = [available_lora_aliases.get(name, None) for name in names]
  181. failed_to_load_loras = []
  182. for i, name in enumerate(names):
  183. lora = already_loaded.get(name, None)
  184. lora_on_disk = loras_on_disk[i]
  185. if lora_on_disk is not None:
  186. if lora is None or os.path.getmtime(lora_on_disk.filename) > lora.mtime:
  187. try:
  188. lora = load_lora(name, lora_on_disk)
  189. except Exception as e:
  190. errors.display(e, f"loading Lora {lora_on_disk.filename}")
  191. continue
  192. lora.mentioned_name = name
  193. lora_on_disk.read_hash()
  194. if lora is None:
  195. failed_to_load_loras.append(name)
  196. print(f"Couldn't find Lora with name {name}")
  197. continue
  198. lora.multiplier = multipliers[i] if multipliers else 1.0
  199. loaded_loras.append(lora)
  200. if failed_to_load_loras:
  201. sd_hijack.model_hijack.comments.append("Failed to find Loras: " + ", ".join(failed_to_load_loras))
  202. def lora_calc_updown(lora, module, target):
  203. with torch.no_grad():
  204. up = module.up.weight.to(target.device, dtype=target.dtype)
  205. down = module.down.weight.to(target.device, dtype=target.dtype)
  206. if up.shape[2:] == (1, 1) and down.shape[2:] == (1, 1):
  207. updown = (up.squeeze(2).squeeze(2) @ down.squeeze(2).squeeze(2)).unsqueeze(2).unsqueeze(3)
  208. elif up.shape[2:] == (3, 3) or down.shape[2:] == (3, 3):
  209. updown = torch.nn.functional.conv2d(down.permute(1, 0, 2, 3), up).permute(1, 0, 2, 3)
  210. else:
  211. updown = up @ down
  212. updown = updown * lora.multiplier * (module.alpha / module.up.weight.shape[1] if module.alpha else 1.0)
  213. return updown
  214. def lora_restore_weights_from_backup(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.MultiheadAttention]):
  215. weights_backup = getattr(self, "lora_weights_backup", None)
  216. if weights_backup is None:
  217. return
  218. if isinstance(self, torch.nn.MultiheadAttention):
  219. self.in_proj_weight.copy_(weights_backup[0])
  220. self.out_proj.weight.copy_(weights_backup[1])
  221. else:
  222. self.weight.copy_(weights_backup)
  223. def lora_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.MultiheadAttention]):
  224. """
  225. Applies the currently selected set of Loras to the weights of torch layer self.
  226. If weights already have this particular set of loras applied, does nothing.
  227. If not, restores orginal weights from backup and alters weights according to loras.
  228. """
  229. lora_layer_name = getattr(self, 'lora_layer_name', None)
  230. if lora_layer_name is None:
  231. return
  232. current_names = getattr(self, "lora_current_names", ())
  233. wanted_names = tuple((x.name, x.multiplier) for x in loaded_loras)
  234. weights_backup = getattr(self, "lora_weights_backup", None)
  235. if weights_backup is None:
  236. if isinstance(self, torch.nn.MultiheadAttention):
  237. weights_backup = (self.in_proj_weight.to(devices.cpu, copy=True), self.out_proj.weight.to(devices.cpu, copy=True))
  238. else:
  239. weights_backup = self.weight.to(devices.cpu, copy=True)
  240. self.lora_weights_backup = weights_backup
  241. if current_names != wanted_names:
  242. lora_restore_weights_from_backup(self)
  243. for lora in loaded_loras:
  244. module = lora.modules.get(lora_layer_name, None)
  245. if module is not None and hasattr(self, 'weight'):
  246. self.weight += lora_calc_updown(lora, module, self.weight)
  247. continue
  248. module_q = lora.modules.get(lora_layer_name + "_q_proj", None)
  249. module_k = lora.modules.get(lora_layer_name + "_k_proj", None)
  250. module_v = lora.modules.get(lora_layer_name + "_v_proj", None)
  251. module_out = lora.modules.get(lora_layer_name + "_out_proj", None)
  252. if isinstance(self, torch.nn.MultiheadAttention) and module_q and module_k and module_v and module_out:
  253. updown_q = lora_calc_updown(lora, module_q, self.in_proj_weight)
  254. updown_k = lora_calc_updown(lora, module_k, self.in_proj_weight)
  255. updown_v = lora_calc_updown(lora, module_v, self.in_proj_weight)
  256. updown_qkv = torch.vstack([updown_q, updown_k, updown_v])
  257. self.in_proj_weight += updown_qkv
  258. self.out_proj.weight += lora_calc_updown(lora, module_out, self.out_proj.weight)
  259. continue
  260. if module is None:
  261. continue
  262. print(f'failed to calculate lora weights for layer {lora_layer_name}')
  263. self.lora_current_names = wanted_names
  264. def lora_forward(module, input, original_forward):
  265. """
  266. Old way of applying Lora by executing operations during layer's forward.
  267. Stacking many loras this way results in big performance degradation.
  268. """
  269. if len(loaded_loras) == 0:
  270. return original_forward(module, input)
  271. input = devices.cond_cast_unet(input)
  272. lora_restore_weights_from_backup(module)
  273. lora_reset_cached_weight(module)
  274. res = original_forward(module, input)
  275. lora_layer_name = getattr(module, 'lora_layer_name', None)
  276. for lora in loaded_loras:
  277. module = lora.modules.get(lora_layer_name, None)
  278. if module is None:
  279. continue
  280. module.up.to(device=devices.device)
  281. module.down.to(device=devices.device)
  282. res = res + module.up(module.down(input)) * lora.multiplier * (module.alpha / module.up.weight.shape[1] if module.alpha else 1.0)
  283. return res
  284. def lora_reset_cached_weight(self: Union[torch.nn.Conv2d, torch.nn.Linear]):
  285. self.lora_current_names = ()
  286. self.lora_weights_backup = None
  287. def lora_Linear_forward(self, input):
  288. if shared.opts.lora_functional:
  289. return lora_forward(self, input, torch.nn.Linear_forward_before_lora)
  290. lora_apply_weights(self)
  291. return torch.nn.Linear_forward_before_lora(self, input)
  292. def lora_Linear_load_state_dict(self, *args, **kwargs):
  293. lora_reset_cached_weight(self)
  294. return torch.nn.Linear_load_state_dict_before_lora(self, *args, **kwargs)
  295. def lora_Conv2d_forward(self, input):
  296. if shared.opts.lora_functional:
  297. return lora_forward(self, input, torch.nn.Conv2d_forward_before_lora)
  298. lora_apply_weights(self)
  299. return torch.nn.Conv2d_forward_before_lora(self, input)
  300. def lora_Conv2d_load_state_dict(self, *args, **kwargs):
  301. lora_reset_cached_weight(self)
  302. return torch.nn.Conv2d_load_state_dict_before_lora(self, *args, **kwargs)
  303. def lora_MultiheadAttention_forward(self, *args, **kwargs):
  304. lora_apply_weights(self)
  305. return torch.nn.MultiheadAttention_forward_before_lora(self, *args, **kwargs)
  306. def lora_MultiheadAttention_load_state_dict(self, *args, **kwargs):
  307. lora_reset_cached_weight(self)
  308. return torch.nn.MultiheadAttention_load_state_dict_before_lora(self, *args, **kwargs)
  309. def list_available_loras():
  310. available_loras.clear()
  311. available_lora_aliases.clear()
  312. forbidden_lora_aliases.clear()
  313. available_lora_hash_lookup.clear()
  314. forbidden_lora_aliases.update({"none": 1, "Addams": 1})
  315. os.makedirs(shared.cmd_opts.lora_dir, exist_ok=True)
  316. candidates = list(shared.walk_files(shared.cmd_opts.lora_dir, allowed_extensions=[".pt", ".ckpt", ".safetensors"]))
  317. for filename in sorted(candidates, key=str.lower):
  318. if os.path.isdir(filename):
  319. continue
  320. name = os.path.splitext(os.path.basename(filename))[0]
  321. try:
  322. entry = LoraOnDisk(name, filename)
  323. except OSError: # should catch FileNotFoundError and PermissionError etc.
  324. errors.report(f"Failed to load LoRA {name} from {filename}", exc_info=True)
  325. continue
  326. available_loras[name] = entry
  327. if entry.alias in available_lora_aliases:
  328. forbidden_lora_aliases[entry.alias.lower()] = 1
  329. available_lora_aliases[name] = entry
  330. available_lora_aliases[entry.alias] = entry
  331. re_lora_name = re.compile(r"(.*)\s*\([0-9a-fA-F]+\)")
  332. def infotext_pasted(infotext, params):
  333. if "AddNet Module 1" in [x[1] for x in scripts.scripts_txt2img.infotext_fields]:
  334. return # if the other extension is active, it will handle those fields, no need to do anything
  335. added = []
  336. for k in params:
  337. if not k.startswith("AddNet Model "):
  338. continue
  339. num = k[13:]
  340. if params.get("AddNet Module " + num) != "LoRA":
  341. continue
  342. name = params.get("AddNet Model " + num)
  343. if name is None:
  344. continue
  345. m = re_lora_name.match(name)
  346. if m:
  347. name = m.group(1)
  348. multiplier = params.get("AddNet Weight A " + num, "1.0")
  349. added.append(f"<lora:{name}:{multiplier}>")
  350. if added:
  351. params["Prompt"] += "\n" + "".join(added)
  352. available_loras = {}
  353. available_lora_aliases = {}
  354. available_lora_hash_lookup = {}
  355. forbidden_lora_aliases = {}
  356. loaded_loras = []
  357. list_available_loras()