network_lora.py 3.7 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586
  1. import torch
  2. import lyco_helpers
  3. import network
  4. from modules import devices
  5. class ModuleTypeLora(network.ModuleType):
  6. def create_module(self, net: network.Network, weights: network.NetworkWeights):
  7. if all(x in weights.w for x in ["lora_up.weight", "lora_down.weight"]):
  8. return NetworkModuleLora(net, weights)
  9. return None
  10. class NetworkModuleLora(network.NetworkModule):
  11. def __init__(self, net: network.Network, weights: network.NetworkWeights):
  12. super().__init__(net, weights)
  13. self.up_model = self.create_module(weights.w, "lora_up.weight")
  14. self.down_model = self.create_module(weights.w, "lora_down.weight")
  15. self.mid_model = self.create_module(weights.w, "lora_mid.weight", none_ok=True)
  16. self.dim = weights.w["lora_down.weight"].shape[0]
  17. def create_module(self, weights, key, none_ok=False):
  18. weight = weights.get(key)
  19. if weight is None and none_ok:
  20. return None
  21. is_linear = type(self.sd_module) in [torch.nn.Linear, torch.nn.modules.linear.NonDynamicallyQuantizableLinear, torch.nn.MultiheadAttention]
  22. is_conv = type(self.sd_module) in [torch.nn.Conv2d]
  23. if is_linear:
  24. weight = weight.reshape(weight.shape[0], -1)
  25. module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False)
  26. elif is_conv and key == "lora_down.weight" or key == "dyn_up":
  27. if len(weight.shape) == 2:
  28. weight = weight.reshape(weight.shape[0], -1, 1, 1)
  29. if weight.shape[2] != 1 or weight.shape[3] != 1:
  30. module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], self.sd_module.kernel_size, self.sd_module.stride, self.sd_module.padding, bias=False)
  31. else:
  32. module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (1, 1), bias=False)
  33. elif is_conv and key == "lora_mid.weight":
  34. module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], self.sd_module.kernel_size, self.sd_module.stride, self.sd_module.padding, bias=False)
  35. elif is_conv and key == "lora_up.weight" or key == "dyn_down":
  36. module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (1, 1), bias=False)
  37. else:
  38. raise AssertionError(f'Lora layer {self.network_key} matched a layer with unsupported type: {type(self.sd_module).__name__}')
  39. with torch.no_grad():
  40. if weight.shape != module.weight.shape:
  41. weight = weight.reshape(module.weight.shape)
  42. module.weight.copy_(weight)
  43. module.to(device=devices.cpu, dtype=devices.dtype)
  44. module.weight.requires_grad_(False)
  45. return module
  46. def calc_updown(self, orig_weight):
  47. up = self.up_model.weight.to(orig_weight.device)
  48. down = self.down_model.weight.to(orig_weight.device)
  49. output_shape = [up.size(0), down.size(1)]
  50. if self.mid_model is not None:
  51. # cp-decomposition
  52. mid = self.mid_model.weight.to(orig_weight.device)
  53. updown = lyco_helpers.rebuild_cp_decomposition(up, down, mid)
  54. output_shape += mid.shape[2:]
  55. else:
  56. if len(down.shape) == 4:
  57. output_shape += down.shape[2:]
  58. updown = lyco_helpers.rebuild_conventional(up, down, output_shape, self.network.dyn_dim)
  59. return self.finalize_updown(updown, orig_weight, output_shape)
  60. def forward(self, x, y):
  61. self.up_model.to(device=devices.device)
  62. self.down_model.to(device=devices.device)
  63. return y + self.up_model(self.down_model(x)) * self.multiplier() * self.calc_scale()