sd_hijack_optimizations.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168
  1. import math
  2. import torch
  3. from torch import einsum
  4. from ldm.util import default
  5. from einops import rearrange
  6. from modules import shared
  7. # see https://github.com/basujindal/stable-diffusion/pull/117 for discussion
  8. def split_cross_attention_forward_v1(self, x, context=None, mask=None):
  9. h = self.heads
  10. q = self.to_q(x)
  11. context = default(context, x)
  12. k = self.to_k(context)
  13. v = self.to_v(context)
  14. del context, x
  15. q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
  16. r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device)
  17. for i in range(0, q.shape[0], 2):
  18. end = i + 2
  19. s1 = einsum('b i d, b j d -> b i j', q[i:end], k[i:end])
  20. s1 *= self.scale
  21. s2 = s1.softmax(dim=-1)
  22. del s1
  23. r1[i:end] = einsum('b i j, b j d -> b i d', s2, v[i:end])
  24. del s2
  25. r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h)
  26. del r1
  27. return self.to_out(r2)
  28. # taken from https://github.com/Doggettx/stable-diffusion
  29. def split_cross_attention_forward(self, x, context=None, mask=None):
  30. h = self.heads
  31. q_in = self.to_q(x)
  32. context = default(context, x)
  33. hypernetwork_layers = (shared.hypernetwork.layers if shared.hypernetwork is not None else {}).get(context.shape[2], None)
  34. if hypernetwork_layers is not None:
  35. k_in = self.to_k(hypernetwork_layers[0](context))
  36. v_in = self.to_v(hypernetwork_layers[1](context))
  37. else:
  38. k_in = self.to_k(context)
  39. v_in = self.to_v(context)
  40. k_in *= self.scale
  41. del context, x
  42. q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in))
  43. del q_in, k_in, v_in
  44. r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
  45. stats = torch.cuda.memory_stats(q.device)
  46. mem_active = stats['active_bytes.all.current']
  47. mem_reserved = stats['reserved_bytes.all.current']
  48. mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device())
  49. mem_free_torch = mem_reserved - mem_active
  50. mem_free_total = mem_free_cuda + mem_free_torch
  51. gb = 1024 ** 3
  52. tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size()
  53. modifier = 3 if q.element_size() == 2 else 2.5
  54. mem_required = tensor_size * modifier
  55. steps = 1
  56. if mem_required > mem_free_total:
  57. steps = 2 ** (math.ceil(math.log(mem_required / mem_free_total, 2)))
  58. # print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB "
  59. # f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}")
  60. if steps > 64:
  61. max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64
  62. raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). '
  63. f'Need: {mem_required / 64 / gb:0.1f}GB free, Have:{mem_free_total / gb:0.1f}GB free')
  64. slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
  65. for i in range(0, q.shape[1], slice_size):
  66. end = i + slice_size
  67. s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k)
  68. s2 = s1.softmax(dim=-1, dtype=q.dtype)
  69. del s1
  70. r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
  71. del s2
  72. del q, k, v
  73. r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h)
  74. del r1
  75. return self.to_out(r2)
  76. def cross_attention_attnblock_forward(self, x):
  77. h_ = x
  78. h_ = self.norm(h_)
  79. q1 = self.q(h_)
  80. k1 = self.k(h_)
  81. v = self.v(h_)
  82. # compute attention
  83. b, c, h, w = q1.shape
  84. q2 = q1.reshape(b, c, h*w)
  85. del q1
  86. q = q2.permute(0, 2, 1) # b,hw,c
  87. del q2
  88. k = k1.reshape(b, c, h*w) # b,c,hw
  89. del k1
  90. h_ = torch.zeros_like(k, device=q.device)
  91. stats = torch.cuda.memory_stats(q.device)
  92. mem_active = stats['active_bytes.all.current']
  93. mem_reserved = stats['reserved_bytes.all.current']
  94. mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device())
  95. mem_free_torch = mem_reserved - mem_active
  96. mem_free_total = mem_free_cuda + mem_free_torch
  97. tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * q.element_size()
  98. mem_required = tensor_size * 2.5
  99. steps = 1
  100. if mem_required > mem_free_total:
  101. steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2)))
  102. slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
  103. for i in range(0, q.shape[1], slice_size):
  104. end = i + slice_size
  105. w1 = torch.bmm(q[:, i:end], k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
  106. w2 = w1 * (int(c)**(-0.5))
  107. del w1
  108. w3 = torch.nn.functional.softmax(w2, dim=2, dtype=q.dtype)
  109. del w2
  110. # attend to values
  111. v1 = v.reshape(b, c, h*w)
  112. w4 = w3.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
  113. del w3
  114. h_[:, :, i:end] = torch.bmm(v1, w4) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
  115. del v1, w4
  116. h2 = h_.reshape(b, c, h, w)
  117. del h_
  118. h3 = self.proj_out(h2)
  119. del h2
  120. h3 += x
  121. return h3