network.py 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190
  1. from __future__ import annotations
  2. import os
  3. from collections import namedtuple
  4. import enum
  5. import torch.nn as nn
  6. import torch.nn.functional as F
  7. from modules import sd_models, cache, errors, hashes, shared
  8. NetworkWeights = namedtuple('NetworkWeights', ['network_key', 'sd_key', 'w', 'sd_module'])
  9. metadata_tags_order = {"ss_sd_model_name": 1, "ss_resolution": 2, "ss_clip_skip": 3, "ss_num_train_images": 10, "ss_tag_frequency": 20}
  10. class SdVersion(enum.Enum):
  11. Unknown = 1
  12. SD1 = 2
  13. SD2 = 3
  14. SDXL = 4
  15. class NetworkOnDisk:
  16. def __init__(self, name, filename):
  17. self.name = name
  18. self.filename = filename
  19. self.metadata = {}
  20. self.is_safetensors = os.path.splitext(filename)[1].lower() == ".safetensors"
  21. def read_metadata():
  22. metadata = sd_models.read_metadata_from_safetensors(filename)
  23. metadata.pop('ssmd_cover_images', None) # those are cover images, and they are too big to display in UI as text
  24. return metadata
  25. if self.is_safetensors:
  26. try:
  27. self.metadata = cache.cached_data_for_file('safetensors-metadata', "lora/" + self.name, filename, read_metadata)
  28. except Exception as e:
  29. errors.display(e, f"reading lora {filename}")
  30. if self.metadata:
  31. m = {}
  32. for k, v in sorted(self.metadata.items(), key=lambda x: metadata_tags_order.get(x[0], 999)):
  33. m[k] = v
  34. self.metadata = m
  35. self.alias = self.metadata.get('ss_output_name', self.name)
  36. self.hash = None
  37. self.shorthash = None
  38. self.set_hash(
  39. self.metadata.get('sshs_model_hash') or
  40. hashes.sha256_from_cache(self.filename, "lora/" + self.name, use_addnet_hash=self.is_safetensors) or
  41. ''
  42. )
  43. self.sd_version = self.detect_version()
  44. def detect_version(self):
  45. if str(self.metadata.get('ss_base_model_version', "")).startswith("sdxl_"):
  46. return SdVersion.SDXL
  47. elif str(self.metadata.get('ss_v2', "")) == "True":
  48. return SdVersion.SD2
  49. elif len(self.metadata):
  50. return SdVersion.SD1
  51. return SdVersion.Unknown
  52. def set_hash(self, v):
  53. self.hash = v
  54. self.shorthash = self.hash[0:12]
  55. if self.shorthash:
  56. import networks
  57. networks.available_network_hash_lookup[self.shorthash] = self
  58. def read_hash(self):
  59. if not self.hash:
  60. self.set_hash(hashes.sha256(self.filename, "lora/" + self.name, use_addnet_hash=self.is_safetensors) or '')
  61. def get_alias(self):
  62. import networks
  63. if shared.opts.lora_preferred_name == "Filename" or self.alias.lower() in networks.forbidden_network_aliases:
  64. return self.name
  65. else:
  66. return self.alias
  67. class Network: # LoraModule
  68. def __init__(self, name, network_on_disk: NetworkOnDisk):
  69. self.name = name
  70. self.network_on_disk = network_on_disk
  71. self.te_multiplier = 1.0
  72. self.unet_multiplier = 1.0
  73. self.dyn_dim = None
  74. self.modules = {}
  75. self.bundle_embeddings = {}
  76. self.mtime = None
  77. self.mentioned_name = None
  78. """the text that was used to add the network to prompt - can be either name or an alias"""
  79. class ModuleType:
  80. def create_module(self, net: Network, weights: NetworkWeights) -> Network | None:
  81. return None
  82. class NetworkModule:
  83. def __init__(self, net: Network, weights: NetworkWeights):
  84. self.network = net
  85. self.network_key = weights.network_key
  86. self.sd_key = weights.sd_key
  87. self.sd_module = weights.sd_module
  88. if hasattr(self.sd_module, 'weight'):
  89. self.shape = self.sd_module.weight.shape
  90. self.ops = None
  91. self.extra_kwargs = {}
  92. if isinstance(self.sd_module, nn.Conv2d):
  93. self.ops = F.conv2d
  94. self.extra_kwargs = {
  95. 'stride': self.sd_module.stride,
  96. 'padding': self.sd_module.padding
  97. }
  98. elif isinstance(self.sd_module, nn.Linear):
  99. self.ops = F.linear
  100. elif isinstance(self.sd_module, nn.LayerNorm):
  101. self.ops = F.layer_norm
  102. self.extra_kwargs = {
  103. 'normalized_shape': self.sd_module.normalized_shape,
  104. 'eps': self.sd_module.eps
  105. }
  106. elif isinstance(self.sd_module, nn.GroupNorm):
  107. self.ops = F.group_norm
  108. self.extra_kwargs = {
  109. 'num_groups': self.sd_module.num_groups,
  110. 'eps': self.sd_module.eps
  111. }
  112. self.dim = None
  113. self.bias = weights.w.get("bias")
  114. self.alpha = weights.w["alpha"].item() if "alpha" in weights.w else None
  115. self.scale = weights.w["scale"].item() if "scale" in weights.w else None
  116. def multiplier(self):
  117. if 'transformer' in self.sd_key[:20]:
  118. return self.network.te_multiplier
  119. else:
  120. return self.network.unet_multiplier
  121. def calc_scale(self):
  122. if self.scale is not None:
  123. return self.scale
  124. if self.dim is not None and self.alpha is not None:
  125. return self.alpha / self.dim
  126. return 1.0
  127. def finalize_updown(self, updown, orig_weight, output_shape, ex_bias=None):
  128. if self.bias is not None:
  129. updown = updown.reshape(self.bias.shape)
  130. updown += self.bias.to(orig_weight.device, dtype=updown.dtype)
  131. updown = updown.reshape(output_shape)
  132. if len(output_shape) == 4:
  133. updown = updown.reshape(output_shape)
  134. if orig_weight.size().numel() == updown.size().numel():
  135. updown = updown.reshape(orig_weight.shape)
  136. if ex_bias is not None:
  137. ex_bias = ex_bias * self.multiplier()
  138. return updown * self.calc_scale() * self.multiplier(), ex_bias
  139. def calc_updown(self, target):
  140. raise NotImplementedError()
  141. def forward(self, x, y):
  142. """A general forward implementation for all modules"""
  143. if self.ops is None:
  144. raise NotImplementedError()
  145. else:
  146. updown, ex_bias = self.calc_updown(self.sd_module.weight)
  147. return y + self.ops(x, weight=updown, bias=ex_bias, **self.extra_kwargs)