network_lora.py 4.0 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394
  1. import torch
  2. import lyco_helpers
  3. import modules.models.sd3.mmdit
  4. import network
  5. from modules import devices
  6. class ModuleTypeLora(network.ModuleType):
  7. def create_module(self, net: network.Network, weights: network.NetworkWeights):
  8. if all(x in weights.w for x in ["lora_up.weight", "lora_down.weight"]):
  9. return NetworkModuleLora(net, weights)
  10. if all(x in weights.w for x in ["lora_A.weight", "lora_B.weight"]):
  11. w = weights.w.copy()
  12. weights.w.clear()
  13. weights.w.update({"lora_up.weight": w["lora_B.weight"], "lora_down.weight": w["lora_A.weight"]})
  14. return NetworkModuleLora(net, weights)
  15. return None
  16. class NetworkModuleLora(network.NetworkModule):
  17. def __init__(self, net: network.Network, weights: network.NetworkWeights):
  18. super().__init__(net, weights)
  19. self.up_model = self.create_module(weights.w, "lora_up.weight")
  20. self.down_model = self.create_module(weights.w, "lora_down.weight")
  21. self.mid_model = self.create_module(weights.w, "lora_mid.weight", none_ok=True)
  22. self.dim = weights.w["lora_down.weight"].shape[0]
  23. def create_module(self, weights, key, none_ok=False):
  24. weight = weights.get(key)
  25. if weight is None and none_ok:
  26. return None
  27. is_linear = type(self.sd_module) in [torch.nn.Linear, torch.nn.modules.linear.NonDynamicallyQuantizableLinear, torch.nn.MultiheadAttention, modules.models.sd3.mmdit.QkvLinear]
  28. is_conv = type(self.sd_module) in [torch.nn.Conv2d]
  29. if is_linear:
  30. weight = weight.reshape(weight.shape[0], -1)
  31. module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False)
  32. elif is_conv and key == "lora_down.weight" or key == "dyn_up":
  33. if len(weight.shape) == 2:
  34. weight = weight.reshape(weight.shape[0], -1, 1, 1)
  35. if weight.shape[2] != 1 or weight.shape[3] != 1:
  36. module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], self.sd_module.kernel_size, self.sd_module.stride, self.sd_module.padding, bias=False)
  37. else:
  38. module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (1, 1), bias=False)
  39. elif is_conv and key == "lora_mid.weight":
  40. module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], self.sd_module.kernel_size, self.sd_module.stride, self.sd_module.padding, bias=False)
  41. elif is_conv and key == "lora_up.weight" or key == "dyn_down":
  42. module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (1, 1), bias=False)
  43. else:
  44. raise AssertionError(f'Lora layer {self.network_key} matched a layer with unsupported type: {type(self.sd_module).__name__}')
  45. with torch.no_grad():
  46. if weight.shape != module.weight.shape:
  47. weight = weight.reshape(module.weight.shape)
  48. module.weight.copy_(weight)
  49. module.to(device=devices.cpu, dtype=devices.dtype)
  50. module.weight.requires_grad_(False)
  51. return module
  52. def calc_updown(self, orig_weight):
  53. up = self.up_model.weight.to(orig_weight.device)
  54. down = self.down_model.weight.to(orig_weight.device)
  55. output_shape = [up.size(0), down.size(1)]
  56. if self.mid_model is not None:
  57. # cp-decomposition
  58. mid = self.mid_model.weight.to(orig_weight.device)
  59. updown = lyco_helpers.rebuild_cp_decomposition(up, down, mid)
  60. output_shape += mid.shape[2:]
  61. else:
  62. if len(down.shape) == 4:
  63. output_shape += down.shape[2:]
  64. updown = lyco_helpers.rebuild_conventional(up, down, output_shape, self.network.dyn_dim)
  65. return self.finalize_updown(updown, orig_weight, output_shape)
  66. def forward(self, x, y):
  67. self.up_model.to(device=devices.device)
  68. self.down_model.to(device=devices.device)
  69. return y + self.up_model(self.down_model(x)) * self.multiplier() * self.calc_scale()