network.py 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222
  1. from __future__ import annotations
  2. import os
  3. from collections import namedtuple
  4. import enum
  5. import torch.nn as nn
  6. import torch.nn.functional as F
  7. from modules import sd_models, cache, errors, hashes, shared
  8. NetworkWeights = namedtuple('NetworkWeights', ['network_key', 'sd_key', 'w', 'sd_module'])
  9. metadata_tags_order = {"ss_sd_model_name": 1, "ss_resolution": 2, "ss_clip_skip": 3, "ss_num_train_images": 10, "ss_tag_frequency": 20}
  10. class SdVersion(enum.Enum):
  11. Unknown = 1
  12. SD1 = 2
  13. SD2 = 3
  14. SDXL = 4
  15. class NetworkOnDisk:
  16. def __init__(self, name, filename):
  17. self.name = name
  18. self.filename = filename
  19. self.metadata = {}
  20. self.is_safetensors = os.path.splitext(filename)[1].lower() == ".safetensors"
  21. def read_metadata():
  22. metadata = sd_models.read_metadata_from_safetensors(filename)
  23. return metadata
  24. if self.is_safetensors:
  25. try:
  26. self.metadata = cache.cached_data_for_file('safetensors-metadata', "lora/" + self.name, filename, read_metadata)
  27. except Exception as e:
  28. errors.display(e, f"reading lora {filename}")
  29. if self.metadata:
  30. m = {}
  31. for k, v in sorted(self.metadata.items(), key=lambda x: metadata_tags_order.get(x[0], 999)):
  32. m[k] = v
  33. self.metadata = m
  34. self.alias = self.metadata.get('ss_output_name', self.name)
  35. self.hash = None
  36. self.shorthash = None
  37. self.set_hash(
  38. self.metadata.get('sshs_model_hash') or
  39. hashes.sha256_from_cache(self.filename, "lora/" + self.name, use_addnet_hash=self.is_safetensors) or
  40. ''
  41. )
  42. self.sd_version = self.detect_version()
  43. def detect_version(self):
  44. if str(self.metadata.get('ss_base_model_version', "")).startswith("sdxl_"):
  45. return SdVersion.SDXL
  46. elif str(self.metadata.get('ss_v2', "")) == "True":
  47. return SdVersion.SD2
  48. elif len(self.metadata):
  49. return SdVersion.SD1
  50. return SdVersion.Unknown
  51. def set_hash(self, v):
  52. self.hash = v
  53. self.shorthash = self.hash[0:12]
  54. if self.shorthash:
  55. import networks
  56. networks.available_network_hash_lookup[self.shorthash] = self
  57. def read_hash(self):
  58. if not self.hash:
  59. self.set_hash(hashes.sha256(self.filename, "lora/" + self.name, use_addnet_hash=self.is_safetensors) or '')
  60. def get_alias(self):
  61. import networks
  62. if shared.opts.lora_preferred_name == "Filename" or self.alias.lower() in networks.forbidden_network_aliases:
  63. return self.name
  64. else:
  65. return self.alias
  66. class Network: # LoraModule
  67. def __init__(self, name, network_on_disk: NetworkOnDisk):
  68. self.name = name
  69. self.network_on_disk = network_on_disk
  70. self.te_multiplier = 1.0
  71. self.unet_multiplier = 1.0
  72. self.dyn_dim = None
  73. self.modules = {}
  74. self.bundle_embeddings = {}
  75. self.mtime = None
  76. self.mentioned_name = None
  77. """the text that was used to add the network to prompt - can be either name or an alias"""
  78. class ModuleType:
  79. def create_module(self, net: Network, weights: NetworkWeights) -> Network | None:
  80. return None
  81. class NetworkModule:
  82. def __init__(self, net: Network, weights: NetworkWeights):
  83. self.network = net
  84. self.network_key = weights.network_key
  85. self.sd_key = weights.sd_key
  86. self.sd_module = weights.sd_module
  87. if hasattr(self.sd_module, 'weight'):
  88. self.shape = self.sd_module.weight.shape
  89. elif isinstance(self.sd_module, nn.MultiheadAttention):
  90. # For now, only self-attn use Pytorch's MHA
  91. # So assume all qkvo proj have same shape
  92. self.shape = self.sd_module.out_proj.weight.shape
  93. else:
  94. self.shape = None
  95. self.ops = None
  96. self.extra_kwargs = {}
  97. if isinstance(self.sd_module, nn.Conv2d):
  98. self.ops = F.conv2d
  99. self.extra_kwargs = {
  100. 'stride': self.sd_module.stride,
  101. 'padding': self.sd_module.padding
  102. }
  103. elif isinstance(self.sd_module, nn.Linear):
  104. self.ops = F.linear
  105. elif isinstance(self.sd_module, nn.LayerNorm):
  106. self.ops = F.layer_norm
  107. self.extra_kwargs = {
  108. 'normalized_shape': self.sd_module.normalized_shape,
  109. 'eps': self.sd_module.eps
  110. }
  111. elif isinstance(self.sd_module, nn.GroupNorm):
  112. self.ops = F.group_norm
  113. self.extra_kwargs = {
  114. 'num_groups': self.sd_module.num_groups,
  115. 'eps': self.sd_module.eps
  116. }
  117. self.dim = None
  118. self.bias = weights.w.get("bias")
  119. self.alpha = weights.w["alpha"].item() if "alpha" in weights.w else None
  120. self.scale = weights.w["scale"].item() if "scale" in weights.w else None
  121. self.dora_scale = weights.w.get("dora_scale", None)
  122. self.dora_norm_dims = len(self.shape) - 1
  123. def multiplier(self):
  124. if 'transformer' in self.sd_key[:20]:
  125. return self.network.te_multiplier
  126. else:
  127. return self.network.unet_multiplier
  128. def calc_scale(self):
  129. if self.scale is not None:
  130. return self.scale
  131. if self.dim is not None and self.alpha is not None:
  132. return self.alpha / self.dim
  133. return 1.0
  134. def apply_weight_decompose(self, updown, orig_weight):
  135. # Match the device/dtype
  136. orig_weight = orig_weight.to(updown.dtype)
  137. dora_scale = self.dora_scale.to(device=orig_weight.device, dtype=updown.dtype)
  138. updown = updown.to(orig_weight.device)
  139. merged_scale1 = updown + orig_weight
  140. merged_scale1_norm = (
  141. merged_scale1.transpose(0, 1)
  142. .reshape(merged_scale1.shape[1], -1)
  143. .norm(dim=1, keepdim=True)
  144. .reshape(merged_scale1.shape[1], *[1] * self.dora_norm_dims)
  145. .transpose(0, 1)
  146. )
  147. dora_merged = (
  148. merged_scale1 * (dora_scale / merged_scale1_norm)
  149. )
  150. final_updown = dora_merged - orig_weight
  151. return final_updown
  152. def finalize_updown(self, updown, orig_weight, output_shape, ex_bias=None):
  153. if self.bias is not None:
  154. updown = updown.reshape(self.bias.shape)
  155. updown += self.bias.to(orig_weight.device, dtype=updown.dtype)
  156. updown = updown.reshape(output_shape)
  157. if len(output_shape) == 4:
  158. updown = updown.reshape(output_shape)
  159. if orig_weight.size().numel() == updown.size().numel():
  160. updown = updown.reshape(orig_weight.shape)
  161. if ex_bias is not None:
  162. ex_bias = ex_bias * self.multiplier()
  163. if self.dora_scale is not None:
  164. updown = self.apply_weight_decompose(updown, orig_weight)
  165. return updown * self.calc_scale() * self.multiplier(), ex_bias
  166. def calc_updown(self, target):
  167. raise NotImplementedError()
  168. def forward(self, x, y):
  169. """A general forward implementation for all modules"""
  170. if self.ops is None:
  171. raise NotImplementedError()
  172. else:
  173. updown, ex_bias = self.calc_updown(self.sd_module.weight)
  174. return y + self.ops(x, weight=updown, bias=ex_bias, **self.extra_kwargs)