networks.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571
  1. import logging
  2. import os
  3. import re
  4. import lora_patches
  5. import network
  6. import network_lora
  7. import network_hada
  8. import network_ia3
  9. import network_lokr
  10. import network_full
  11. import network_norm
  12. import torch
  13. from typing import Union
  14. from modules import shared, devices, sd_models, errors, scripts, sd_hijack
  15. module_types = [
  16. network_lora.ModuleTypeLora(),
  17. network_hada.ModuleTypeHada(),
  18. network_ia3.ModuleTypeIa3(),
  19. network_lokr.ModuleTypeLokr(),
  20. network_full.ModuleTypeFull(),
  21. network_norm.ModuleTypeNorm(),
  22. ]
  23. re_digits = re.compile(r"\d+")
  24. re_x_proj = re.compile(r"(.*)_([qkv]_proj)$")
  25. re_compiled = {}
  26. suffix_conversion = {
  27. "attentions": {},
  28. "resnets": {
  29. "conv1": "in_layers_2",
  30. "conv2": "out_layers_3",
  31. "norm1": "in_layers_0",
  32. "norm2": "out_layers_0",
  33. "time_emb_proj": "emb_layers_1",
  34. "conv_shortcut": "skip_connection",
  35. }
  36. }
  37. def convert_diffusers_name_to_compvis(key, is_sd2):
  38. def match(match_list, regex_text):
  39. regex = re_compiled.get(regex_text)
  40. if regex is None:
  41. regex = re.compile(regex_text)
  42. re_compiled[regex_text] = regex
  43. r = re.match(regex, key)
  44. if not r:
  45. return False
  46. match_list.clear()
  47. match_list.extend([int(x) if re.match(re_digits, x) else x for x in r.groups()])
  48. return True
  49. m = []
  50. if match(m, r"lora_unet_conv_in(.*)"):
  51. return f'diffusion_model_input_blocks_0_0{m[0]}'
  52. if match(m, r"lora_unet_conv_out(.*)"):
  53. return f'diffusion_model_out_2{m[0]}'
  54. if match(m, r"lora_unet_time_embedding_linear_(\d+)(.*)"):
  55. return f"diffusion_model_time_embed_{m[0] * 2 - 2}{m[1]}"
  56. if match(m, r"lora_unet_down_blocks_(\d+)_(attentions|resnets)_(\d+)_(.+)"):
  57. suffix = suffix_conversion.get(m[1], {}).get(m[3], m[3])
  58. return f"diffusion_model_input_blocks_{1 + m[0] * 3 + m[2]}_{1 if m[1] == 'attentions' else 0}_{suffix}"
  59. if match(m, r"lora_unet_mid_block_(attentions|resnets)_(\d+)_(.+)"):
  60. suffix = suffix_conversion.get(m[0], {}).get(m[2], m[2])
  61. return f"diffusion_model_middle_block_{1 if m[0] == 'attentions' else m[1] * 2}_{suffix}"
  62. if match(m, r"lora_unet_up_blocks_(\d+)_(attentions|resnets)_(\d+)_(.+)"):
  63. suffix = suffix_conversion.get(m[1], {}).get(m[3], m[3])
  64. return f"diffusion_model_output_blocks_{m[0] * 3 + m[2]}_{1 if m[1] == 'attentions' else 0}_{suffix}"
  65. if match(m, r"lora_unet_down_blocks_(\d+)_downsamplers_0_conv"):
  66. return f"diffusion_model_input_blocks_{3 + m[0] * 3}_0_op"
  67. if match(m, r"lora_unet_up_blocks_(\d+)_upsamplers_0_conv"):
  68. return f"diffusion_model_output_blocks_{2 + m[0] * 3}_{2 if m[0]>0 else 1}_conv"
  69. if match(m, r"lora_te_text_model_encoder_layers_(\d+)_(.+)"):
  70. if is_sd2:
  71. if 'mlp_fc1' in m[1]:
  72. return f"model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc1', 'mlp_c_fc')}"
  73. elif 'mlp_fc2' in m[1]:
  74. return f"model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc2', 'mlp_c_proj')}"
  75. else:
  76. return f"model_transformer_resblocks_{m[0]}_{m[1].replace('self_attn', 'attn')}"
  77. return f"transformer_text_model_encoder_layers_{m[0]}_{m[1]}"
  78. if match(m, r"lora_te2_text_model_encoder_layers_(\d+)_(.+)"):
  79. if 'mlp_fc1' in m[1]:
  80. return f"1_model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc1', 'mlp_c_fc')}"
  81. elif 'mlp_fc2' in m[1]:
  82. return f"1_model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc2', 'mlp_c_proj')}"
  83. else:
  84. return f"1_model_transformer_resblocks_{m[0]}_{m[1].replace('self_attn', 'attn')}"
  85. return key
  86. def assign_network_names_to_compvis_modules(sd_model):
  87. network_layer_mapping = {}
  88. if shared.sd_model.is_sdxl:
  89. for i, embedder in enumerate(shared.sd_model.conditioner.embedders):
  90. if not hasattr(embedder, 'wrapped'):
  91. continue
  92. for name, module in embedder.wrapped.named_modules():
  93. network_name = f'{i}_{name.replace(".", "_")}'
  94. network_layer_mapping[network_name] = module
  95. module.network_layer_name = network_name
  96. else:
  97. for name, module in shared.sd_model.cond_stage_model.wrapped.named_modules():
  98. network_name = name.replace(".", "_")
  99. network_layer_mapping[network_name] = module
  100. module.network_layer_name = network_name
  101. for name, module in shared.sd_model.model.named_modules():
  102. network_name = name.replace(".", "_")
  103. network_layer_mapping[network_name] = module
  104. module.network_layer_name = network_name
  105. sd_model.network_layer_mapping = network_layer_mapping
  106. def load_network(name, network_on_disk):
  107. net = network.Network(name, network_on_disk)
  108. net.mtime = os.path.getmtime(network_on_disk.filename)
  109. sd = sd_models.read_state_dict(network_on_disk.filename)
  110. # this should not be needed but is here as an emergency fix for an unknown error people are experiencing in 1.2.0
  111. if not hasattr(shared.sd_model, 'network_layer_mapping'):
  112. assign_network_names_to_compvis_modules(shared.sd_model)
  113. keys_failed_to_match = {}
  114. is_sd2 = 'model_transformer_resblocks' in shared.sd_model.network_layer_mapping
  115. matched_networks = {}
  116. for key_network, weight in sd.items():
  117. key_network_without_network_parts, network_part = key_network.split(".", 1)
  118. key = convert_diffusers_name_to_compvis(key_network_without_network_parts, is_sd2)
  119. sd_module = shared.sd_model.network_layer_mapping.get(key, None)
  120. if sd_module is None:
  121. m = re_x_proj.match(key)
  122. if m:
  123. sd_module = shared.sd_model.network_layer_mapping.get(m.group(1), None)
  124. # SDXL loras seem to already have correct compvis keys, so only need to replace "lora_unet" with "diffusion_model"
  125. if sd_module is None and "lora_unet" in key_network_without_network_parts:
  126. key = key_network_without_network_parts.replace("lora_unet", "diffusion_model")
  127. sd_module = shared.sd_model.network_layer_mapping.get(key, None)
  128. elif sd_module is None and "lora_te1_text_model" in key_network_without_network_parts:
  129. key = key_network_without_network_parts.replace("lora_te1_text_model", "0_transformer_text_model")
  130. sd_module = shared.sd_model.network_layer_mapping.get(key, None)
  131. # some SD1 Loras also have correct compvis keys
  132. if sd_module is None:
  133. key = key_network_without_network_parts.replace("lora_te1_text_model", "transformer_text_model")
  134. sd_module = shared.sd_model.network_layer_mapping.get(key, None)
  135. if sd_module is None:
  136. keys_failed_to_match[key_network] = key
  137. continue
  138. if key not in matched_networks:
  139. matched_networks[key] = network.NetworkWeights(network_key=key_network, sd_key=key, w={}, sd_module=sd_module)
  140. matched_networks[key].w[network_part] = weight
  141. for key, weights in matched_networks.items():
  142. net_module = None
  143. for nettype in module_types:
  144. net_module = nettype.create_module(net, weights)
  145. if net_module is not None:
  146. break
  147. if net_module is None:
  148. raise AssertionError(f"Could not find a module type (out of {', '.join([x.__class__.__name__ for x in module_types])}) that would accept those keys: {', '.join(weights.w)}")
  149. net.modules[key] = net_module
  150. if keys_failed_to_match:
  151. logging.debug(f"Network {network_on_disk.filename} didn't match keys: {keys_failed_to_match}")
  152. return net
  153. def purge_networks_from_memory():
  154. while len(networks_in_memory) > shared.opts.lora_in_memory_limit and len(networks_in_memory) > 0:
  155. name = next(iter(networks_in_memory))
  156. networks_in_memory.pop(name, None)
  157. devices.torch_gc()
  158. def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=None):
  159. already_loaded = {}
  160. for net in loaded_networks:
  161. if net.name in names:
  162. already_loaded[net.name] = net
  163. loaded_networks.clear()
  164. networks_on_disk = [available_network_aliases.get(name, None) for name in names]
  165. if any(x is None for x in networks_on_disk):
  166. list_available_networks()
  167. networks_on_disk = [available_network_aliases.get(name, None) for name in names]
  168. failed_to_load_networks = []
  169. for i, (network_on_disk, name) in enumerate(zip(networks_on_disk, names)):
  170. net = already_loaded.get(name, None)
  171. if network_on_disk is not None:
  172. if net is None:
  173. net = networks_in_memory.get(name)
  174. if net is None or os.path.getmtime(network_on_disk.filename) > net.mtime:
  175. try:
  176. net = load_network(name, network_on_disk)
  177. networks_in_memory.pop(name, None)
  178. networks_in_memory[name] = net
  179. except Exception as e:
  180. errors.display(e, f"loading network {network_on_disk.filename}")
  181. continue
  182. net.mentioned_name = name
  183. network_on_disk.read_hash()
  184. if net is None:
  185. failed_to_load_networks.append(name)
  186. logging.info(f"Couldn't find network with name {name}")
  187. continue
  188. net.te_multiplier = te_multipliers[i] if te_multipliers else 1.0
  189. net.unet_multiplier = unet_multipliers[i] if unet_multipliers else 1.0
  190. net.dyn_dim = dyn_dims[i] if dyn_dims else 1.0
  191. loaded_networks.append(net)
  192. if failed_to_load_networks:
  193. sd_hijack.model_hijack.comments.append("Networks not found: " + ", ".join(failed_to_load_networks))
  194. purge_networks_from_memory()
  195. def network_restore_weights_from_backup(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, torch.nn.MultiheadAttention]):
  196. weights_backup = getattr(self, "network_weights_backup", None)
  197. bias_backup = getattr(self, "network_bias_backup", None)
  198. if weights_backup is None and bias_backup is None:
  199. return
  200. if weights_backup is not None:
  201. if isinstance(self, torch.nn.MultiheadAttention):
  202. self.in_proj_weight.copy_(weights_backup[0])
  203. self.out_proj.weight.copy_(weights_backup[1])
  204. else:
  205. self.weight.copy_(weights_backup)
  206. if bias_backup is not None:
  207. if isinstance(self, torch.nn.MultiheadAttention):
  208. self.out_proj.bias.copy_(bias_backup)
  209. else:
  210. self.bias.copy_(bias_backup)
  211. else:
  212. if isinstance(self, torch.nn.MultiheadAttention):
  213. self.out_proj.bias = None
  214. else:
  215. self.bias = None
  216. def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, torch.nn.MultiheadAttention]):
  217. """
  218. Applies the currently selected set of networks to the weights of torch layer self.
  219. If weights already have this particular set of networks applied, does nothing.
  220. If not, restores orginal weights from backup and alters weights according to networks.
  221. """
  222. network_layer_name = getattr(self, 'network_layer_name', None)
  223. if network_layer_name is None:
  224. return
  225. current_names = getattr(self, "network_current_names", ())
  226. wanted_names = tuple((x.name, x.te_multiplier, x.unet_multiplier, x.dyn_dim) for x in loaded_networks)
  227. weights_backup = getattr(self, "network_weights_backup", None)
  228. if weights_backup is None and wanted_names != ():
  229. if current_names != ():
  230. raise RuntimeError("no backup weights found and current weights are not unchanged")
  231. if isinstance(self, torch.nn.MultiheadAttention):
  232. weights_backup = (self.in_proj_weight.to(devices.cpu, copy=True), self.out_proj.weight.to(devices.cpu, copy=True))
  233. else:
  234. weights_backup = self.weight.to(devices.cpu, copy=True)
  235. self.network_weights_backup = weights_backup
  236. bias_backup = getattr(self, "network_bias_backup", None)
  237. if bias_backup is None:
  238. if isinstance(self, torch.nn.MultiheadAttention) and self.out_proj.bias is not None:
  239. bias_backup = self.out_proj.bias.to(devices.cpu, copy=True)
  240. elif getattr(self, 'bias', None) is not None:
  241. bias_backup = self.bias.to(devices.cpu, copy=True)
  242. else:
  243. bias_backup = None
  244. self.network_bias_backup = bias_backup
  245. if current_names != wanted_names:
  246. network_restore_weights_from_backup(self)
  247. for net in loaded_networks:
  248. module = net.modules.get(network_layer_name, None)
  249. if module is not None and hasattr(self, 'weight'):
  250. try:
  251. with torch.no_grad():
  252. updown, ex_bias = module.calc_updown(self.weight)
  253. if len(self.weight.shape) == 4 and self.weight.shape[1] == 9:
  254. # inpainting model. zero pad updown to make channel[1] 4 to 9
  255. updown = torch.nn.functional.pad(updown, (0, 0, 0, 0, 0, 5))
  256. self.weight += updown
  257. if ex_bias is not None and hasattr(self, 'bias'):
  258. if self.bias is None:
  259. self.bias = torch.nn.Parameter(ex_bias)
  260. else:
  261. self.bias += ex_bias
  262. except RuntimeError as e:
  263. logging.debug(f"Network {net.name} layer {network_layer_name}: {e}")
  264. extra_network_lora.errors[net.name] = extra_network_lora.errors.get(net.name, 0) + 1
  265. continue
  266. module_q = net.modules.get(network_layer_name + "_q_proj", None)
  267. module_k = net.modules.get(network_layer_name + "_k_proj", None)
  268. module_v = net.modules.get(network_layer_name + "_v_proj", None)
  269. module_out = net.modules.get(network_layer_name + "_out_proj", None)
  270. if isinstance(self, torch.nn.MultiheadAttention) and module_q and module_k and module_v and module_out:
  271. try:
  272. with torch.no_grad():
  273. updown_q, _ = module_q.calc_updown(self.in_proj_weight)
  274. updown_k, _ = module_k.calc_updown(self.in_proj_weight)
  275. updown_v, _ = module_v.calc_updown(self.in_proj_weight)
  276. updown_qkv = torch.vstack([updown_q, updown_k, updown_v])
  277. updown_out, ex_bias = module_out.calc_updown(self.out_proj.weight)
  278. self.in_proj_weight += updown_qkv
  279. self.out_proj.weight += updown_out
  280. if ex_bias is not None:
  281. if self.out_proj.bias is None:
  282. self.out_proj.bias = torch.nn.Parameter(ex_bias)
  283. else:
  284. self.out_proj.bias += ex_bias
  285. except RuntimeError as e:
  286. logging.debug(f"Network {net.name} layer {network_layer_name}: {e}")
  287. extra_network_lora.errors[net.name] = extra_network_lora.errors.get(net.name, 0) + 1
  288. continue
  289. if module is None:
  290. continue
  291. logging.debug(f"Network {net.name} layer {network_layer_name}: couldn't find supported operation")
  292. extra_network_lora.errors[net.name] = extra_network_lora.errors.get(net.name, 0) + 1
  293. self.network_current_names = wanted_names
  294. def network_forward(module, input, original_forward):
  295. """
  296. Old way of applying Lora by executing operations during layer's forward.
  297. Stacking many loras this way results in big performance degradation.
  298. """
  299. if len(loaded_networks) == 0:
  300. return original_forward(module, input)
  301. input = devices.cond_cast_unet(input)
  302. network_restore_weights_from_backup(module)
  303. network_reset_cached_weight(module)
  304. y = original_forward(module, input)
  305. network_layer_name = getattr(module, 'network_layer_name', None)
  306. for lora in loaded_networks:
  307. module = lora.modules.get(network_layer_name, None)
  308. if module is None:
  309. continue
  310. y = module.forward(input, y)
  311. return y
  312. def network_reset_cached_weight(self: Union[torch.nn.Conv2d, torch.nn.Linear]):
  313. self.network_current_names = ()
  314. self.network_weights_backup = None
  315. def network_Linear_forward(self, input):
  316. if shared.opts.lora_functional:
  317. return network_forward(self, input, originals.Linear_forward)
  318. network_apply_weights(self)
  319. return originals.Linear_forward(self, input)
  320. def network_Linear_load_state_dict(self, *args, **kwargs):
  321. network_reset_cached_weight(self)
  322. return originals.Linear_load_state_dict(self, *args, **kwargs)
  323. def network_Conv2d_forward(self, input):
  324. if shared.opts.lora_functional:
  325. return network_forward(self, input, originals.Conv2d_forward)
  326. network_apply_weights(self)
  327. return originals.Conv2d_forward(self, input)
  328. def network_Conv2d_load_state_dict(self, *args, **kwargs):
  329. network_reset_cached_weight(self)
  330. return originals.Conv2d_load_state_dict(self, *args, **kwargs)
  331. def network_GroupNorm_forward(self, input):
  332. if shared.opts.lora_functional:
  333. return network_forward(self, input, originals.GroupNorm_forward)
  334. network_apply_weights(self)
  335. return originals.GroupNorm_forward(self, input)
  336. def network_GroupNorm_load_state_dict(self, *args, **kwargs):
  337. network_reset_cached_weight(self)
  338. return originals.GroupNorm_load_state_dict(self, *args, **kwargs)
  339. def network_LayerNorm_forward(self, input):
  340. if shared.opts.lora_functional:
  341. return network_forward(self, input, originals.LayerNorm_forward)
  342. network_apply_weights(self)
  343. return originals.LayerNorm_forward(self, input)
  344. def network_LayerNorm_load_state_dict(self, *args, **kwargs):
  345. network_reset_cached_weight(self)
  346. return originals.LayerNorm_load_state_dict(self, *args, **kwargs)
  347. def network_MultiheadAttention_forward(self, *args, **kwargs):
  348. network_apply_weights(self)
  349. return originals.MultiheadAttention_forward(self, *args, **kwargs)
  350. def network_MultiheadAttention_load_state_dict(self, *args, **kwargs):
  351. network_reset_cached_weight(self)
  352. return originals.MultiheadAttention_load_state_dict(self, *args, **kwargs)
  353. def list_available_networks():
  354. available_networks.clear()
  355. available_network_aliases.clear()
  356. forbidden_network_aliases.clear()
  357. available_network_hash_lookup.clear()
  358. forbidden_network_aliases.update({"none": 1, "Addams": 1})
  359. os.makedirs(shared.cmd_opts.lora_dir, exist_ok=True)
  360. candidates = list(shared.walk_files(shared.cmd_opts.lora_dir, allowed_extensions=[".pt", ".ckpt", ".safetensors"]))
  361. candidates += list(shared.walk_files(shared.cmd_opts.lyco_dir_backcompat, allowed_extensions=[".pt", ".ckpt", ".safetensors"]))
  362. for filename in candidates:
  363. if os.path.isdir(filename):
  364. continue
  365. name = os.path.splitext(os.path.basename(filename))[0]
  366. try:
  367. entry = network.NetworkOnDisk(name, filename)
  368. except OSError: # should catch FileNotFoundError and PermissionError etc.
  369. errors.report(f"Failed to load network {name} from {filename}", exc_info=True)
  370. continue
  371. available_networks[name] = entry
  372. if entry.alias in available_network_aliases:
  373. forbidden_network_aliases[entry.alias.lower()] = 1
  374. available_network_aliases[name] = entry
  375. available_network_aliases[entry.alias] = entry
  376. re_network_name = re.compile(r"(.*)\s*\([0-9a-fA-F]+\)")
  377. def infotext_pasted(infotext, params):
  378. if "AddNet Module 1" in [x[1] for x in scripts.scripts_txt2img.infotext_fields]:
  379. return # if the other extension is active, it will handle those fields, no need to do anything
  380. added = []
  381. for k in params:
  382. if not k.startswith("AddNet Model "):
  383. continue
  384. num = k[13:]
  385. if params.get("AddNet Module " + num) != "LoRA":
  386. continue
  387. name = params.get("AddNet Model " + num)
  388. if name is None:
  389. continue
  390. m = re_network_name.match(name)
  391. if m:
  392. name = m.group(1)
  393. multiplier = params.get("AddNet Weight A " + num, "1.0")
  394. added.append(f"<lora:{name}:{multiplier}>")
  395. if added:
  396. params["Prompt"] += "\n" + "".join(added)
  397. originals: lora_patches.LoraPatches = None
  398. extra_network_lora = None
  399. available_networks = {}
  400. available_network_aliases = {}
  401. loaded_networks = []
  402. networks_in_memory = {}
  403. available_network_hash_lookup = {}
  404. forbidden_network_aliases = {}
  405. list_available_networks()