123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643 |
- import gradio as gr
- import logging
- import os
- import re
- import lora_patches
- import network
- import network_lora
- import network_glora
- import network_hada
- import network_ia3
- import network_lokr
- import network_full
- import network_norm
- import network_oft
- import torch
- from typing import Union
- from modules import shared, devices, sd_models, errors, scripts, sd_hijack
- import modules.textual_inversion.textual_inversion as textual_inversion
- from lora_logger import logger
- module_types = [
- network_lora.ModuleTypeLora(),
- network_hada.ModuleTypeHada(),
- network_ia3.ModuleTypeIa3(),
- network_lokr.ModuleTypeLokr(),
- network_full.ModuleTypeFull(),
- network_norm.ModuleTypeNorm(),
- network_glora.ModuleTypeGLora(),
- network_oft.ModuleTypeOFT(),
- ]
- re_digits = re.compile(r"\d+")
- re_x_proj = re.compile(r"(.*)_([qkv]_proj)$")
- re_compiled = {}
- suffix_conversion = {
- "attentions": {},
- "resnets": {
- "conv1": "in_layers_2",
- "conv2": "out_layers_3",
- "norm1": "in_layers_0",
- "norm2": "out_layers_0",
- "time_emb_proj": "emb_layers_1",
- "conv_shortcut": "skip_connection",
- }
- }
- def convert_diffusers_name_to_compvis(key, is_sd2):
- def match(match_list, regex_text):
- regex = re_compiled.get(regex_text)
- if regex is None:
- regex = re.compile(regex_text)
- re_compiled[regex_text] = regex
- r = re.match(regex, key)
- if not r:
- return False
- match_list.clear()
- match_list.extend([int(x) if re.match(re_digits, x) else x for x in r.groups()])
- return True
- m = []
- if match(m, r"lora_unet_conv_in(.*)"):
- return f'diffusion_model_input_blocks_0_0{m[0]}'
- if match(m, r"lora_unet_conv_out(.*)"):
- return f'diffusion_model_out_2{m[0]}'
- if match(m, r"lora_unet_time_embedding_linear_(\d+)(.*)"):
- return f"diffusion_model_time_embed_{m[0] * 2 - 2}{m[1]}"
- if match(m, r"lora_unet_down_blocks_(\d+)_(attentions|resnets)_(\d+)_(.+)"):
- suffix = suffix_conversion.get(m[1], {}).get(m[3], m[3])
- return f"diffusion_model_input_blocks_{1 + m[0] * 3 + m[2]}_{1 if m[1] == 'attentions' else 0}_{suffix}"
- if match(m, r"lora_unet_mid_block_(attentions|resnets)_(\d+)_(.+)"):
- suffix = suffix_conversion.get(m[0], {}).get(m[2], m[2])
- return f"diffusion_model_middle_block_{1 if m[0] == 'attentions' else m[1] * 2}_{suffix}"
- if match(m, r"lora_unet_up_blocks_(\d+)_(attentions|resnets)_(\d+)_(.+)"):
- suffix = suffix_conversion.get(m[1], {}).get(m[3], m[3])
- return f"diffusion_model_output_blocks_{m[0] * 3 + m[2]}_{1 if m[1] == 'attentions' else 0}_{suffix}"
- if match(m, r"lora_unet_down_blocks_(\d+)_downsamplers_0_conv"):
- return f"diffusion_model_input_blocks_{3 + m[0] * 3}_0_op"
- if match(m, r"lora_unet_up_blocks_(\d+)_upsamplers_0_conv"):
- return f"diffusion_model_output_blocks_{2 + m[0] * 3}_{2 if m[0]>0 else 1}_conv"
- if match(m, r"lora_te_text_model_encoder_layers_(\d+)_(.+)"):
- if is_sd2:
- if 'mlp_fc1' in m[1]:
- return f"model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc1', 'mlp_c_fc')}"
- elif 'mlp_fc2' in m[1]:
- return f"model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc2', 'mlp_c_proj')}"
- else:
- return f"model_transformer_resblocks_{m[0]}_{m[1].replace('self_attn', 'attn')}"
- return f"transformer_text_model_encoder_layers_{m[0]}_{m[1]}"
- if match(m, r"lora_te2_text_model_encoder_layers_(\d+)_(.+)"):
- if 'mlp_fc1' in m[1]:
- return f"1_model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc1', 'mlp_c_fc')}"
- elif 'mlp_fc2' in m[1]:
- return f"1_model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc2', 'mlp_c_proj')}"
- else:
- return f"1_model_transformer_resblocks_{m[0]}_{m[1].replace('self_attn', 'attn')}"
- return key
- def assign_network_names_to_compvis_modules(sd_model):
- network_layer_mapping = {}
- if shared.sd_model.is_sdxl:
- for i, embedder in enumerate(shared.sd_model.conditioner.embedders):
- if not hasattr(embedder, 'wrapped'):
- continue
- for name, module in embedder.wrapped.named_modules():
- network_name = f'{i}_{name.replace(".", "_")}'
- network_layer_mapping[network_name] = module
- module.network_layer_name = network_name
- else:
- for name, module in shared.sd_model.cond_stage_model.wrapped.named_modules():
- network_name = name.replace(".", "_")
- network_layer_mapping[network_name] = module
- module.network_layer_name = network_name
- for name, module in shared.sd_model.model.named_modules():
- network_name = name.replace(".", "_")
- network_layer_mapping[network_name] = module
- module.network_layer_name = network_name
- sd_model.network_layer_mapping = network_layer_mapping
- def load_network(name, network_on_disk):
- net = network.Network(name, network_on_disk)
- net.mtime = os.path.getmtime(network_on_disk.filename)
- sd = sd_models.read_state_dict(network_on_disk.filename)
- # this should not be needed but is here as an emergency fix for an unknown error people are experiencing in 1.2.0
- if not hasattr(shared.sd_model, 'network_layer_mapping'):
- assign_network_names_to_compvis_modules(shared.sd_model)
- keys_failed_to_match = {}
- is_sd2 = 'model_transformer_resblocks' in shared.sd_model.network_layer_mapping
- matched_networks = {}
- bundle_embeddings = {}
- for key_network, weight in sd.items():
- key_network_without_network_parts, _, network_part = key_network.partition(".")
- if key_network_without_network_parts == "bundle_emb":
- emb_name, vec_name = network_part.split(".", 1)
- emb_dict = bundle_embeddings.get(emb_name, {})
- if vec_name.split('.')[0] == 'string_to_param':
- _, k2 = vec_name.split('.', 1)
- emb_dict['string_to_param'] = {k2: weight}
- else:
- emb_dict[vec_name] = weight
- bundle_embeddings[emb_name] = emb_dict
- key = convert_diffusers_name_to_compvis(key_network_without_network_parts, is_sd2)
- sd_module = shared.sd_model.network_layer_mapping.get(key, None)
- if sd_module is None:
- m = re_x_proj.match(key)
- if m:
- sd_module = shared.sd_model.network_layer_mapping.get(m.group(1), None)
- # SDXL loras seem to already have correct compvis keys, so only need to replace "lora_unet" with "diffusion_model"
- if sd_module is None and "lora_unet" in key_network_without_network_parts:
- key = key_network_without_network_parts.replace("lora_unet", "diffusion_model")
- sd_module = shared.sd_model.network_layer_mapping.get(key, None)
- elif sd_module is None and "lora_te1_text_model" in key_network_without_network_parts:
- key = key_network_without_network_parts.replace("lora_te1_text_model", "0_transformer_text_model")
- sd_module = shared.sd_model.network_layer_mapping.get(key, None)
- # some SD1 Loras also have correct compvis keys
- if sd_module is None:
- key = key_network_without_network_parts.replace("lora_te1_text_model", "transformer_text_model")
- sd_module = shared.sd_model.network_layer_mapping.get(key, None)
- # kohya_ss OFT module
- elif sd_module is None and "oft_unet" in key_network_without_network_parts:
- key = key_network_without_network_parts.replace("oft_unet", "diffusion_model")
- sd_module = shared.sd_model.network_layer_mapping.get(key, None)
- # KohakuBlueLeaf OFT module
- if sd_module is None and "oft_diag" in key:
- key = key_network_without_network_parts.replace("lora_unet", "diffusion_model")
- key = key_network_without_network_parts.replace("lora_te1_text_model", "0_transformer_text_model")
- sd_module = shared.sd_model.network_layer_mapping.get(key, None)
- if sd_module is None:
- keys_failed_to_match[key_network] = key
- continue
- if key not in matched_networks:
- matched_networks[key] = network.NetworkWeights(network_key=key_network, sd_key=key, w={}, sd_module=sd_module)
- matched_networks[key].w[network_part] = weight
- for key, weights in matched_networks.items():
- net_module = None
- for nettype in module_types:
- net_module = nettype.create_module(net, weights)
- if net_module is not None:
- break
- if net_module is None:
- raise AssertionError(f"Could not find a module type (out of {', '.join([x.__class__.__name__ for x in module_types])}) that would accept those keys: {', '.join(weights.w)}")
- net.modules[key] = net_module
- embeddings = {}
- for emb_name, data in bundle_embeddings.items():
- embedding = textual_inversion.create_embedding_from_data(data, emb_name, filename=network_on_disk.filename + "/" + emb_name)
- embedding.loaded = None
- embeddings[emb_name] = embedding
- net.bundle_embeddings = embeddings
- if keys_failed_to_match:
- logging.debug(f"Network {network_on_disk.filename} didn't match keys: {keys_failed_to_match}")
- return net
- def purge_networks_from_memory():
- while len(networks_in_memory) > shared.opts.lora_in_memory_limit and len(networks_in_memory) > 0:
- name = next(iter(networks_in_memory))
- networks_in_memory.pop(name, None)
- devices.torch_gc()
- def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=None):
- emb_db = sd_hijack.model_hijack.embedding_db
- already_loaded = {}
- for net in loaded_networks:
- if net.name in names:
- already_loaded[net.name] = net
- for emb_name, embedding in net.bundle_embeddings.items():
- if embedding.loaded:
- emb_db.register_embedding_by_name(None, shared.sd_model, emb_name)
- loaded_networks.clear()
- networks_on_disk = [available_networks.get(name, None) if name.lower() in forbidden_network_aliases else available_network_aliases.get(name, None) for name in names]
- if any(x is None for x in networks_on_disk):
- list_available_networks()
- networks_on_disk = [available_networks.get(name, None) if name.lower() in forbidden_network_aliases else available_network_aliases.get(name, None) for name in names]
- failed_to_load_networks = []
- for i, (network_on_disk, name) in enumerate(zip(networks_on_disk, names)):
- net = already_loaded.get(name, None)
- if network_on_disk is not None:
- if net is None:
- net = networks_in_memory.get(name)
- if net is None or os.path.getmtime(network_on_disk.filename) > net.mtime:
- try:
- net = load_network(name, network_on_disk)
- networks_in_memory.pop(name, None)
- networks_in_memory[name] = net
- except Exception as e:
- errors.display(e, f"loading network {network_on_disk.filename}")
- continue
- net.mentioned_name = name
- network_on_disk.read_hash()
- if net is None:
- failed_to_load_networks.append(name)
- logging.info(f"Couldn't find network with name {name}")
- continue
- net.te_multiplier = te_multipliers[i] if te_multipliers else 1.0
- net.unet_multiplier = unet_multipliers[i] if unet_multipliers else 1.0
- net.dyn_dim = dyn_dims[i] if dyn_dims else 1.0
- loaded_networks.append(net)
- for emb_name, embedding in net.bundle_embeddings.items():
- if embedding.loaded is None and emb_name in emb_db.word_embeddings:
- logger.warning(
- f'Skip bundle embedding: "{emb_name}"'
- ' as it was already loaded from embeddings folder'
- )
- continue
- embedding.loaded = False
- if emb_db.expected_shape == -1 or emb_db.expected_shape == embedding.shape:
- embedding.loaded = True
- emb_db.register_embedding(embedding, shared.sd_model)
- else:
- emb_db.skipped_embeddings[name] = embedding
- if failed_to_load_networks:
- lora_not_found_message = f'Lora not found: {", ".join(failed_to_load_networks)}'
- sd_hijack.model_hijack.comments.append(lora_not_found_message)
- if shared.opts.lora_not_found_warning_console:
- print(f'\n{lora_not_found_message}\n')
- if shared.opts.lora_not_found_gradio_warning:
- gr.Warning(lora_not_found_message)
- purge_networks_from_memory()
- def network_restore_weights_from_backup(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, torch.nn.MultiheadAttention]):
- weights_backup = getattr(self, "network_weights_backup", None)
- bias_backup = getattr(self, "network_bias_backup", None)
- if weights_backup is None and bias_backup is None:
- return
- if weights_backup is not None:
- if isinstance(self, torch.nn.MultiheadAttention):
- self.in_proj_weight.copy_(weights_backup[0])
- self.out_proj.weight.copy_(weights_backup[1])
- else:
- self.weight.copy_(weights_backup)
- if bias_backup is not None:
- if isinstance(self, torch.nn.MultiheadAttention):
- self.out_proj.bias.copy_(bias_backup)
- else:
- self.bias.copy_(bias_backup)
- else:
- if isinstance(self, torch.nn.MultiheadAttention):
- self.out_proj.bias = None
- else:
- self.bias = None
- def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, torch.nn.MultiheadAttention]):
- """
- Applies the currently selected set of networks to the weights of torch layer self.
- If weights already have this particular set of networks applied, does nothing.
- If not, restores original weights from backup and alters weights according to networks.
- """
- network_layer_name = getattr(self, 'network_layer_name', None)
- if network_layer_name is None:
- return
- current_names = getattr(self, "network_current_names", ())
- wanted_names = tuple((x.name, x.te_multiplier, x.unet_multiplier, x.dyn_dim) for x in loaded_networks)
- weights_backup = getattr(self, "network_weights_backup", None)
- if weights_backup is None and wanted_names != ():
- if current_names != ():
- raise RuntimeError("no backup weights found and current weights are not unchanged")
- if isinstance(self, torch.nn.MultiheadAttention):
- weights_backup = (self.in_proj_weight.to(devices.cpu, copy=True), self.out_proj.weight.to(devices.cpu, copy=True))
- else:
- weights_backup = self.weight.to(devices.cpu, copy=True)
- self.network_weights_backup = weights_backup
- bias_backup = getattr(self, "network_bias_backup", None)
- if bias_backup is None:
- if isinstance(self, torch.nn.MultiheadAttention) and self.out_proj.bias is not None:
- bias_backup = self.out_proj.bias.to(devices.cpu, copy=True)
- elif getattr(self, 'bias', None) is not None:
- bias_backup = self.bias.to(devices.cpu, copy=True)
- else:
- bias_backup = None
- self.network_bias_backup = bias_backup
- if current_names != wanted_names:
- network_restore_weights_from_backup(self)
- for net in loaded_networks:
- module = net.modules.get(network_layer_name, None)
- if module is not None and hasattr(self, 'weight'):
- try:
- with torch.no_grad():
- if getattr(self, 'fp16_weight', None) is None:
- weight = self.weight
- bias = self.bias
- else:
- weight = self.fp16_weight.clone().to(self.weight.device)
- bias = getattr(self, 'fp16_bias', None)
- if bias is not None:
- bias = bias.clone().to(self.bias.device)
- updown, ex_bias = module.calc_updown(weight)
- if len(weight.shape) == 4 and weight.shape[1] == 9:
- # inpainting model. zero pad updown to make channel[1] 4 to 9
- updown = torch.nn.functional.pad(updown, (0, 0, 0, 0, 0, 5))
- self.weight.copy_((weight.to(dtype=updown.dtype) + updown).to(dtype=self.weight.dtype))
- if ex_bias is not None and hasattr(self, 'bias'):
- if self.bias is None:
- self.bias = torch.nn.Parameter(ex_bias).to(self.weight.dtype)
- else:
- self.bias.copy_((bias + ex_bias).to(dtype=self.bias.dtype))
- except RuntimeError as e:
- logging.debug(f"Network {net.name} layer {network_layer_name}: {e}")
- extra_network_lora.errors[net.name] = extra_network_lora.errors.get(net.name, 0) + 1
- continue
- module_q = net.modules.get(network_layer_name + "_q_proj", None)
- module_k = net.modules.get(network_layer_name + "_k_proj", None)
- module_v = net.modules.get(network_layer_name + "_v_proj", None)
- module_out = net.modules.get(network_layer_name + "_out_proj", None)
- if isinstance(self, torch.nn.MultiheadAttention) and module_q and module_k and module_v and module_out:
- try:
- with torch.no_grad():
- updown_q, _ = module_q.calc_updown(self.in_proj_weight)
- updown_k, _ = module_k.calc_updown(self.in_proj_weight)
- updown_v, _ = module_v.calc_updown(self.in_proj_weight)
- updown_qkv = torch.vstack([updown_q, updown_k, updown_v])
- updown_out, ex_bias = module_out.calc_updown(self.out_proj.weight)
- self.in_proj_weight += updown_qkv
- self.out_proj.weight += updown_out
- if ex_bias is not None:
- if self.out_proj.bias is None:
- self.out_proj.bias = torch.nn.Parameter(ex_bias)
- else:
- self.out_proj.bias += ex_bias
- except RuntimeError as e:
- logging.debug(f"Network {net.name} layer {network_layer_name}: {e}")
- extra_network_lora.errors[net.name] = extra_network_lora.errors.get(net.name, 0) + 1
- continue
- if module is None:
- continue
- logging.debug(f"Network {net.name} layer {network_layer_name}: couldn't find supported operation")
- extra_network_lora.errors[net.name] = extra_network_lora.errors.get(net.name, 0) + 1
- self.network_current_names = wanted_names
- def network_forward(org_module, input, original_forward):
- """
- Old way of applying Lora by executing operations during layer's forward.
- Stacking many loras this way results in big performance degradation.
- """
- if len(loaded_networks) == 0:
- return original_forward(org_module, input)
- input = devices.cond_cast_unet(input)
- network_restore_weights_from_backup(org_module)
- network_reset_cached_weight(org_module)
- y = original_forward(org_module, input)
- network_layer_name = getattr(org_module, 'network_layer_name', None)
- for lora in loaded_networks:
- module = lora.modules.get(network_layer_name, None)
- if module is None:
- continue
- y = module.forward(input, y)
- return y
- def network_reset_cached_weight(self: Union[torch.nn.Conv2d, torch.nn.Linear]):
- self.network_current_names = ()
- self.network_weights_backup = None
- self.network_bias_backup = None
- def network_Linear_forward(self, input):
- if shared.opts.lora_functional:
- return network_forward(self, input, originals.Linear_forward)
- network_apply_weights(self)
- return originals.Linear_forward(self, input)
- def network_Linear_load_state_dict(self, *args, **kwargs):
- network_reset_cached_weight(self)
- return originals.Linear_load_state_dict(self, *args, **kwargs)
- def network_Conv2d_forward(self, input):
- if shared.opts.lora_functional:
- return network_forward(self, input, originals.Conv2d_forward)
- network_apply_weights(self)
- return originals.Conv2d_forward(self, input)
- def network_Conv2d_load_state_dict(self, *args, **kwargs):
- network_reset_cached_weight(self)
- return originals.Conv2d_load_state_dict(self, *args, **kwargs)
- def network_GroupNorm_forward(self, input):
- if shared.opts.lora_functional:
- return network_forward(self, input, originals.GroupNorm_forward)
- network_apply_weights(self)
- return originals.GroupNorm_forward(self, input)
- def network_GroupNorm_load_state_dict(self, *args, **kwargs):
- network_reset_cached_weight(self)
- return originals.GroupNorm_load_state_dict(self, *args, **kwargs)
- def network_LayerNorm_forward(self, input):
- if shared.opts.lora_functional:
- return network_forward(self, input, originals.LayerNorm_forward)
- network_apply_weights(self)
- return originals.LayerNorm_forward(self, input)
- def network_LayerNorm_load_state_dict(self, *args, **kwargs):
- network_reset_cached_weight(self)
- return originals.LayerNorm_load_state_dict(self, *args, **kwargs)
- def network_MultiheadAttention_forward(self, *args, **kwargs):
- network_apply_weights(self)
- return originals.MultiheadAttention_forward(self, *args, **kwargs)
- def network_MultiheadAttention_load_state_dict(self, *args, **kwargs):
- network_reset_cached_weight(self)
- return originals.MultiheadAttention_load_state_dict(self, *args, **kwargs)
- def list_available_networks():
- available_networks.clear()
- available_network_aliases.clear()
- forbidden_network_aliases.clear()
- available_network_hash_lookup.clear()
- forbidden_network_aliases.update({"none": 1, "Addams": 1})
- os.makedirs(shared.cmd_opts.lora_dir, exist_ok=True)
- candidates = list(shared.walk_files(shared.cmd_opts.lora_dir, allowed_extensions=[".pt", ".ckpt", ".safetensors"]))
- candidates += list(shared.walk_files(shared.cmd_opts.lyco_dir_backcompat, allowed_extensions=[".pt", ".ckpt", ".safetensors"]))
- for filename in candidates:
- if os.path.isdir(filename):
- continue
- name = os.path.splitext(os.path.basename(filename))[0]
- try:
- entry = network.NetworkOnDisk(name, filename)
- except OSError: # should catch FileNotFoundError and PermissionError etc.
- errors.report(f"Failed to load network {name} from {filename}", exc_info=True)
- continue
- available_networks[name] = entry
- if entry.alias in available_network_aliases:
- forbidden_network_aliases[entry.alias.lower()] = 1
- available_network_aliases[name] = entry
- available_network_aliases[entry.alias] = entry
- re_network_name = re.compile(r"(.*)\s*\([0-9a-fA-F]+\)")
- def infotext_pasted(infotext, params):
- if "AddNet Module 1" in [x[1] for x in scripts.scripts_txt2img.infotext_fields]:
- return # if the other extension is active, it will handle those fields, no need to do anything
- added = []
- for k in params:
- if not k.startswith("AddNet Model "):
- continue
- num = k[13:]
- if params.get("AddNet Module " + num) != "LoRA":
- continue
- name = params.get("AddNet Model " + num)
- if name is None:
- continue
- m = re_network_name.match(name)
- if m:
- name = m.group(1)
- multiplier = params.get("AddNet Weight A " + num, "1.0")
- added.append(f"<lora:{name}:{multiplier}>")
- if added:
- params["Prompt"] += "\n" + "".join(added)
- originals: lora_patches.LoraPatches = None
- extra_network_lora = None
- available_networks = {}
- available_network_aliases = {}
- loaded_networks = []
- loaded_bundle_embeddings = {}
- networks_in_memory = {}
- available_network_hash_lookup = {}
- forbidden_network_aliases = {}
- list_available_networks()
|