network.py 7.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228
  1. from __future__ import annotations
  2. import os
  3. from collections import namedtuple
  4. import enum
  5. import torch.nn as nn
  6. import torch.nn.functional as F
  7. from modules import sd_models, cache, errors, hashes, shared
  8. import modules.models.sd3.mmdit
  9. NetworkWeights = namedtuple('NetworkWeights', ['network_key', 'sd_key', 'w', 'sd_module'])
  10. metadata_tags_order = {"ss_sd_model_name": 1, "ss_resolution": 2, "ss_clip_skip": 3, "ss_num_train_images": 10, "ss_tag_frequency": 20}
  11. class SdVersion(enum.Enum):
  12. Unknown = 1
  13. SD1 = 2
  14. SD2 = 3
  15. SDXL = 4
  16. class NetworkOnDisk:
  17. def __init__(self, name, filename):
  18. self.name = name
  19. self.filename = filename
  20. self.metadata = {}
  21. self.is_safetensors = os.path.splitext(filename)[1].lower() == ".safetensors"
  22. def read_metadata():
  23. metadata = sd_models.read_metadata_from_safetensors(filename)
  24. return metadata
  25. if self.is_safetensors:
  26. try:
  27. self.metadata = cache.cached_data_for_file('safetensors-metadata', "lora/" + self.name, filename, read_metadata)
  28. except Exception as e:
  29. errors.display(e, f"reading lora {filename}")
  30. if self.metadata:
  31. m = {}
  32. for k, v in sorted(self.metadata.items(), key=lambda x: metadata_tags_order.get(x[0], 999)):
  33. m[k] = v
  34. self.metadata = m
  35. self.alias = self.metadata.get('ss_output_name', self.name)
  36. self.hash = None
  37. self.shorthash = None
  38. self.set_hash(
  39. self.metadata.get('sshs_model_hash') or
  40. hashes.sha256_from_cache(self.filename, "lora/" + self.name, use_addnet_hash=self.is_safetensors) or
  41. ''
  42. )
  43. self.sd_version = self.detect_version()
  44. def detect_version(self):
  45. if str(self.metadata.get('ss_base_model_version', "")).startswith("sdxl_"):
  46. return SdVersion.SDXL
  47. elif str(self.metadata.get('ss_v2', "")) == "True":
  48. return SdVersion.SD2
  49. elif len(self.metadata):
  50. return SdVersion.SD1
  51. return SdVersion.Unknown
  52. def set_hash(self, v):
  53. self.hash = v
  54. self.shorthash = self.hash[0:12]
  55. if self.shorthash:
  56. import networks
  57. networks.available_network_hash_lookup[self.shorthash] = self
  58. def read_hash(self):
  59. if not self.hash:
  60. self.set_hash(hashes.sha256(self.filename, "lora/" + self.name, use_addnet_hash=self.is_safetensors) or '')
  61. def get_alias(self):
  62. import networks
  63. if shared.opts.lora_preferred_name == "Filename" or self.alias.lower() in networks.forbidden_network_aliases:
  64. return self.name
  65. else:
  66. return self.alias
  67. class Network: # LoraModule
  68. def __init__(self, name, network_on_disk: NetworkOnDisk):
  69. self.name = name
  70. self.network_on_disk = network_on_disk
  71. self.te_multiplier = 1.0
  72. self.unet_multiplier = 1.0
  73. self.dyn_dim = None
  74. self.modules = {}
  75. self.bundle_embeddings = {}
  76. self.mtime = None
  77. self.mentioned_name = None
  78. """the text that was used to add the network to prompt - can be either name or an alias"""
  79. class ModuleType:
  80. def create_module(self, net: Network, weights: NetworkWeights) -> Network | None:
  81. return None
  82. class NetworkModule:
  83. def __init__(self, net: Network, weights: NetworkWeights):
  84. self.network = net
  85. self.network_key = weights.network_key
  86. self.sd_key = weights.sd_key
  87. self.sd_module = weights.sd_module
  88. if isinstance(self.sd_module, modules.models.sd3.mmdit.QkvLinear):
  89. s = self.sd_module.weight.shape
  90. self.shape = (s[0] // 3, s[1])
  91. elif hasattr(self.sd_module, 'weight'):
  92. self.shape = self.sd_module.weight.shape
  93. elif isinstance(self.sd_module, nn.MultiheadAttention):
  94. # For now, only self-attn use Pytorch's MHA
  95. # So assume all qkvo proj have same shape
  96. self.shape = self.sd_module.out_proj.weight.shape
  97. else:
  98. self.shape = None
  99. self.ops = None
  100. self.extra_kwargs = {}
  101. if isinstance(self.sd_module, nn.Conv2d):
  102. self.ops = F.conv2d
  103. self.extra_kwargs = {
  104. 'stride': self.sd_module.stride,
  105. 'padding': self.sd_module.padding
  106. }
  107. elif isinstance(self.sd_module, nn.Linear):
  108. self.ops = F.linear
  109. elif isinstance(self.sd_module, nn.LayerNorm):
  110. self.ops = F.layer_norm
  111. self.extra_kwargs = {
  112. 'normalized_shape': self.sd_module.normalized_shape,
  113. 'eps': self.sd_module.eps
  114. }
  115. elif isinstance(self.sd_module, nn.GroupNorm):
  116. self.ops = F.group_norm
  117. self.extra_kwargs = {
  118. 'num_groups': self.sd_module.num_groups,
  119. 'eps': self.sd_module.eps
  120. }
  121. self.dim = None
  122. self.bias = weights.w.get("bias")
  123. self.alpha = weights.w["alpha"].item() if "alpha" in weights.w else None
  124. self.scale = weights.w["scale"].item() if "scale" in weights.w else None
  125. self.dora_scale = weights.w.get("dora_scale", None)
  126. self.dora_norm_dims = len(self.shape) - 1
  127. def multiplier(self):
  128. if 'transformer' in self.sd_key[:20]:
  129. return self.network.te_multiplier
  130. else:
  131. return self.network.unet_multiplier
  132. def calc_scale(self):
  133. if self.scale is not None:
  134. return self.scale
  135. if self.dim is not None and self.alpha is not None:
  136. return self.alpha / self.dim
  137. return 1.0
  138. def apply_weight_decompose(self, updown, orig_weight):
  139. # Match the device/dtype
  140. orig_weight = orig_weight.to(updown.dtype)
  141. dora_scale = self.dora_scale.to(device=orig_weight.device, dtype=updown.dtype)
  142. updown = updown.to(orig_weight.device)
  143. merged_scale1 = updown + orig_weight
  144. merged_scale1_norm = (
  145. merged_scale1.transpose(0, 1)
  146. .reshape(merged_scale1.shape[1], -1)
  147. .norm(dim=1, keepdim=True)
  148. .reshape(merged_scale1.shape[1], *[1] * self.dora_norm_dims)
  149. .transpose(0, 1)
  150. )
  151. dora_merged = (
  152. merged_scale1 * (dora_scale / merged_scale1_norm)
  153. )
  154. final_updown = dora_merged - orig_weight
  155. return final_updown
  156. def finalize_updown(self, updown, orig_weight, output_shape, ex_bias=None):
  157. if self.bias is not None:
  158. updown = updown.reshape(self.bias.shape)
  159. updown += self.bias.to(orig_weight.device, dtype=updown.dtype)
  160. updown = updown.reshape(output_shape)
  161. if len(output_shape) == 4:
  162. updown = updown.reshape(output_shape)
  163. if orig_weight.size().numel() == updown.size().numel():
  164. updown = updown.reshape(orig_weight.shape)
  165. if ex_bias is not None:
  166. ex_bias = ex_bias * self.multiplier()
  167. updown = updown * self.calc_scale()
  168. if self.dora_scale is not None:
  169. updown = self.apply_weight_decompose(updown, orig_weight)
  170. return updown * self.multiplier(), ex_bias
  171. def calc_updown(self, target):
  172. raise NotImplementedError()
  173. def forward(self, x, y):
  174. """A general forward implementation for all modules"""
  175. if self.ops is None:
  176. raise NotImplementedError()
  177. else:
  178. updown, ex_bias = self.calc_updown(self.sd_module.weight)
  179. return y + self.ops(x, weight=updown, bias=ex_bias, **self.extra_kwargs)