sd_hijack_optimizations.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516
  1. import math
  2. import sys
  3. import traceback
  4. import psutil
  5. import torch
  6. from torch import einsum
  7. from ldm.util import default
  8. from einops import rearrange
  9. from modules import shared, errors, devices
  10. from modules.hypernetworks import hypernetwork
  11. from .sub_quadratic_attention import efficient_dot_product_attention
  12. if shared.cmd_opts.xformers or shared.cmd_opts.force_enable_xformers:
  13. try:
  14. import xformers.ops
  15. shared.xformers_available = True
  16. except Exception:
  17. print("Cannot import xformers", file=sys.stderr)
  18. print(traceback.format_exc(), file=sys.stderr)
  19. def get_available_vram():
  20. if shared.device.type == 'cuda':
  21. stats = torch.cuda.memory_stats(shared.device)
  22. mem_active = stats['active_bytes.all.current']
  23. mem_reserved = stats['reserved_bytes.all.current']
  24. mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device())
  25. mem_free_torch = mem_reserved - mem_active
  26. mem_free_total = mem_free_cuda + mem_free_torch
  27. return mem_free_total
  28. else:
  29. return psutil.virtual_memory().available
  30. # see https://github.com/basujindal/stable-diffusion/pull/117 for discussion
  31. def split_cross_attention_forward_v1(self, x, context=None, mask=None):
  32. h = self.heads
  33. q_in = self.to_q(x)
  34. context = default(context, x)
  35. context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context)
  36. k_in = self.to_k(context_k)
  37. v_in = self.to_v(context_v)
  38. del context, context_k, context_v, x
  39. q, k, v = (rearrange(t, 'b n (h d) -> (b h) n d', h=h) for t in (q_in, k_in, v_in))
  40. del q_in, k_in, v_in
  41. dtype = q.dtype
  42. if shared.opts.upcast_attn:
  43. q, k, v = q.float(), k.float(), v.float()
  44. with devices.without_autocast(disable=not shared.opts.upcast_attn):
  45. r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
  46. for i in range(0, q.shape[0], 2):
  47. end = i + 2
  48. s1 = einsum('b i d, b j d -> b i j', q[i:end], k[i:end])
  49. s1 *= self.scale
  50. s2 = s1.softmax(dim=-1)
  51. del s1
  52. r1[i:end] = einsum('b i j, b j d -> b i d', s2, v[i:end])
  53. del s2
  54. del q, k, v
  55. r1 = r1.to(dtype)
  56. r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h)
  57. del r1
  58. return self.to_out(r2)
  59. # taken from https://github.com/Doggettx/stable-diffusion and modified
  60. def split_cross_attention_forward(self, x, context=None, mask=None):
  61. h = self.heads
  62. q_in = self.to_q(x)
  63. context = default(context, x)
  64. context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context)
  65. k_in = self.to_k(context_k)
  66. v_in = self.to_v(context_v)
  67. dtype = q_in.dtype
  68. if shared.opts.upcast_attn:
  69. q_in, k_in, v_in = q_in.float(), k_in.float(), v_in if v_in.device.type == 'mps' else v_in.float()
  70. with devices.without_autocast(disable=not shared.opts.upcast_attn):
  71. k_in = k_in * self.scale
  72. del context, x
  73. q, k, v = (rearrange(t, 'b n (h d) -> (b h) n d', h=h) for t in (q_in, k_in, v_in))
  74. del q_in, k_in, v_in
  75. r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
  76. mem_free_total = get_available_vram()
  77. gb = 1024 ** 3
  78. tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size()
  79. modifier = 3 if q.element_size() == 2 else 2.5
  80. mem_required = tensor_size * modifier
  81. steps = 1
  82. if mem_required > mem_free_total:
  83. steps = 2 ** (math.ceil(math.log(mem_required / mem_free_total, 2)))
  84. # print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB "
  85. # f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}")
  86. if steps > 64:
  87. max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64
  88. raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). '
  89. f'Need: {mem_required / 64 / gb:0.1f}GB free, Have:{mem_free_total / gb:0.1f}GB free')
  90. slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
  91. for i in range(0, q.shape[1], slice_size):
  92. end = i + slice_size
  93. s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k)
  94. s2 = s1.softmax(dim=-1, dtype=q.dtype)
  95. del s1
  96. r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
  97. del s2
  98. del q, k, v
  99. r1 = r1.to(dtype)
  100. r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h)
  101. del r1
  102. return self.to_out(r2)
  103. # -- Taken from https://github.com/invoke-ai/InvokeAI and modified --
  104. mem_total_gb = psutil.virtual_memory().total // (1 << 30)
  105. def einsum_op_compvis(q, k, v):
  106. s = einsum('b i d, b j d -> b i j', q, k)
  107. s = s.softmax(dim=-1, dtype=s.dtype)
  108. return einsum('b i j, b j d -> b i d', s, v)
  109. def einsum_op_slice_0(q, k, v, slice_size):
  110. r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
  111. for i in range(0, q.shape[0], slice_size):
  112. end = i + slice_size
  113. r[i:end] = einsum_op_compvis(q[i:end], k[i:end], v[i:end])
  114. return r
  115. def einsum_op_slice_1(q, k, v, slice_size):
  116. r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
  117. for i in range(0, q.shape[1], slice_size):
  118. end = i + slice_size
  119. r[:, i:end] = einsum_op_compvis(q[:, i:end], k, v)
  120. return r
  121. def einsum_op_mps_v1(q, k, v):
  122. if q.shape[0] * q.shape[1] <= 2**16: # (512x512) max q.shape[1]: 4096
  123. return einsum_op_compvis(q, k, v)
  124. else:
  125. slice_size = math.floor(2**30 / (q.shape[0] * q.shape[1]))
  126. if slice_size % 4096 == 0:
  127. slice_size -= 1
  128. return einsum_op_slice_1(q, k, v, slice_size)
  129. def einsum_op_mps_v2(q, k, v):
  130. if mem_total_gb > 8 and q.shape[0] * q.shape[1] <= 2**16:
  131. return einsum_op_compvis(q, k, v)
  132. else:
  133. return einsum_op_slice_0(q, k, v, 1)
  134. def einsum_op_tensor_mem(q, k, v, max_tensor_mb):
  135. size_mb = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size() // (1 << 20)
  136. if size_mb <= max_tensor_mb:
  137. return einsum_op_compvis(q, k, v)
  138. div = 1 << int((size_mb - 1) / max_tensor_mb).bit_length()
  139. if div <= q.shape[0]:
  140. return einsum_op_slice_0(q, k, v, q.shape[0] // div)
  141. return einsum_op_slice_1(q, k, v, max(q.shape[1] // div, 1))
  142. def einsum_op_cuda(q, k, v):
  143. stats = torch.cuda.memory_stats(q.device)
  144. mem_active = stats['active_bytes.all.current']
  145. mem_reserved = stats['reserved_bytes.all.current']
  146. mem_free_cuda, _ = torch.cuda.mem_get_info(q.device)
  147. mem_free_torch = mem_reserved - mem_active
  148. mem_free_total = mem_free_cuda + mem_free_torch
  149. # Divide factor of safety as there's copying and fragmentation
  150. return einsum_op_tensor_mem(q, k, v, mem_free_total / 3.3 / (1 << 20))
  151. def einsum_op(q, k, v):
  152. if q.device.type == 'cuda':
  153. return einsum_op_cuda(q, k, v)
  154. if q.device.type == 'mps':
  155. if mem_total_gb >= 32 and q.shape[0] % 32 != 0 and q.shape[0] * q.shape[1] < 2**18:
  156. return einsum_op_mps_v1(q, k, v)
  157. return einsum_op_mps_v2(q, k, v)
  158. # Smaller slices are faster due to L2/L3/SLC caches.
  159. # Tested on i7 with 8MB L3 cache.
  160. return einsum_op_tensor_mem(q, k, v, 32)
  161. def split_cross_attention_forward_invokeAI(self, x, context=None, mask=None):
  162. h = self.heads
  163. q = self.to_q(x)
  164. context = default(context, x)
  165. context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context)
  166. k = self.to_k(context_k)
  167. v = self.to_v(context_v)
  168. del context, context_k, context_v, x
  169. dtype = q.dtype
  170. if shared.opts.upcast_attn:
  171. q, k, v = q.float(), k.float(), v if v.device.type == 'mps' else v.float()
  172. with devices.without_autocast(disable=not shared.opts.upcast_attn):
  173. k = k * self.scale
  174. q, k, v = (rearrange(t, 'b n (h d) -> (b h) n d', h=h) for t in (q, k, v))
  175. r = einsum_op(q, k, v)
  176. r = r.to(dtype)
  177. return self.to_out(rearrange(r, '(b h) n d -> b n (h d)', h=h))
  178. # -- End of code from https://github.com/invoke-ai/InvokeAI --
  179. # Based on Birch-san's modified implementation of sub-quadratic attention from https://github.com/Birch-san/diffusers/pull/1
  180. # The sub_quad_attention_forward function is under the MIT License listed under Memory Efficient Attention in the Licenses section of the web UI interface
  181. def sub_quad_attention_forward(self, x, context=None, mask=None):
  182. assert mask is None, "attention-mask not currently implemented for SubQuadraticCrossAttnProcessor."
  183. h = self.heads
  184. q = self.to_q(x)
  185. context = default(context, x)
  186. context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context)
  187. k = self.to_k(context_k)
  188. v = self.to_v(context_v)
  189. del context, context_k, context_v, x
  190. q = q.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1)
  191. k = k.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1)
  192. v = v.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1)
  193. if q.device.type == 'mps':
  194. q, k, v = q.contiguous(), k.contiguous(), v.contiguous()
  195. dtype = q.dtype
  196. if shared.opts.upcast_attn:
  197. q, k = q.float(), k.float()
  198. x = sub_quad_attention(q, k, v, q_chunk_size=shared.cmd_opts.sub_quad_q_chunk_size, kv_chunk_size=shared.cmd_opts.sub_quad_kv_chunk_size, chunk_threshold=shared.cmd_opts.sub_quad_chunk_threshold, use_checkpoint=self.training)
  199. x = x.to(dtype)
  200. x = x.unflatten(0, (-1, h)).transpose(1,2).flatten(start_dim=2)
  201. out_proj, dropout = self.to_out
  202. x = out_proj(x)
  203. x = dropout(x)
  204. return x
  205. def sub_quad_attention(q, k, v, q_chunk_size=1024, kv_chunk_size=None, kv_chunk_size_min=None, chunk_threshold=None, use_checkpoint=True):
  206. bytes_per_token = torch.finfo(q.dtype).bits//8
  207. batch_x_heads, q_tokens, _ = q.shape
  208. _, k_tokens, _ = k.shape
  209. qk_matmul_size_bytes = batch_x_heads * bytes_per_token * q_tokens * k_tokens
  210. if chunk_threshold is None:
  211. chunk_threshold_bytes = int(get_available_vram() * 0.9) if q.device.type == 'mps' else int(get_available_vram() * 0.7)
  212. elif chunk_threshold == 0:
  213. chunk_threshold_bytes = None
  214. else:
  215. chunk_threshold_bytes = int(0.01 * chunk_threshold * get_available_vram())
  216. if kv_chunk_size_min is None and chunk_threshold_bytes is not None:
  217. kv_chunk_size_min = chunk_threshold_bytes // (batch_x_heads * bytes_per_token * (k.shape[2] + v.shape[2]))
  218. elif kv_chunk_size_min == 0:
  219. kv_chunk_size_min = None
  220. if chunk_threshold_bytes is not None and qk_matmul_size_bytes <= chunk_threshold_bytes:
  221. # the big matmul fits into our memory limit; do everything in 1 chunk,
  222. # i.e. send it down the unchunked fast-path
  223. kv_chunk_size = k_tokens
  224. with devices.without_autocast(disable=q.dtype == v.dtype):
  225. return efficient_dot_product_attention(
  226. q,
  227. k,
  228. v,
  229. query_chunk_size=q_chunk_size,
  230. kv_chunk_size=kv_chunk_size,
  231. kv_chunk_size_min = kv_chunk_size_min,
  232. use_checkpoint=use_checkpoint,
  233. )
  234. def get_xformers_flash_attention_op(q, k, v):
  235. if not shared.cmd_opts.xformers_flash_attention:
  236. return None
  237. try:
  238. flash_attention_op = xformers.ops.MemoryEfficientAttentionFlashAttentionOp
  239. fw, bw = flash_attention_op
  240. if fw.supports(xformers.ops.fmha.Inputs(query=q, key=k, value=v, attn_bias=None)):
  241. return flash_attention_op
  242. except Exception as e:
  243. errors.display_once(e, "enabling flash attention")
  244. return None
  245. def xformers_attention_forward(self, x, context=None, mask=None):
  246. h = self.heads
  247. q_in = self.to_q(x)
  248. context = default(context, x)
  249. context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context)
  250. k_in = self.to_k(context_k)
  251. v_in = self.to_v(context_v)
  252. q, k, v = (rearrange(t, 'b n (h d) -> b n h d', h=h) for t in (q_in, k_in, v_in))
  253. del q_in, k_in, v_in
  254. dtype = q.dtype
  255. if shared.opts.upcast_attn:
  256. q, k, v = q.float(), k.float(), v.float()
  257. out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=get_xformers_flash_attention_op(q, k, v))
  258. out = out.to(dtype)
  259. out = rearrange(out, 'b n h d -> b n (h d)', h=h)
  260. return self.to_out(out)
  261. # Based on Diffusers usage of scaled dot product attention from https://github.com/huggingface/diffusers/blob/c7da8fd23359a22d0df2741688b5b4f33c26df21/src/diffusers/models/cross_attention.py
  262. # The scaled_dot_product_attention_forward function contains parts of code under Apache-2.0 license listed under Scaled Dot Product Attention in the Licenses section of the web UI interface
  263. def scaled_dot_product_attention_forward(self, x, context=None, mask=None):
  264. batch_size, sequence_length, inner_dim = x.shape
  265. if mask is not None:
  266. mask = self.prepare_attention_mask(mask, sequence_length, batch_size)
  267. mask = mask.view(batch_size, self.heads, -1, mask.shape[-1])
  268. h = self.heads
  269. q_in = self.to_q(x)
  270. context = default(context, x)
  271. context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context)
  272. k_in = self.to_k(context_k)
  273. v_in = self.to_v(context_v)
  274. head_dim = inner_dim // h
  275. q = q_in.view(batch_size, -1, h, head_dim).transpose(1, 2)
  276. k = k_in.view(batch_size, -1, h, head_dim).transpose(1, 2)
  277. v = v_in.view(batch_size, -1, h, head_dim).transpose(1, 2)
  278. del q_in, k_in, v_in
  279. dtype = q.dtype
  280. if shared.opts.upcast_attn:
  281. q, k, v = q.float(), k.float(), v.float()
  282. # the output of sdp = (batch, num_heads, seq_len, head_dim)
  283. hidden_states = torch.nn.functional.scaled_dot_product_attention(
  284. q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False
  285. )
  286. hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, h * head_dim)
  287. hidden_states = hidden_states.to(dtype)
  288. # linear proj
  289. hidden_states = self.to_out[0](hidden_states)
  290. # dropout
  291. hidden_states = self.to_out[1](hidden_states)
  292. return hidden_states
  293. def scaled_dot_product_no_mem_attention_forward(self, x, context=None, mask=None):
  294. with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=False):
  295. return scaled_dot_product_attention_forward(self, x, context, mask)
  296. def cross_attention_attnblock_forward(self, x):
  297. h_ = x
  298. h_ = self.norm(h_)
  299. q1 = self.q(h_)
  300. k1 = self.k(h_)
  301. v = self.v(h_)
  302. # compute attention
  303. b, c, h, w = q1.shape
  304. q2 = q1.reshape(b, c, h*w)
  305. del q1
  306. q = q2.permute(0, 2, 1) # b,hw,c
  307. del q2
  308. k = k1.reshape(b, c, h*w) # b,c,hw
  309. del k1
  310. h_ = torch.zeros_like(k, device=q.device)
  311. mem_free_total = get_available_vram()
  312. tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * q.element_size()
  313. mem_required = tensor_size * 2.5
  314. steps = 1
  315. if mem_required > mem_free_total:
  316. steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2)))
  317. slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
  318. for i in range(0, q.shape[1], slice_size):
  319. end = i + slice_size
  320. w1 = torch.bmm(q[:, i:end], k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
  321. w2 = w1 * (int(c)**(-0.5))
  322. del w1
  323. w3 = torch.nn.functional.softmax(w2, dim=2, dtype=q.dtype)
  324. del w2
  325. # attend to values
  326. v1 = v.reshape(b, c, h*w)
  327. w4 = w3.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
  328. del w3
  329. h_[:, :, i:end] = torch.bmm(v1, w4) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
  330. del v1, w4
  331. h2 = h_.reshape(b, c, h, w)
  332. del h_
  333. h3 = self.proj_out(h2)
  334. del h2
  335. h3 += x
  336. return h3
  337. def xformers_attnblock_forward(self, x):
  338. try:
  339. h_ = x
  340. h_ = self.norm(h_)
  341. q = self.q(h_)
  342. k = self.k(h_)
  343. v = self.v(h_)
  344. b, c, h, w = q.shape
  345. q, k, v = (rearrange(t, 'b c h w -> b (h w) c') for t in (q, k, v))
  346. dtype = q.dtype
  347. if shared.opts.upcast_attn:
  348. q, k = q.float(), k.float()
  349. q = q.contiguous()
  350. k = k.contiguous()
  351. v = v.contiguous()
  352. out = xformers.ops.memory_efficient_attention(q, k, v, op=get_xformers_flash_attention_op(q, k, v))
  353. out = out.to(dtype)
  354. out = rearrange(out, 'b (h w) c -> b c h w', h=h)
  355. out = self.proj_out(out)
  356. return x + out
  357. except NotImplementedError:
  358. return cross_attention_attnblock_forward(self, x)
  359. def sdp_attnblock_forward(self, x):
  360. h_ = x
  361. h_ = self.norm(h_)
  362. q = self.q(h_)
  363. k = self.k(h_)
  364. v = self.v(h_)
  365. b, c, h, w = q.shape
  366. q, k, v = (rearrange(t, 'b c h w -> b (h w) c') for t in (q, k, v))
  367. dtype = q.dtype
  368. if shared.opts.upcast_attn:
  369. q, k = q.float(), k.float()
  370. q = q.contiguous()
  371. k = k.contiguous()
  372. v = v.contiguous()
  373. out = torch.nn.functional.scaled_dot_product_attention(q, k, v, dropout_p=0.0, is_causal=False)
  374. out = out.to(dtype)
  375. out = rearrange(out, 'b (h w) c -> b c h w', h=h)
  376. out = self.proj_out(out)
  377. return x + out
  378. def sdp_no_mem_attnblock_forward(self, x):
  379. with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=False):
  380. return sdp_attnblock_forward(self, x)
  381. def sub_quad_attnblock_forward(self, x):
  382. h_ = x
  383. h_ = self.norm(h_)
  384. q = self.q(h_)
  385. k = self.k(h_)
  386. v = self.v(h_)
  387. b, c, h, w = q.shape
  388. q, k, v = (rearrange(t, 'b c h w -> b (h w) c') for t in (q, k, v))
  389. q = q.contiguous()
  390. k = k.contiguous()
  391. v = v.contiguous()
  392. out = sub_quad_attention(q, k, v, q_chunk_size=shared.cmd_opts.sub_quad_q_chunk_size, kv_chunk_size=shared.cmd_opts.sub_quad_kv_chunk_size, chunk_threshold=shared.cmd_opts.sub_quad_chunk_threshold, use_checkpoint=self.training)
  393. out = rearrange(out, 'b (h w) c -> b c h w', h=h)
  394. out = self.proj_out(out)
  395. return x + out