hypernetwork.py 2.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788
  1. import glob
  2. import os
  3. import sys
  4. import traceback
  5. import torch
  6. from ldm.util import default
  7. from modules import devices, shared
  8. import torch
  9. from torch import einsum
  10. from einops import rearrange, repeat
  11. class HypernetworkModule(torch.nn.Module):
  12. def __init__(self, dim, state_dict):
  13. super().__init__()
  14. self.linear1 = torch.nn.Linear(dim, dim * 2)
  15. self.linear2 = torch.nn.Linear(dim * 2, dim)
  16. self.load_state_dict(state_dict, strict=True)
  17. self.to(devices.device)
  18. def forward(self, x):
  19. return x + (self.linear2(self.linear1(x)))
  20. class Hypernetwork:
  21. filename = None
  22. name = None
  23. def __init__(self, filename):
  24. self.filename = filename
  25. self.name = os.path.splitext(os.path.basename(filename))[0]
  26. self.layers = {}
  27. state_dict = torch.load(filename, map_location='cpu')
  28. for size, sd in state_dict.items():
  29. self.layers[size] = (HypernetworkModule(size, sd[0]), HypernetworkModule(size, sd[1]))
  30. def load_hypernetworks(path):
  31. res = {}
  32. for filename in glob.iglob(path + '**/*.pt', recursive=True):
  33. try:
  34. hn = Hypernetwork(filename)
  35. res[hn.name] = hn
  36. except Exception:
  37. print(f"Error loading hypernetwork {filename}", file=sys.stderr)
  38. print(traceback.format_exc(), file=sys.stderr)
  39. return res
  40. def attention_CrossAttention_forward(self, x, context=None, mask=None):
  41. h = self.heads
  42. q = self.to_q(x)
  43. context = default(context, x)
  44. hypernetwork = shared.selected_hypernetwork()
  45. hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None)
  46. if hypernetwork_layers is not None:
  47. k = self.to_k(hypernetwork_layers[0](context))
  48. v = self.to_v(hypernetwork_layers[1](context))
  49. else:
  50. k = self.to_k(context)
  51. v = self.to_v(context)
  52. q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
  53. sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
  54. if mask is not None:
  55. mask = rearrange(mask, 'b ... -> b (...)')
  56. max_neg_value = -torch.finfo(sim.dtype).max
  57. mask = repeat(mask, 'b j -> (b h) () j', h=h)
  58. sim.masked_fill_(~mask, max_neg_value)
  59. # attention, what we cannot get enough of
  60. attn = sim.softmax(dim=-1)
  61. out = einsum('b i j, b j d -> b i d', attn, v)
  62. out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
  63. return self.to_out(out)