hypernetwork.py 1.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162
  1. import glob
  2. import os
  3. import sys
  4. import traceback
  5. import torch
  6. from modules import devices
  7. class HypernetworkModule(torch.nn.Module):
  8. def __init__(self, dim, state_dict):
  9. super().__init__()
  10. self.linear1 = torch.nn.Linear(dim, dim * 2)
  11. self.linear2 = torch.nn.Linear(dim * 2, dim)
  12. self.load_state_dict(state_dict, strict=True)
  13. self.to(devices.device)
  14. def forward(self, x):
  15. return x + (self.linear2(self.linear1(x)))
  16. class Hypernetwork:
  17. filename = None
  18. name = None
  19. def __init__(self, filename):
  20. self.filename = filename
  21. self.name = os.path.splitext(os.path.basename(filename))[0]
  22. self.layers = {}
  23. state_dict = torch.load(filename, map_location='cpu')
  24. for size, sd in state_dict.items():
  25. self.layers[size] = (HypernetworkModule(size, sd[0]), HypernetworkModule(size, sd[1]))
  26. def load_hypernetworks(path):
  27. res = {}
  28. for filename in glob.iglob(path + '**/*.pt', recursive=True):
  29. try:
  30. hn = Hypernetwork(filename)
  31. res[hn.name] = hn
  32. except Exception:
  33. print(f"Error loading hypernetwork {filename}", file=sys.stderr)
  34. print(traceback.format_exc(), file=sys.stderr)
  35. return res
  36. def apply(self, x, context=None, mask=None, original=None):
  37. if CrossAttention.hypernetwork is not None and context.shape[2] in CrossAttention.hypernetwork:
  38. if context.shape[1] == 77 and CrossAttention.noise_cond:
  39. context = context + (torch.randn_like(context) * 0.1)
  40. h_k, h_v = CrossAttention.hypernetwork[context.shape[2]]
  41. k = self.to_k(h_k(context))
  42. v = self.to_v(h_v(context))
  43. else:
  44. k = self.to_k(context)
  45. v = self.to_v(context)