networks.py 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641
  1. import logging
  2. import os
  3. import re
  4. import lora_patches
  5. import network
  6. import network_lora
  7. import network_hada
  8. import network_ia3
  9. import network_lokr
  10. import network_full
  11. import network_norm
  12. import torch
  13. from typing import Union
  14. from modules import shared, devices, sd_models, errors, scripts, sd_hijack
  15. from modules.textual_inversion.textual_inversion import Embedding
  16. from lora_logger import logger
  17. module_types = [
  18. network_lora.ModuleTypeLora(),
  19. network_hada.ModuleTypeHada(),
  20. network_ia3.ModuleTypeIa3(),
  21. network_lokr.ModuleTypeLokr(),
  22. network_full.ModuleTypeFull(),
  23. network_norm.ModuleTypeNorm(),
  24. ]
  25. re_digits = re.compile(r"\d+")
  26. re_x_proj = re.compile(r"(.*)_([qkv]_proj)$")
  27. re_compiled = {}
  28. suffix_conversion = {
  29. "attentions": {},
  30. "resnets": {
  31. "conv1": "in_layers_2",
  32. "conv2": "out_layers_3",
  33. "norm1": "in_layers_0",
  34. "norm2": "out_layers_0",
  35. "time_emb_proj": "emb_layers_1",
  36. "conv_shortcut": "skip_connection",
  37. }
  38. }
  39. def convert_diffusers_name_to_compvis(key, is_sd2):
  40. def match(match_list, regex_text):
  41. regex = re_compiled.get(regex_text)
  42. if regex is None:
  43. regex = re.compile(regex_text)
  44. re_compiled[regex_text] = regex
  45. r = re.match(regex, key)
  46. if not r:
  47. return False
  48. match_list.clear()
  49. match_list.extend([int(x) if re.match(re_digits, x) else x for x in r.groups()])
  50. return True
  51. m = []
  52. if match(m, r"lora_unet_conv_in(.*)"):
  53. return f'diffusion_model_input_blocks_0_0{m[0]}'
  54. if match(m, r"lora_unet_conv_out(.*)"):
  55. return f'diffusion_model_out_2{m[0]}'
  56. if match(m, r"lora_unet_time_embedding_linear_(\d+)(.*)"):
  57. return f"diffusion_model_time_embed_{m[0] * 2 - 2}{m[1]}"
  58. if match(m, r"lora_unet_down_blocks_(\d+)_(attentions|resnets)_(\d+)_(.+)"):
  59. suffix = suffix_conversion.get(m[1], {}).get(m[3], m[3])
  60. return f"diffusion_model_input_blocks_{1 + m[0] * 3 + m[2]}_{1 if m[1] == 'attentions' else 0}_{suffix}"
  61. if match(m, r"lora_unet_mid_block_(attentions|resnets)_(\d+)_(.+)"):
  62. suffix = suffix_conversion.get(m[0], {}).get(m[2], m[2])
  63. return f"diffusion_model_middle_block_{1 if m[0] == 'attentions' else m[1] * 2}_{suffix}"
  64. if match(m, r"lora_unet_up_blocks_(\d+)_(attentions|resnets)_(\d+)_(.+)"):
  65. suffix = suffix_conversion.get(m[1], {}).get(m[3], m[3])
  66. return f"diffusion_model_output_blocks_{m[0] * 3 + m[2]}_{1 if m[1] == 'attentions' else 0}_{suffix}"
  67. if match(m, r"lora_unet_down_blocks_(\d+)_downsamplers_0_conv"):
  68. return f"diffusion_model_input_blocks_{3 + m[0] * 3}_0_op"
  69. if match(m, r"lora_unet_up_blocks_(\d+)_upsamplers_0_conv"):
  70. return f"diffusion_model_output_blocks_{2 + m[0] * 3}_{2 if m[0]>0 else 1}_conv"
  71. if match(m, r"lora_te_text_model_encoder_layers_(\d+)_(.+)"):
  72. if is_sd2:
  73. if 'mlp_fc1' in m[1]:
  74. return f"model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc1', 'mlp_c_fc')}"
  75. elif 'mlp_fc2' in m[1]:
  76. return f"model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc2', 'mlp_c_proj')}"
  77. else:
  78. return f"model_transformer_resblocks_{m[0]}_{m[1].replace('self_attn', 'attn')}"
  79. return f"transformer_text_model_encoder_layers_{m[0]}_{m[1]}"
  80. if match(m, r"lora_te2_text_model_encoder_layers_(\d+)_(.+)"):
  81. if 'mlp_fc1' in m[1]:
  82. return f"1_model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc1', 'mlp_c_fc')}"
  83. elif 'mlp_fc2' in m[1]:
  84. return f"1_model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc2', 'mlp_c_proj')}"
  85. else:
  86. return f"1_model_transformer_resblocks_{m[0]}_{m[1].replace('self_attn', 'attn')}"
  87. return key
  88. def assign_network_names_to_compvis_modules(sd_model):
  89. network_layer_mapping = {}
  90. if shared.sd_model.is_sdxl:
  91. for i, embedder in enumerate(shared.sd_model.conditioner.embedders):
  92. if not hasattr(embedder, 'wrapped'):
  93. continue
  94. for name, module in embedder.wrapped.named_modules():
  95. network_name = f'{i}_{name.replace(".", "_")}'
  96. network_layer_mapping[network_name] = module
  97. module.network_layer_name = network_name
  98. else:
  99. for name, module in shared.sd_model.cond_stage_model.wrapped.named_modules():
  100. network_name = name.replace(".", "_")
  101. network_layer_mapping[network_name] = module
  102. module.network_layer_name = network_name
  103. for name, module in shared.sd_model.model.named_modules():
  104. network_name = name.replace(".", "_")
  105. network_layer_mapping[network_name] = module
  106. module.network_layer_name = network_name
  107. sd_model.network_layer_mapping = network_layer_mapping
  108. def load_network(name, network_on_disk):
  109. net = network.Network(name, network_on_disk)
  110. net.mtime = os.path.getmtime(network_on_disk.filename)
  111. sd = sd_models.read_state_dict(network_on_disk.filename)
  112. # this should not be needed but is here as an emergency fix for an unknown error people are experiencing in 1.2.0
  113. if not hasattr(shared.sd_model, 'network_layer_mapping'):
  114. assign_network_names_to_compvis_modules(shared.sd_model)
  115. keys_failed_to_match = {}
  116. is_sd2 = 'model_transformer_resblocks' in shared.sd_model.network_layer_mapping
  117. matched_networks = {}
  118. bundle_embeddings = {}
  119. for key_network, weight in sd.items():
  120. key_network_without_network_parts, network_part = key_network.split(".", 1)
  121. if key_network_without_network_parts == "bundle_emb":
  122. emb_name, vec_name = network_part.split(".", 1)
  123. emb_dict = bundle_embeddings.get(emb_name, {})
  124. if vec_name.split('.')[0] == 'string_to_param':
  125. _, k2 = vec_name.split('.', 1)
  126. emb_dict['string_to_param'] = {k2: weight}
  127. else:
  128. emb_dict[vec_name] = weight
  129. bundle_embeddings[emb_name] = emb_dict
  130. key = convert_diffusers_name_to_compvis(key_network_without_network_parts, is_sd2)
  131. sd_module = shared.sd_model.network_layer_mapping.get(key, None)
  132. if sd_module is None:
  133. m = re_x_proj.match(key)
  134. if m:
  135. sd_module = shared.sd_model.network_layer_mapping.get(m.group(1), None)
  136. # SDXL loras seem to already have correct compvis keys, so only need to replace "lora_unet" with "diffusion_model"
  137. if sd_module is None and "lora_unet" in key_network_without_network_parts:
  138. key = key_network_without_network_parts.replace("lora_unet", "diffusion_model")
  139. sd_module = shared.sd_model.network_layer_mapping.get(key, None)
  140. elif sd_module is None and "lora_te1_text_model" in key_network_without_network_parts:
  141. key = key_network_without_network_parts.replace("lora_te1_text_model", "0_transformer_text_model")
  142. sd_module = shared.sd_model.network_layer_mapping.get(key, None)
  143. # some SD1 Loras also have correct compvis keys
  144. if sd_module is None:
  145. key = key_network_without_network_parts.replace("lora_te1_text_model", "transformer_text_model")
  146. sd_module = shared.sd_model.network_layer_mapping.get(key, None)
  147. if sd_module is None:
  148. keys_failed_to_match[key_network] = key
  149. continue
  150. if key not in matched_networks:
  151. matched_networks[key] = network.NetworkWeights(network_key=key_network, sd_key=key, w={}, sd_module=sd_module)
  152. matched_networks[key].w[network_part] = weight
  153. for key, weights in matched_networks.items():
  154. net_module = None
  155. for nettype in module_types:
  156. net_module = nettype.create_module(net, weights)
  157. if net_module is not None:
  158. break
  159. if net_module is None:
  160. raise AssertionError(f"Could not find a module type (out of {', '.join([x.__class__.__name__ for x in module_types])}) that would accept those keys: {', '.join(weights.w)}")
  161. net.modules[key] = net_module
  162. embeddings = {}
  163. for emb_name, data in bundle_embeddings.items():
  164. # textual inversion embeddings
  165. if 'string_to_param' in data:
  166. param_dict = data['string_to_param']
  167. param_dict = getattr(param_dict, '_parameters', param_dict) # fix for torch 1.12.1 loading saved file from torch 1.11
  168. assert len(param_dict) == 1, 'embedding file has multiple terms in it'
  169. emb = next(iter(param_dict.items()))[1]
  170. vec = emb.detach().to(devices.device, dtype=torch.float32)
  171. shape = vec.shape[-1]
  172. vectors = vec.shape[0]
  173. elif type(data) == dict and 'clip_g' in data and 'clip_l' in data: # SDXL embedding
  174. vec = {k: v.detach().to(devices.device, dtype=torch.float32) for k, v in data.items()}
  175. shape = data['clip_g'].shape[-1] + data['clip_l'].shape[-1]
  176. vectors = data['clip_g'].shape[0]
  177. elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor: # diffuser concepts
  178. assert len(data.keys()) == 1, 'embedding file has multiple terms in it'
  179. emb = next(iter(data.values()))
  180. if len(emb.shape) == 1:
  181. emb = emb.unsqueeze(0)
  182. vec = emb.detach().to(devices.device, dtype=torch.float32)
  183. shape = vec.shape[-1]
  184. vectors = vec.shape[0]
  185. else:
  186. raise Exception(f"Couldn't identify {emb_name} in lora: {name} as neither textual inversion embedding nor diffuser concept.")
  187. embedding = Embedding(vec, emb_name)
  188. embedding.vectors = vectors
  189. embedding.shape = shape
  190. embedding.loaded = None
  191. embeddings[emb_name] = embedding
  192. net.bundle_embeddings = embeddings
  193. if keys_failed_to_match:
  194. logging.debug(f"Network {network_on_disk.filename} didn't match keys: {keys_failed_to_match}")
  195. return net
  196. def purge_networks_from_memory():
  197. while len(networks_in_memory) > shared.opts.lora_in_memory_limit and len(networks_in_memory) > 0:
  198. name = next(iter(networks_in_memory))
  199. networks_in_memory.pop(name, None)
  200. devices.torch_gc()
  201. def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=None):
  202. emb_db = sd_hijack.model_hijack.embedding_db
  203. already_loaded = {}
  204. for net in loaded_networks:
  205. if net.name in names:
  206. already_loaded[net.name] = net
  207. for emb_name, embedding in net.bundle_embeddings.items():
  208. if embedding.loaded:
  209. embedding.loaded = None
  210. emb_db.register_embedding_by_name(None, shared.sd_model, emb_name)
  211. loaded_networks.clear()
  212. networks_on_disk = [available_network_aliases.get(name, None) for name in names]
  213. if any(x is None for x in networks_on_disk):
  214. list_available_networks()
  215. networks_on_disk = [available_network_aliases.get(name, None) for name in names]
  216. failed_to_load_networks = []
  217. for i, (network_on_disk, name) in enumerate(zip(networks_on_disk, names)):
  218. net = already_loaded.get(name, None)
  219. if network_on_disk is not None:
  220. if net is None:
  221. net = networks_in_memory.get(name)
  222. if net is None or os.path.getmtime(network_on_disk.filename) > net.mtime:
  223. try:
  224. net = load_network(name, network_on_disk)
  225. networks_in_memory.pop(name, None)
  226. networks_in_memory[name] = net
  227. except Exception as e:
  228. errors.display(e, f"loading network {network_on_disk.filename}")
  229. continue
  230. net.mentioned_name = name
  231. network_on_disk.read_hash()
  232. if net is None:
  233. failed_to_load_networks.append(name)
  234. logging.info(f"Couldn't find network with name {name}")
  235. continue
  236. net.te_multiplier = te_multipliers[i] if te_multipliers else 1.0
  237. net.unet_multiplier = unet_multipliers[i] if unet_multipliers else 1.0
  238. net.dyn_dim = dyn_dims[i] if dyn_dims else 1.0
  239. loaded_networks.append(net)
  240. for emb_name, embedding in net.bundle_embeddings.items():
  241. if embedding.loaded is None and emb_name in emb_db.word_embeddings:
  242. logger.warning(
  243. f'Skip bundle embedding: "{emb_name}"'
  244. ' as it was already loaded from embeddings folder'
  245. )
  246. continue
  247. embedding.loaded = False
  248. if emb_db.expected_shape == -1 or emb_db.expected_shape == embedding.shape:
  249. embedding.loaded = True
  250. emb_db.register_embedding(embedding, shared.sd_model)
  251. else:
  252. emb_db.skipped_embeddings[name] = embedding
  253. if failed_to_load_networks:
  254. sd_hijack.model_hijack.comments.append("Networks not found: " + ", ".join(failed_to_load_networks))
  255. purge_networks_from_memory()
  256. def network_restore_weights_from_backup(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, torch.nn.MultiheadAttention]):
  257. weights_backup = getattr(self, "network_weights_backup", None)
  258. bias_backup = getattr(self, "network_bias_backup", None)
  259. if weights_backup is None and bias_backup is None:
  260. return
  261. if weights_backup is not None:
  262. if isinstance(self, torch.nn.MultiheadAttention):
  263. self.in_proj_weight.copy_(weights_backup[0])
  264. self.out_proj.weight.copy_(weights_backup[1])
  265. else:
  266. self.weight.copy_(weights_backup)
  267. if bias_backup is not None:
  268. if isinstance(self, torch.nn.MultiheadAttention):
  269. self.out_proj.bias.copy_(bias_backup)
  270. else:
  271. self.bias.copy_(bias_backup)
  272. else:
  273. if isinstance(self, torch.nn.MultiheadAttention):
  274. self.out_proj.bias = None
  275. else:
  276. self.bias = None
  277. def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, torch.nn.MultiheadAttention]):
  278. """
  279. Applies the currently selected set of networks to the weights of torch layer self.
  280. If weights already have this particular set of networks applied, does nothing.
  281. If not, restores orginal weights from backup and alters weights according to networks.
  282. """
  283. network_layer_name = getattr(self, 'network_layer_name', None)
  284. if network_layer_name is None:
  285. return
  286. current_names = getattr(self, "network_current_names", ())
  287. wanted_names = tuple((x.name, x.te_multiplier, x.unet_multiplier, x.dyn_dim) for x in loaded_networks)
  288. weights_backup = getattr(self, "network_weights_backup", None)
  289. if weights_backup is None and wanted_names != ():
  290. if current_names != ():
  291. raise RuntimeError("no backup weights found and current weights are not unchanged")
  292. if isinstance(self, torch.nn.MultiheadAttention):
  293. weights_backup = (self.in_proj_weight.to(devices.cpu, copy=True), self.out_proj.weight.to(devices.cpu, copy=True))
  294. else:
  295. weights_backup = self.weight.to(devices.cpu, copy=True)
  296. self.network_weights_backup = weights_backup
  297. bias_backup = getattr(self, "network_bias_backup", None)
  298. if bias_backup is None:
  299. if isinstance(self, torch.nn.MultiheadAttention) and self.out_proj.bias is not None:
  300. bias_backup = self.out_proj.bias.to(devices.cpu, copy=True)
  301. elif getattr(self, 'bias', None) is not None:
  302. bias_backup = self.bias.to(devices.cpu, copy=True)
  303. else:
  304. bias_backup = None
  305. self.network_bias_backup = bias_backup
  306. if current_names != wanted_names:
  307. network_restore_weights_from_backup(self)
  308. for net in loaded_networks:
  309. module = net.modules.get(network_layer_name, None)
  310. if module is not None and hasattr(self, 'weight'):
  311. try:
  312. with torch.no_grad():
  313. updown, ex_bias = module.calc_updown(self.weight)
  314. if len(self.weight.shape) == 4 and self.weight.shape[1] == 9:
  315. # inpainting model. zero pad updown to make channel[1] 4 to 9
  316. updown = torch.nn.functional.pad(updown, (0, 0, 0, 0, 0, 5))
  317. self.weight += updown
  318. if ex_bias is not None and hasattr(self, 'bias'):
  319. if self.bias is None:
  320. self.bias = torch.nn.Parameter(ex_bias)
  321. else:
  322. self.bias += ex_bias
  323. except RuntimeError as e:
  324. logging.debug(f"Network {net.name} layer {network_layer_name}: {e}")
  325. extra_network_lora.errors[net.name] = extra_network_lora.errors.get(net.name, 0) + 1
  326. continue
  327. module_q = net.modules.get(network_layer_name + "_q_proj", None)
  328. module_k = net.modules.get(network_layer_name + "_k_proj", None)
  329. module_v = net.modules.get(network_layer_name + "_v_proj", None)
  330. module_out = net.modules.get(network_layer_name + "_out_proj", None)
  331. if isinstance(self, torch.nn.MultiheadAttention) and module_q and module_k and module_v and module_out:
  332. try:
  333. with torch.no_grad():
  334. updown_q, _ = module_q.calc_updown(self.in_proj_weight)
  335. updown_k, _ = module_k.calc_updown(self.in_proj_weight)
  336. updown_v, _ = module_v.calc_updown(self.in_proj_weight)
  337. updown_qkv = torch.vstack([updown_q, updown_k, updown_v])
  338. updown_out, ex_bias = module_out.calc_updown(self.out_proj.weight)
  339. self.in_proj_weight += updown_qkv
  340. self.out_proj.weight += updown_out
  341. if ex_bias is not None:
  342. if self.out_proj.bias is None:
  343. self.out_proj.bias = torch.nn.Parameter(ex_bias)
  344. else:
  345. self.out_proj.bias += ex_bias
  346. except RuntimeError as e:
  347. logging.debug(f"Network {net.name} layer {network_layer_name}: {e}")
  348. extra_network_lora.errors[net.name] = extra_network_lora.errors.get(net.name, 0) + 1
  349. continue
  350. if module is None:
  351. continue
  352. logging.debug(f"Network {net.name} layer {network_layer_name}: couldn't find supported operation")
  353. extra_network_lora.errors[net.name] = extra_network_lora.errors.get(net.name, 0) + 1
  354. self.network_current_names = wanted_names
  355. def network_forward(module, input, original_forward):
  356. """
  357. Old way of applying Lora by executing operations during layer's forward.
  358. Stacking many loras this way results in big performance degradation.
  359. """
  360. if len(loaded_networks) == 0:
  361. return original_forward(module, input)
  362. input = devices.cond_cast_unet(input)
  363. network_restore_weights_from_backup(module)
  364. network_reset_cached_weight(module)
  365. y = original_forward(module, input)
  366. network_layer_name = getattr(module, 'network_layer_name', None)
  367. for lora in loaded_networks:
  368. module = lora.modules.get(network_layer_name, None)
  369. if module is None:
  370. continue
  371. y = module.forward(input, y)
  372. return y
  373. def network_reset_cached_weight(self: Union[torch.nn.Conv2d, torch.nn.Linear]):
  374. self.network_current_names = ()
  375. self.network_weights_backup = None
  376. self.network_bias_backup = None
  377. def network_Linear_forward(self, input):
  378. if shared.opts.lora_functional:
  379. return network_forward(self, input, originals.Linear_forward)
  380. network_apply_weights(self)
  381. return originals.Linear_forward(self, input)
  382. def network_Linear_load_state_dict(self, *args, **kwargs):
  383. network_reset_cached_weight(self)
  384. return originals.Linear_load_state_dict(self, *args, **kwargs)
  385. def network_Conv2d_forward(self, input):
  386. if shared.opts.lora_functional:
  387. return network_forward(self, input, originals.Conv2d_forward)
  388. network_apply_weights(self)
  389. return originals.Conv2d_forward(self, input)
  390. def network_Conv2d_load_state_dict(self, *args, **kwargs):
  391. network_reset_cached_weight(self)
  392. return originals.Conv2d_load_state_dict(self, *args, **kwargs)
  393. def network_GroupNorm_forward(self, input):
  394. if shared.opts.lora_functional:
  395. return network_forward(self, input, originals.GroupNorm_forward)
  396. network_apply_weights(self)
  397. return originals.GroupNorm_forward(self, input)
  398. def network_GroupNorm_load_state_dict(self, *args, **kwargs):
  399. network_reset_cached_weight(self)
  400. return originals.GroupNorm_load_state_dict(self, *args, **kwargs)
  401. def network_LayerNorm_forward(self, input):
  402. if shared.opts.lora_functional:
  403. return network_forward(self, input, originals.LayerNorm_forward)
  404. network_apply_weights(self)
  405. return originals.LayerNorm_forward(self, input)
  406. def network_LayerNorm_load_state_dict(self, *args, **kwargs):
  407. network_reset_cached_weight(self)
  408. return originals.LayerNorm_load_state_dict(self, *args, **kwargs)
  409. def network_MultiheadAttention_forward(self, *args, **kwargs):
  410. network_apply_weights(self)
  411. return originals.MultiheadAttention_forward(self, *args, **kwargs)
  412. def network_MultiheadAttention_load_state_dict(self, *args, **kwargs):
  413. network_reset_cached_weight(self)
  414. return originals.MultiheadAttention_load_state_dict(self, *args, **kwargs)
  415. def list_available_networks():
  416. available_networks.clear()
  417. available_network_aliases.clear()
  418. forbidden_network_aliases.clear()
  419. available_network_hash_lookup.clear()
  420. forbidden_network_aliases.update({"none": 1, "Addams": 1})
  421. os.makedirs(shared.cmd_opts.lora_dir, exist_ok=True)
  422. candidates = list(shared.walk_files(shared.cmd_opts.lora_dir, allowed_extensions=[".pt", ".ckpt", ".safetensors"]))
  423. candidates += list(shared.walk_files(shared.cmd_opts.lyco_dir_backcompat, allowed_extensions=[".pt", ".ckpt", ".safetensors"]))
  424. for filename in candidates:
  425. if os.path.isdir(filename):
  426. continue
  427. name = os.path.splitext(os.path.basename(filename))[0]
  428. try:
  429. entry = network.NetworkOnDisk(name, filename)
  430. except OSError: # should catch FileNotFoundError and PermissionError etc.
  431. errors.report(f"Failed to load network {name} from {filename}", exc_info=True)
  432. continue
  433. available_networks[name] = entry
  434. if entry.alias in available_network_aliases:
  435. forbidden_network_aliases[entry.alias.lower()] = 1
  436. available_network_aliases[name] = entry
  437. available_network_aliases[entry.alias] = entry
  438. re_network_name = re.compile(r"(.*)\s*\([0-9a-fA-F]+\)")
  439. def infotext_pasted(infotext, params):
  440. if "AddNet Module 1" in [x[1] for x in scripts.scripts_txt2img.infotext_fields]:
  441. return # if the other extension is active, it will handle those fields, no need to do anything
  442. added = []
  443. for k in params:
  444. if not k.startswith("AddNet Model "):
  445. continue
  446. num = k[13:]
  447. if params.get("AddNet Module " + num) != "LoRA":
  448. continue
  449. name = params.get("AddNet Model " + num)
  450. if name is None:
  451. continue
  452. m = re_network_name.match(name)
  453. if m:
  454. name = m.group(1)
  455. multiplier = params.get("AddNet Weight A " + num, "1.0")
  456. added.append(f"<lora:{name}:{multiplier}>")
  457. if added:
  458. params["Prompt"] += "\n" + "".join(added)
  459. originals: lora_patches.LoraPatches = None
  460. extra_network_lora = None
  461. available_networks = {}
  462. available_network_aliases = {}
  463. loaded_networks = []
  464. loaded_bundle_embeddings = {}
  465. networks_in_memory = {}
  466. available_network_hash_lookup = {}
  467. forbidden_network_aliases = {}
  468. list_available_networks()