networks.py 29 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737
  1. from __future__ import annotations
  2. import gradio as gr
  3. import logging
  4. import os
  5. import re
  6. import lora_patches
  7. import network
  8. import network_lora
  9. import network_glora
  10. import network_hada
  11. import network_ia3
  12. import network_lokr
  13. import network_full
  14. import network_norm
  15. import network_oft
  16. import torch
  17. from typing import Union
  18. from modules import shared, devices, sd_models, errors, scripts, sd_hijack
  19. import modules.textual_inversion.textual_inversion as textual_inversion
  20. import modules.models.sd3.mmdit
  21. from lora_logger import logger
  22. module_types = [
  23. network_lora.ModuleTypeLora(),
  24. network_hada.ModuleTypeHada(),
  25. network_ia3.ModuleTypeIa3(),
  26. network_lokr.ModuleTypeLokr(),
  27. network_full.ModuleTypeFull(),
  28. network_norm.ModuleTypeNorm(),
  29. network_glora.ModuleTypeGLora(),
  30. network_oft.ModuleTypeOFT(),
  31. ]
  32. re_digits = re.compile(r"\d+")
  33. re_x_proj = re.compile(r"(.*)_([qkv]_proj)$")
  34. re_compiled = {}
  35. suffix_conversion = {
  36. "attentions": {},
  37. "resnets": {
  38. "conv1": "in_layers_2",
  39. "conv2": "out_layers_3",
  40. "norm1": "in_layers_0",
  41. "norm2": "out_layers_0",
  42. "time_emb_proj": "emb_layers_1",
  43. "conv_shortcut": "skip_connection",
  44. }
  45. }
  46. def convert_diffusers_name_to_compvis(key, is_sd2):
  47. def match(match_list, regex_text):
  48. regex = re_compiled.get(regex_text)
  49. if regex is None:
  50. regex = re.compile(regex_text)
  51. re_compiled[regex_text] = regex
  52. r = re.match(regex, key)
  53. if not r:
  54. return False
  55. match_list.clear()
  56. match_list.extend([int(x) if re.match(re_digits, x) else x for x in r.groups()])
  57. return True
  58. m = []
  59. if match(m, r"lora_unet_conv_in(.*)"):
  60. return f'diffusion_model_input_blocks_0_0{m[0]}'
  61. if match(m, r"lora_unet_conv_out(.*)"):
  62. return f'diffusion_model_out_2{m[0]}'
  63. if match(m, r"lora_unet_time_embedding_linear_(\d+)(.*)"):
  64. return f"diffusion_model_time_embed_{m[0] * 2 - 2}{m[1]}"
  65. if match(m, r"lora_unet_down_blocks_(\d+)_(attentions|resnets)_(\d+)_(.+)"):
  66. suffix = suffix_conversion.get(m[1], {}).get(m[3], m[3])
  67. return f"diffusion_model_input_blocks_{1 + m[0] * 3 + m[2]}_{1 if m[1] == 'attentions' else 0}_{suffix}"
  68. if match(m, r"lora_unet_mid_block_(attentions|resnets)_(\d+)_(.+)"):
  69. suffix = suffix_conversion.get(m[0], {}).get(m[2], m[2])
  70. return f"diffusion_model_middle_block_{1 if m[0] == 'attentions' else m[1] * 2}_{suffix}"
  71. if match(m, r"lora_unet_up_blocks_(\d+)_(attentions|resnets)_(\d+)_(.+)"):
  72. suffix = suffix_conversion.get(m[1], {}).get(m[3], m[3])
  73. return f"diffusion_model_output_blocks_{m[0] * 3 + m[2]}_{1 if m[1] == 'attentions' else 0}_{suffix}"
  74. if match(m, r"lora_unet_down_blocks_(\d+)_downsamplers_0_conv"):
  75. return f"diffusion_model_input_blocks_{3 + m[0] * 3}_0_op"
  76. if match(m, r"lora_unet_up_blocks_(\d+)_upsamplers_0_conv"):
  77. return f"diffusion_model_output_blocks_{2 + m[0] * 3}_{2 if m[0]>0 else 1}_conv"
  78. if match(m, r"lora_te_text_model_encoder_layers_(\d+)_(.+)"):
  79. if is_sd2:
  80. if 'mlp_fc1' in m[1]:
  81. return f"model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc1', 'mlp_c_fc')}"
  82. elif 'mlp_fc2' in m[1]:
  83. return f"model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc2', 'mlp_c_proj')}"
  84. else:
  85. return f"model_transformer_resblocks_{m[0]}_{m[1].replace('self_attn', 'attn')}"
  86. return f"transformer_text_model_encoder_layers_{m[0]}_{m[1]}"
  87. if match(m, r"lora_te2_text_model_encoder_layers_(\d+)_(.+)"):
  88. if 'mlp_fc1' in m[1]:
  89. return f"1_model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc1', 'mlp_c_fc')}"
  90. elif 'mlp_fc2' in m[1]:
  91. return f"1_model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc2', 'mlp_c_proj')}"
  92. else:
  93. return f"1_model_transformer_resblocks_{m[0]}_{m[1].replace('self_attn', 'attn')}"
  94. return key
  95. def assign_network_names_to_compvis_modules(sd_model):
  96. network_layer_mapping = {}
  97. if shared.sd_model.is_sdxl:
  98. for i, embedder in enumerate(shared.sd_model.conditioner.embedders):
  99. if not hasattr(embedder, 'wrapped'):
  100. continue
  101. for name, module in embedder.wrapped.named_modules():
  102. network_name = f'{i}_{name.replace(".", "_")}'
  103. network_layer_mapping[network_name] = module
  104. module.network_layer_name = network_name
  105. else:
  106. cond_stage_model = getattr(shared.sd_model.cond_stage_model, 'wrapped', shared.sd_model.cond_stage_model)
  107. for name, module in cond_stage_model.named_modules():
  108. network_name = name.replace(".", "_")
  109. network_layer_mapping[network_name] = module
  110. module.network_layer_name = network_name
  111. for name, module in shared.sd_model.model.named_modules():
  112. network_name = name.replace(".", "_")
  113. network_layer_mapping[network_name] = module
  114. module.network_layer_name = network_name
  115. sd_model.network_layer_mapping = network_layer_mapping
  116. class BundledTIHash(str):
  117. def __init__(self, hash_str):
  118. self.hash = hash_str
  119. def __str__(self):
  120. return self.hash if shared.opts.lora_bundled_ti_to_infotext else ''
  121. def load_network(name, network_on_disk):
  122. net = network.Network(name, network_on_disk)
  123. net.mtime = os.path.getmtime(network_on_disk.filename)
  124. sd = sd_models.read_state_dict(network_on_disk.filename)
  125. # this should not be needed but is here as an emergency fix for an unknown error people are experiencing in 1.2.0
  126. if not hasattr(shared.sd_model, 'network_layer_mapping'):
  127. assign_network_names_to_compvis_modules(shared.sd_model)
  128. keys_failed_to_match = {}
  129. is_sd2 = 'model_transformer_resblocks' in shared.sd_model.network_layer_mapping
  130. if hasattr(shared.sd_model, 'diffusers_weight_map'):
  131. diffusers_weight_map = shared.sd_model.diffusers_weight_map
  132. elif hasattr(shared.sd_model, 'diffusers_weight_mapping'):
  133. diffusers_weight_map = {}
  134. for k, v in shared.sd_model.diffusers_weight_mapping():
  135. diffusers_weight_map[k] = v
  136. shared.sd_model.diffusers_weight_map = diffusers_weight_map
  137. else:
  138. diffusers_weight_map = None
  139. matched_networks = {}
  140. bundle_embeddings = {}
  141. for key_network, weight in sd.items():
  142. if diffusers_weight_map:
  143. key_network_without_network_parts, network_name, network_weight = key_network.rsplit(".", 2)
  144. network_part = network_name + '.' + network_weight
  145. else:
  146. key_network_without_network_parts, _, network_part = key_network.partition(".")
  147. if key_network_without_network_parts == "bundle_emb":
  148. emb_name, vec_name = network_part.split(".", 1)
  149. emb_dict = bundle_embeddings.get(emb_name, {})
  150. if vec_name.split('.')[0] == 'string_to_param':
  151. _, k2 = vec_name.split('.', 1)
  152. emb_dict['string_to_param'] = {k2: weight}
  153. else:
  154. emb_dict[vec_name] = weight
  155. bundle_embeddings[emb_name] = emb_dict
  156. if diffusers_weight_map:
  157. key = diffusers_weight_map.get(key_network_without_network_parts, key_network_without_network_parts)
  158. else:
  159. key = convert_diffusers_name_to_compvis(key_network_without_network_parts, is_sd2)
  160. sd_module = shared.sd_model.network_layer_mapping.get(key, None)
  161. if sd_module is None:
  162. m = re_x_proj.match(key)
  163. if m:
  164. sd_module = shared.sd_model.network_layer_mapping.get(m.group(1), None)
  165. # SDXL loras seem to already have correct compvis keys, so only need to replace "lora_unet" with "diffusion_model"
  166. if sd_module is None and "lora_unet" in key_network_without_network_parts:
  167. key = key_network_without_network_parts.replace("lora_unet", "diffusion_model")
  168. sd_module = shared.sd_model.network_layer_mapping.get(key, None)
  169. elif sd_module is None and "lora_te1_text_model" in key_network_without_network_parts:
  170. key = key_network_without_network_parts.replace("lora_te1_text_model", "0_transformer_text_model")
  171. sd_module = shared.sd_model.network_layer_mapping.get(key, None)
  172. # some SD1 Loras also have correct compvis keys
  173. if sd_module is None:
  174. key = key_network_without_network_parts.replace("lora_te1_text_model", "transformer_text_model")
  175. sd_module = shared.sd_model.network_layer_mapping.get(key, None)
  176. # kohya_ss OFT module
  177. elif sd_module is None and "oft_unet" in key_network_without_network_parts:
  178. key = key_network_without_network_parts.replace("oft_unet", "diffusion_model")
  179. sd_module = shared.sd_model.network_layer_mapping.get(key, None)
  180. # KohakuBlueLeaf OFT module
  181. if sd_module is None and "oft_diag" in key:
  182. key = key_network_without_network_parts.replace("lora_unet", "diffusion_model")
  183. key = key_network_without_network_parts.replace("lora_te1_text_model", "0_transformer_text_model")
  184. sd_module = shared.sd_model.network_layer_mapping.get(key, None)
  185. if sd_module is None:
  186. keys_failed_to_match[key_network] = key
  187. continue
  188. if key not in matched_networks:
  189. matched_networks[key] = network.NetworkWeights(network_key=key_network, sd_key=key, w={}, sd_module=sd_module)
  190. matched_networks[key].w[network_part] = weight
  191. for key, weights in matched_networks.items():
  192. net_module = None
  193. for nettype in module_types:
  194. net_module = nettype.create_module(net, weights)
  195. if net_module is not None:
  196. break
  197. if net_module is None:
  198. raise AssertionError(f"Could not find a module type (out of {', '.join([x.__class__.__name__ for x in module_types])}) that would accept those keys: {', '.join(weights.w)}")
  199. net.modules[key] = net_module
  200. embeddings = {}
  201. for emb_name, data in bundle_embeddings.items():
  202. embedding = textual_inversion.create_embedding_from_data(data, emb_name, filename=network_on_disk.filename + "/" + emb_name)
  203. embedding.loaded = None
  204. embedding.shorthash = BundledTIHash(name)
  205. embeddings[emb_name] = embedding
  206. net.bundle_embeddings = embeddings
  207. if keys_failed_to_match:
  208. logging.debug(f"Network {network_on_disk.filename} didn't match keys: {keys_failed_to_match}")
  209. return net
  210. def purge_networks_from_memory():
  211. while len(networks_in_memory) > shared.opts.lora_in_memory_limit and len(networks_in_memory) > 0:
  212. name = next(iter(networks_in_memory))
  213. networks_in_memory.pop(name, None)
  214. devices.torch_gc()
  215. def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=None):
  216. emb_db = sd_hijack.model_hijack.embedding_db
  217. already_loaded = {}
  218. for net in loaded_networks:
  219. if net.name in names:
  220. already_loaded[net.name] = net
  221. for emb_name, embedding in net.bundle_embeddings.items():
  222. if embedding.loaded:
  223. emb_db.register_embedding_by_name(None, shared.sd_model, emb_name)
  224. loaded_networks.clear()
  225. unavailable_networks = []
  226. for name in names:
  227. if name.lower() in forbidden_network_aliases and available_networks.get(name) is None:
  228. unavailable_networks.append(name)
  229. elif available_network_aliases.get(name) is None:
  230. unavailable_networks.append(name)
  231. if unavailable_networks:
  232. update_available_networks_by_names(unavailable_networks)
  233. networks_on_disk = [available_networks.get(name, None) if name.lower() in forbidden_network_aliases else available_network_aliases.get(name, None) for name in names]
  234. if any(x is None for x in networks_on_disk):
  235. list_available_networks()
  236. networks_on_disk = [available_networks.get(name, None) if name.lower() in forbidden_network_aliases else available_network_aliases.get(name, None) for name in names]
  237. failed_to_load_networks = []
  238. for i, (network_on_disk, name) in enumerate(zip(networks_on_disk, names)):
  239. net = already_loaded.get(name, None)
  240. if network_on_disk is not None:
  241. if net is None:
  242. net = networks_in_memory.get(name)
  243. if net is None or os.path.getmtime(network_on_disk.filename) > net.mtime:
  244. try:
  245. net = load_network(name, network_on_disk)
  246. networks_in_memory.pop(name, None)
  247. networks_in_memory[name] = net
  248. except Exception as e:
  249. errors.display(e, f"loading network {network_on_disk.filename}")
  250. continue
  251. net.mentioned_name = name
  252. network_on_disk.read_hash()
  253. if net is None:
  254. failed_to_load_networks.append(name)
  255. logging.info(f"Couldn't find network with name {name}")
  256. continue
  257. net.te_multiplier = te_multipliers[i] if te_multipliers else 1.0
  258. net.unet_multiplier = unet_multipliers[i] if unet_multipliers else 1.0
  259. net.dyn_dim = dyn_dims[i] if dyn_dims else 1.0
  260. loaded_networks.append(net)
  261. for emb_name, embedding in net.bundle_embeddings.items():
  262. if embedding.loaded is None and emb_name in emb_db.word_embeddings:
  263. logger.warning(
  264. f'Skip bundle embedding: "{emb_name}"'
  265. ' as it was already loaded from embeddings folder'
  266. )
  267. continue
  268. embedding.loaded = False
  269. if emb_db.expected_shape == -1 or emb_db.expected_shape == embedding.shape:
  270. embedding.loaded = True
  271. emb_db.register_embedding(embedding, shared.sd_model)
  272. else:
  273. emb_db.skipped_embeddings[name] = embedding
  274. if failed_to_load_networks:
  275. lora_not_found_message = f'Lora not found: {", ".join(failed_to_load_networks)}'
  276. sd_hijack.model_hijack.comments.append(lora_not_found_message)
  277. if shared.opts.lora_not_found_warning_console:
  278. print(f'\n{lora_not_found_message}\n')
  279. if shared.opts.lora_not_found_gradio_warning:
  280. gr.Warning(lora_not_found_message)
  281. purge_networks_from_memory()
  282. def allowed_layer_without_weight(layer):
  283. if isinstance(layer, torch.nn.LayerNorm) and not layer.elementwise_affine:
  284. return True
  285. return False
  286. def store_weights_backup(weight):
  287. if weight is None:
  288. return None
  289. return weight.to(devices.cpu, copy=True)
  290. def restore_weights_backup(obj, field, weight):
  291. if weight is None:
  292. setattr(obj, field, None)
  293. return
  294. getattr(obj, field).copy_(weight)
  295. def network_restore_weights_from_backup(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, torch.nn.MultiheadAttention]):
  296. weights_backup = getattr(self, "network_weights_backup", None)
  297. bias_backup = getattr(self, "network_bias_backup", None)
  298. if weights_backup is None and bias_backup is None:
  299. return
  300. if weights_backup is not None:
  301. if isinstance(self, torch.nn.MultiheadAttention):
  302. restore_weights_backup(self, 'in_proj_weight', weights_backup[0])
  303. restore_weights_backup(self.out_proj, 'weight', weights_backup[1])
  304. else:
  305. restore_weights_backup(self, 'weight', weights_backup)
  306. if isinstance(self, torch.nn.MultiheadAttention):
  307. restore_weights_backup(self.out_proj, 'bias', bias_backup)
  308. else:
  309. restore_weights_backup(self, 'bias', bias_backup)
  310. def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, torch.nn.MultiheadAttention]):
  311. """
  312. Applies the currently selected set of networks to the weights of torch layer self.
  313. If weights already have this particular set of networks applied, does nothing.
  314. If not, restores original weights from backup and alters weights according to networks.
  315. """
  316. network_layer_name = getattr(self, 'network_layer_name', None)
  317. if network_layer_name is None:
  318. return
  319. current_names = getattr(self, "network_current_names", ())
  320. wanted_names = tuple((x.name, x.te_multiplier, x.unet_multiplier, x.dyn_dim) for x in loaded_networks)
  321. weights_backup = getattr(self, "network_weights_backup", None)
  322. if weights_backup is None and wanted_names != ():
  323. if current_names != () and not allowed_layer_without_weight(self):
  324. raise RuntimeError(f"{network_layer_name} - no backup weights found and current weights are not unchanged")
  325. if isinstance(self, torch.nn.MultiheadAttention):
  326. weights_backup = (store_weights_backup(self.in_proj_weight), store_weights_backup(self.out_proj.weight))
  327. else:
  328. weights_backup = store_weights_backup(self.weight)
  329. self.network_weights_backup = weights_backup
  330. bias_backup = getattr(self, "network_bias_backup", None)
  331. if bias_backup is None and wanted_names != ():
  332. if isinstance(self, torch.nn.MultiheadAttention) and self.out_proj.bias is not None:
  333. bias_backup = store_weights_backup(self.out_proj.bias)
  334. elif getattr(self, 'bias', None) is not None:
  335. bias_backup = store_weights_backup(self.bias)
  336. else:
  337. bias_backup = None
  338. # Unlike weight which always has value, some modules don't have bias.
  339. # Only report if bias is not None and current bias are not unchanged.
  340. if bias_backup is not None and current_names != ():
  341. raise RuntimeError("no backup bias found and current bias are not unchanged")
  342. self.network_bias_backup = bias_backup
  343. if current_names != wanted_names:
  344. network_restore_weights_from_backup(self)
  345. for net in loaded_networks:
  346. module = net.modules.get(network_layer_name, None)
  347. if module is not None and hasattr(self, 'weight') and not isinstance(module, modules.models.sd3.mmdit.QkvLinear):
  348. try:
  349. with torch.no_grad():
  350. if getattr(self, 'fp16_weight', None) is None:
  351. weight = self.weight
  352. bias = self.bias
  353. else:
  354. weight = self.fp16_weight.clone().to(self.weight.device)
  355. bias = getattr(self, 'fp16_bias', None)
  356. if bias is not None:
  357. bias = bias.clone().to(self.bias.device)
  358. updown, ex_bias = module.calc_updown(weight)
  359. if len(weight.shape) == 4 and weight.shape[1] == 9:
  360. # inpainting model. zero pad updown to make channel[1] 4 to 9
  361. updown = torch.nn.functional.pad(updown, (0, 0, 0, 0, 0, 5))
  362. self.weight.copy_((weight.to(dtype=updown.dtype) + updown).to(dtype=self.weight.dtype))
  363. if ex_bias is not None and hasattr(self, 'bias'):
  364. if self.bias is None:
  365. self.bias = torch.nn.Parameter(ex_bias).to(self.weight.dtype)
  366. else:
  367. self.bias.copy_((bias + ex_bias).to(dtype=self.bias.dtype))
  368. except RuntimeError as e:
  369. logging.debug(f"Network {net.name} layer {network_layer_name}: {e}")
  370. extra_network_lora.errors[net.name] = extra_network_lora.errors.get(net.name, 0) + 1
  371. continue
  372. module_q = net.modules.get(network_layer_name + "_q_proj", None)
  373. module_k = net.modules.get(network_layer_name + "_k_proj", None)
  374. module_v = net.modules.get(network_layer_name + "_v_proj", None)
  375. module_out = net.modules.get(network_layer_name + "_out_proj", None)
  376. if isinstance(self, torch.nn.MultiheadAttention) and module_q and module_k and module_v and module_out:
  377. try:
  378. with torch.no_grad():
  379. # Send "real" orig_weight into MHA's lora module
  380. qw, kw, vw = self.in_proj_weight.chunk(3, 0)
  381. updown_q, _ = module_q.calc_updown(qw)
  382. updown_k, _ = module_k.calc_updown(kw)
  383. updown_v, _ = module_v.calc_updown(vw)
  384. del qw, kw, vw
  385. updown_qkv = torch.vstack([updown_q, updown_k, updown_v])
  386. updown_out, ex_bias = module_out.calc_updown(self.out_proj.weight)
  387. self.in_proj_weight += updown_qkv
  388. self.out_proj.weight += updown_out
  389. if ex_bias is not None:
  390. if self.out_proj.bias is None:
  391. self.out_proj.bias = torch.nn.Parameter(ex_bias)
  392. else:
  393. self.out_proj.bias += ex_bias
  394. except RuntimeError as e:
  395. logging.debug(f"Network {net.name} layer {network_layer_name}: {e}")
  396. extra_network_lora.errors[net.name] = extra_network_lora.errors.get(net.name, 0) + 1
  397. continue
  398. if isinstance(self, modules.models.sd3.mmdit.QkvLinear) and module_q and module_k and module_v:
  399. try:
  400. with torch.no_grad():
  401. # Send "real" orig_weight into MHA's lora module
  402. qw, kw, vw = self.weight.chunk(3, 0)
  403. updown_q, _ = module_q.calc_updown(qw)
  404. updown_k, _ = module_k.calc_updown(kw)
  405. updown_v, _ = module_v.calc_updown(vw)
  406. del qw, kw, vw
  407. updown_qkv = torch.vstack([updown_q, updown_k, updown_v])
  408. self.weight += updown_qkv
  409. except RuntimeError as e:
  410. logging.debug(f"Network {net.name} layer {network_layer_name}: {e}")
  411. extra_network_lora.errors[net.name] = extra_network_lora.errors.get(net.name, 0) + 1
  412. continue
  413. if module is None:
  414. continue
  415. logging.debug(f"Network {net.name} layer {network_layer_name}: couldn't find supported operation")
  416. extra_network_lora.errors[net.name] = extra_network_lora.errors.get(net.name, 0) + 1
  417. self.network_current_names = wanted_names
  418. def network_forward(org_module, input, original_forward):
  419. """
  420. Old way of applying Lora by executing operations during layer's forward.
  421. Stacking many loras this way results in big performance degradation.
  422. """
  423. if len(loaded_networks) == 0:
  424. return original_forward(org_module, input)
  425. input = devices.cond_cast_unet(input)
  426. network_restore_weights_from_backup(org_module)
  427. network_reset_cached_weight(org_module)
  428. y = original_forward(org_module, input)
  429. network_layer_name = getattr(org_module, 'network_layer_name', None)
  430. for lora in loaded_networks:
  431. module = lora.modules.get(network_layer_name, None)
  432. if module is None:
  433. continue
  434. y = module.forward(input, y)
  435. return y
  436. def network_reset_cached_weight(self: Union[torch.nn.Conv2d, torch.nn.Linear]):
  437. self.network_current_names = ()
  438. self.network_weights_backup = None
  439. self.network_bias_backup = None
  440. def network_Linear_forward(self, input):
  441. if shared.opts.lora_functional:
  442. return network_forward(self, input, originals.Linear_forward)
  443. network_apply_weights(self)
  444. return originals.Linear_forward(self, input)
  445. def network_Linear_load_state_dict(self, *args, **kwargs):
  446. network_reset_cached_weight(self)
  447. return originals.Linear_load_state_dict(self, *args, **kwargs)
  448. def network_Conv2d_forward(self, input):
  449. if shared.opts.lora_functional:
  450. return network_forward(self, input, originals.Conv2d_forward)
  451. network_apply_weights(self)
  452. return originals.Conv2d_forward(self, input)
  453. def network_Conv2d_load_state_dict(self, *args, **kwargs):
  454. network_reset_cached_weight(self)
  455. return originals.Conv2d_load_state_dict(self, *args, **kwargs)
  456. def network_GroupNorm_forward(self, input):
  457. if shared.opts.lora_functional:
  458. return network_forward(self, input, originals.GroupNorm_forward)
  459. network_apply_weights(self)
  460. return originals.GroupNorm_forward(self, input)
  461. def network_GroupNorm_load_state_dict(self, *args, **kwargs):
  462. network_reset_cached_weight(self)
  463. return originals.GroupNorm_load_state_dict(self, *args, **kwargs)
  464. def network_LayerNorm_forward(self, input):
  465. if shared.opts.lora_functional:
  466. return network_forward(self, input, originals.LayerNorm_forward)
  467. network_apply_weights(self)
  468. return originals.LayerNorm_forward(self, input)
  469. def network_LayerNorm_load_state_dict(self, *args, **kwargs):
  470. network_reset_cached_weight(self)
  471. return originals.LayerNorm_load_state_dict(self, *args, **kwargs)
  472. def network_MultiheadAttention_forward(self, *args, **kwargs):
  473. network_apply_weights(self)
  474. return originals.MultiheadAttention_forward(self, *args, **kwargs)
  475. def network_MultiheadAttention_load_state_dict(self, *args, **kwargs):
  476. network_reset_cached_weight(self)
  477. return originals.MultiheadAttention_load_state_dict(self, *args, **kwargs)
  478. def process_network_files(names: list[str] | None = None):
  479. candidates = list(shared.walk_files(shared.cmd_opts.lora_dir, allowed_extensions=[".pt", ".ckpt", ".safetensors"]))
  480. candidates += list(shared.walk_files(shared.cmd_opts.lyco_dir_backcompat, allowed_extensions=[".pt", ".ckpt", ".safetensors"]))
  481. for filename in candidates:
  482. if os.path.isdir(filename):
  483. continue
  484. name = os.path.splitext(os.path.basename(filename))[0]
  485. # if names is provided, only load networks with names in the list
  486. if names and name not in names:
  487. continue
  488. try:
  489. entry = network.NetworkOnDisk(name, filename)
  490. except OSError: # should catch FileNotFoundError and PermissionError etc.
  491. errors.report(f"Failed to load network {name} from {filename}", exc_info=True)
  492. continue
  493. available_networks[name] = entry
  494. if entry.alias in available_network_aliases:
  495. forbidden_network_aliases[entry.alias.lower()] = 1
  496. available_network_aliases[name] = entry
  497. available_network_aliases[entry.alias] = entry
  498. def update_available_networks_by_names(names: list[str]):
  499. process_network_files(names)
  500. def list_available_networks():
  501. available_networks.clear()
  502. available_network_aliases.clear()
  503. forbidden_network_aliases.clear()
  504. available_network_hash_lookup.clear()
  505. forbidden_network_aliases.update({"none": 1, "Addams": 1})
  506. os.makedirs(shared.cmd_opts.lora_dir, exist_ok=True)
  507. process_network_files()
  508. re_network_name = re.compile(r"(.*)\s*\([0-9a-fA-F]+\)")
  509. def infotext_pasted(infotext, params):
  510. if "AddNet Module 1" in [x[1] for x in scripts.scripts_txt2img.infotext_fields]:
  511. return # if the other extension is active, it will handle those fields, no need to do anything
  512. added = []
  513. for k in params:
  514. if not k.startswith("AddNet Model "):
  515. continue
  516. num = k[13:]
  517. if params.get("AddNet Module " + num) != "LoRA":
  518. continue
  519. name = params.get("AddNet Model " + num)
  520. if name is None:
  521. continue
  522. m = re_network_name.match(name)
  523. if m:
  524. name = m.group(1)
  525. multiplier = params.get("AddNet Weight A " + num, "1.0")
  526. added.append(f"<lora:{name}:{multiplier}>")
  527. if added:
  528. params["Prompt"] += "\n" + "".join(added)
  529. originals: lora_patches.LoraPatches = None
  530. extra_network_lora = None
  531. available_networks = {}
  532. available_network_aliases = {}
  533. loaded_networks = []
  534. loaded_bundle_embeddings = {}
  535. networks_in_memory = {}
  536. available_network_hash_lookup = {}
  537. forbidden_network_aliases = {}
  538. list_available_networks()