sd_hijack_optimizations.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164
  1. import math
  2. import torch
  3. from torch import einsum
  4. from ldm.util import default
  5. from einops import rearrange
  6. # see https://github.com/basujindal/stable-diffusion/pull/117 for discussion
  7. def split_cross_attention_forward_v1(self, x, context=None, mask=None):
  8. h = self.heads
  9. q = self.to_q(x)
  10. context = default(context, x)
  11. k = self.to_k(context)
  12. v = self.to_v(context)
  13. del context, x
  14. q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
  15. r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device)
  16. for i in range(0, q.shape[0], 2):
  17. end = i + 2
  18. s1 = einsum('b i d, b j d -> b i j', q[i:end], k[i:end])
  19. s1 *= self.scale
  20. s2 = s1.softmax(dim=-1)
  21. del s1
  22. r1[i:end] = einsum('b i j, b j d -> b i d', s2, v[i:end])
  23. del s2
  24. r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h)
  25. del r1
  26. return self.to_out(r2)
  27. # taken from https://github.com/Doggettx/stable-diffusion
  28. def split_cross_attention_forward(self, x, context=None, mask=None):
  29. h = self.heads
  30. q_in = self.to_q(x)
  31. context = default(context, x)
  32. k_in = self.to_k(context) * self.scale
  33. v_in = self.to_v(context)
  34. del context, x
  35. q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in))
  36. del q_in, k_in, v_in
  37. r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
  38. stats = torch.cuda.memory_stats(q.device)
  39. mem_active = stats['active_bytes.all.current']
  40. mem_reserved = stats['reserved_bytes.all.current']
  41. mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device())
  42. mem_free_torch = mem_reserved - mem_active
  43. mem_free_total = mem_free_cuda + mem_free_torch
  44. gb = 1024 ** 3
  45. tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size()
  46. modifier = 3 if q.element_size() == 2 else 2.5
  47. mem_required = tensor_size * modifier
  48. steps = 1
  49. if mem_required > mem_free_total:
  50. steps = 2 ** (math.ceil(math.log(mem_required / mem_free_total, 2)))
  51. # print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB "
  52. # f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}")
  53. if steps > 64:
  54. max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64
  55. raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). '
  56. f'Need: {mem_required / 64 / gb:0.1f}GB free, Have:{mem_free_total / gb:0.1f}GB free')
  57. slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
  58. for i in range(0, q.shape[1], slice_size):
  59. end = i + slice_size
  60. s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k)
  61. s2 = s1.softmax(dim=-1, dtype=q.dtype)
  62. del s1
  63. r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
  64. del s2
  65. del q, k, v
  66. r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h)
  67. del r1
  68. return self.to_out(r2)
  69. def nonlinearity_hijack(x):
  70. # swish
  71. t = torch.sigmoid(x)
  72. x *= t
  73. del t
  74. return x
  75. def cross_attention_attnblock_forward(self, x):
  76. h_ = x
  77. h_ = self.norm(h_)
  78. q1 = self.q(h_)
  79. k1 = self.k(h_)
  80. v = self.v(h_)
  81. # compute attention
  82. b, c, h, w = q1.shape
  83. q2 = q1.reshape(b, c, h*w)
  84. del q1
  85. q = q2.permute(0, 2, 1) # b,hw,c
  86. del q2
  87. k = k1.reshape(b, c, h*w) # b,c,hw
  88. del k1
  89. h_ = torch.zeros_like(k, device=q.device)
  90. stats = torch.cuda.memory_stats(q.device)
  91. mem_active = stats['active_bytes.all.current']
  92. mem_reserved = stats['reserved_bytes.all.current']
  93. mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device())
  94. mem_free_torch = mem_reserved - mem_active
  95. mem_free_total = mem_free_cuda + mem_free_torch
  96. tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * q.element_size()
  97. mem_required = tensor_size * 2.5
  98. steps = 1
  99. if mem_required > mem_free_total:
  100. steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2)))
  101. slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
  102. for i in range(0, q.shape[1], slice_size):
  103. end = i + slice_size
  104. w1 = torch.bmm(q[:, i:end], k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
  105. w2 = w1 * (int(c)**(-0.5))
  106. del w1
  107. w3 = torch.nn.functional.softmax(w2, dim=2, dtype=q.dtype)
  108. del w2
  109. # attend to values
  110. v1 = v.reshape(b, c, h*w)
  111. w4 = w3.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
  112. del w3
  113. h_[:, :, i:end] = torch.bmm(v1, w4) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
  114. del v1, w4
  115. h2 = h_.reshape(b, c, h, w)
  116. del h_
  117. h3 = self.proj_out(h2)
  118. del h2
  119. h3 += x
  120. return h3