sd_hijack_optimizations.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633
  1. from __future__ import annotations
  2. import math
  3. import psutil
  4. import torch
  5. from torch import einsum
  6. from ldm.util import default
  7. from einops import rearrange
  8. from modules import shared, errors, devices, sub_quadratic_attention
  9. from modules.hypernetworks import hypernetwork
  10. import ldm.modules.attention
  11. import ldm.modules.diffusionmodules.model
  12. diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.AttnBlock.forward
  13. class SdOptimization:
  14. name: str = None
  15. label: str | None = None
  16. cmd_opt: str | None = None
  17. priority: int = 0
  18. def title(self):
  19. if self.label is None:
  20. return self.name
  21. return f"{self.name} - {self.label}"
  22. def is_available(self):
  23. return True
  24. def apply(self):
  25. pass
  26. def undo(self):
  27. ldm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward
  28. ldm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward
  29. class SdOptimizationXformers(SdOptimization):
  30. name = "xformers"
  31. cmd_opt = "xformers"
  32. priority = 100
  33. def is_available(self):
  34. return shared.cmd_opts.force_enable_xformers or (shared.xformers_available and torch.cuda.is_available() and (6, 0) <= torch.cuda.get_device_capability(shared.device) <= (9, 0))
  35. def apply(self):
  36. ldm.modules.attention.CrossAttention.forward = xformers_attention_forward
  37. ldm.modules.diffusionmodules.model.AttnBlock.forward = xformers_attnblock_forward
  38. class SdOptimizationSdpNoMem(SdOptimization):
  39. name = "sdp-no-mem"
  40. label = "scaled dot product without memory efficient attention"
  41. cmd_opt = "opt_sdp_no_mem_attention"
  42. priority = 80
  43. def is_available(self):
  44. return hasattr(torch.nn.functional, "scaled_dot_product_attention") and callable(torch.nn.functional.scaled_dot_product_attention)
  45. def apply(self):
  46. ldm.modules.attention.CrossAttention.forward = scaled_dot_product_no_mem_attention_forward
  47. ldm.modules.diffusionmodules.model.AttnBlock.forward = sdp_no_mem_attnblock_forward
  48. class SdOptimizationSdp(SdOptimizationSdpNoMem):
  49. name = "sdp"
  50. label = "scaled dot product"
  51. cmd_opt = "opt_sdp_attention"
  52. priority = 70
  53. def apply(self):
  54. ldm.modules.attention.CrossAttention.forward = scaled_dot_product_attention_forward
  55. ldm.modules.diffusionmodules.model.AttnBlock.forward = sdp_attnblock_forward
  56. class SdOptimizationSubQuad(SdOptimization):
  57. name = "sub-quadratic"
  58. cmd_opt = "opt_sub_quad_attention"
  59. priority = 10
  60. def apply(self):
  61. ldm.modules.attention.CrossAttention.forward = sub_quad_attention_forward
  62. ldm.modules.diffusionmodules.model.AttnBlock.forward = sub_quad_attnblock_forward
  63. class SdOptimizationV1(SdOptimization):
  64. name = "V1"
  65. label = "original v1"
  66. cmd_opt = "opt_split_attention_v1"
  67. priority = 10
  68. def apply(self):
  69. ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward_v1
  70. class SdOptimizationInvokeAI(SdOptimization):
  71. name = "InvokeAI"
  72. cmd_opt = "opt_split_attention_invokeai"
  73. @property
  74. def priority(self):
  75. return 1000 if not torch.cuda.is_available() else 10
  76. def apply(self):
  77. ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward_invokeAI
  78. class SdOptimizationDoggettx(SdOptimization):
  79. name = "Doggettx"
  80. cmd_opt = "opt_split_attention"
  81. priority = 90
  82. def apply(self):
  83. ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward
  84. ldm.modules.diffusionmodules.model.AttnBlock.forward = cross_attention_attnblock_forward
  85. def list_optimizers(res):
  86. res.extend([
  87. SdOptimizationXformers(),
  88. SdOptimizationSdpNoMem(),
  89. SdOptimizationSdp(),
  90. SdOptimizationSubQuad(),
  91. SdOptimizationV1(),
  92. SdOptimizationInvokeAI(),
  93. SdOptimizationDoggettx(),
  94. ])
  95. if shared.cmd_opts.xformers or shared.cmd_opts.force_enable_xformers:
  96. try:
  97. import xformers.ops
  98. shared.xformers_available = True
  99. except Exception:
  100. errors.report("Cannot import xformers", exc_info=True)
  101. def get_available_vram():
  102. if shared.device.type == 'cuda':
  103. stats = torch.cuda.memory_stats(shared.device)
  104. mem_active = stats['active_bytes.all.current']
  105. mem_reserved = stats['reserved_bytes.all.current']
  106. mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device())
  107. mem_free_torch = mem_reserved - mem_active
  108. mem_free_total = mem_free_cuda + mem_free_torch
  109. return mem_free_total
  110. else:
  111. return psutil.virtual_memory().available
  112. # see https://github.com/basujindal/stable-diffusion/pull/117 for discussion
  113. def split_cross_attention_forward_v1(self, x, context=None, mask=None):
  114. h = self.heads
  115. q_in = self.to_q(x)
  116. context = default(context, x)
  117. context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context)
  118. k_in = self.to_k(context_k)
  119. v_in = self.to_v(context_v)
  120. del context, context_k, context_v, x
  121. q, k, v = (rearrange(t, 'b n (h d) -> (b h) n d', h=h) for t in (q_in, k_in, v_in))
  122. del q_in, k_in, v_in
  123. dtype = q.dtype
  124. if shared.opts.upcast_attn:
  125. q, k, v = q.float(), k.float(), v.float()
  126. with devices.without_autocast(disable=not shared.opts.upcast_attn):
  127. r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
  128. for i in range(0, q.shape[0], 2):
  129. end = i + 2
  130. s1 = einsum('b i d, b j d -> b i j', q[i:end], k[i:end])
  131. s1 *= self.scale
  132. s2 = s1.softmax(dim=-1)
  133. del s1
  134. r1[i:end] = einsum('b i j, b j d -> b i d', s2, v[i:end])
  135. del s2
  136. del q, k, v
  137. r1 = r1.to(dtype)
  138. r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h)
  139. del r1
  140. return self.to_out(r2)
  141. # taken from https://github.com/Doggettx/stable-diffusion and modified
  142. def split_cross_attention_forward(self, x, context=None, mask=None):
  143. h = self.heads
  144. q_in = self.to_q(x)
  145. context = default(context, x)
  146. context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context)
  147. k_in = self.to_k(context_k)
  148. v_in = self.to_v(context_v)
  149. dtype = q_in.dtype
  150. if shared.opts.upcast_attn:
  151. q_in, k_in, v_in = q_in.float(), k_in.float(), v_in if v_in.device.type == 'mps' else v_in.float()
  152. with devices.without_autocast(disable=not shared.opts.upcast_attn):
  153. k_in = k_in * self.scale
  154. del context, x
  155. q, k, v = (rearrange(t, 'b n (h d) -> (b h) n d', h=h) for t in (q_in, k_in, v_in))
  156. del q_in, k_in, v_in
  157. r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
  158. mem_free_total = get_available_vram()
  159. gb = 1024 ** 3
  160. tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size()
  161. modifier = 3 if q.element_size() == 2 else 2.5
  162. mem_required = tensor_size * modifier
  163. steps = 1
  164. if mem_required > mem_free_total:
  165. steps = 2 ** (math.ceil(math.log(mem_required / mem_free_total, 2)))
  166. # print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB "
  167. # f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}")
  168. if steps > 64:
  169. max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64
  170. raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). '
  171. f'Need: {mem_required / 64 / gb:0.1f}GB free, Have:{mem_free_total / gb:0.1f}GB free')
  172. slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
  173. for i in range(0, q.shape[1], slice_size):
  174. end = i + slice_size
  175. s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k)
  176. s2 = s1.softmax(dim=-1, dtype=q.dtype)
  177. del s1
  178. r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
  179. del s2
  180. del q, k, v
  181. r1 = r1.to(dtype)
  182. r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h)
  183. del r1
  184. return self.to_out(r2)
  185. # -- Taken from https://github.com/invoke-ai/InvokeAI and modified --
  186. mem_total_gb = psutil.virtual_memory().total // (1 << 30)
  187. def einsum_op_compvis(q, k, v):
  188. s = einsum('b i d, b j d -> b i j', q, k)
  189. s = s.softmax(dim=-1, dtype=s.dtype)
  190. return einsum('b i j, b j d -> b i d', s, v)
  191. def einsum_op_slice_0(q, k, v, slice_size):
  192. r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
  193. for i in range(0, q.shape[0], slice_size):
  194. end = i + slice_size
  195. r[i:end] = einsum_op_compvis(q[i:end], k[i:end], v[i:end])
  196. return r
  197. def einsum_op_slice_1(q, k, v, slice_size):
  198. r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
  199. for i in range(0, q.shape[1], slice_size):
  200. end = i + slice_size
  201. r[:, i:end] = einsum_op_compvis(q[:, i:end], k, v)
  202. return r
  203. def einsum_op_mps_v1(q, k, v):
  204. if q.shape[0] * q.shape[1] <= 2**16: # (512x512) max q.shape[1]: 4096
  205. return einsum_op_compvis(q, k, v)
  206. else:
  207. slice_size = math.floor(2**30 / (q.shape[0] * q.shape[1]))
  208. if slice_size % 4096 == 0:
  209. slice_size -= 1
  210. return einsum_op_slice_1(q, k, v, slice_size)
  211. def einsum_op_mps_v2(q, k, v):
  212. if mem_total_gb > 8 and q.shape[0] * q.shape[1] <= 2**16:
  213. return einsum_op_compvis(q, k, v)
  214. else:
  215. return einsum_op_slice_0(q, k, v, 1)
  216. def einsum_op_tensor_mem(q, k, v, max_tensor_mb):
  217. size_mb = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size() // (1 << 20)
  218. if size_mb <= max_tensor_mb:
  219. return einsum_op_compvis(q, k, v)
  220. div = 1 << int((size_mb - 1) / max_tensor_mb).bit_length()
  221. if div <= q.shape[0]:
  222. return einsum_op_slice_0(q, k, v, q.shape[0] // div)
  223. return einsum_op_slice_1(q, k, v, max(q.shape[1] // div, 1))
  224. def einsum_op_cuda(q, k, v):
  225. stats = torch.cuda.memory_stats(q.device)
  226. mem_active = stats['active_bytes.all.current']
  227. mem_reserved = stats['reserved_bytes.all.current']
  228. mem_free_cuda, _ = torch.cuda.mem_get_info(q.device)
  229. mem_free_torch = mem_reserved - mem_active
  230. mem_free_total = mem_free_cuda + mem_free_torch
  231. # Divide factor of safety as there's copying and fragmentation
  232. return einsum_op_tensor_mem(q, k, v, mem_free_total / 3.3 / (1 << 20))
  233. def einsum_op(q, k, v):
  234. if q.device.type == 'cuda':
  235. return einsum_op_cuda(q, k, v)
  236. if q.device.type == 'mps':
  237. if mem_total_gb >= 32 and q.shape[0] % 32 != 0 and q.shape[0] * q.shape[1] < 2**18:
  238. return einsum_op_mps_v1(q, k, v)
  239. return einsum_op_mps_v2(q, k, v)
  240. # Smaller slices are faster due to L2/L3/SLC caches.
  241. # Tested on i7 with 8MB L3 cache.
  242. return einsum_op_tensor_mem(q, k, v, 32)
  243. def split_cross_attention_forward_invokeAI(self, x, context=None, mask=None):
  244. h = self.heads
  245. q = self.to_q(x)
  246. context = default(context, x)
  247. context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context)
  248. k = self.to_k(context_k)
  249. v = self.to_v(context_v)
  250. del context, context_k, context_v, x
  251. dtype = q.dtype
  252. if shared.opts.upcast_attn:
  253. q, k, v = q.float(), k.float(), v if v.device.type == 'mps' else v.float()
  254. with devices.without_autocast(disable=not shared.opts.upcast_attn):
  255. k = k * self.scale
  256. q, k, v = (rearrange(t, 'b n (h d) -> (b h) n d', h=h) for t in (q, k, v))
  257. r = einsum_op(q, k, v)
  258. r = r.to(dtype)
  259. return self.to_out(rearrange(r, '(b h) n d -> b n (h d)', h=h))
  260. # -- End of code from https://github.com/invoke-ai/InvokeAI --
  261. # Based on Birch-san's modified implementation of sub-quadratic attention from https://github.com/Birch-san/diffusers/pull/1
  262. # The sub_quad_attention_forward function is under the MIT License listed under Memory Efficient Attention in the Licenses section of the web UI interface
  263. def sub_quad_attention_forward(self, x, context=None, mask=None):
  264. assert mask is None, "attention-mask not currently implemented for SubQuadraticCrossAttnProcessor."
  265. h = self.heads
  266. q = self.to_q(x)
  267. context = default(context, x)
  268. context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context)
  269. k = self.to_k(context_k)
  270. v = self.to_v(context_v)
  271. del context, context_k, context_v, x
  272. q = q.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1)
  273. k = k.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1)
  274. v = v.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1)
  275. if q.device.type == 'mps':
  276. q, k, v = q.contiguous(), k.contiguous(), v.contiguous()
  277. dtype = q.dtype
  278. if shared.opts.upcast_attn:
  279. q, k = q.float(), k.float()
  280. x = sub_quad_attention(q, k, v, q_chunk_size=shared.cmd_opts.sub_quad_q_chunk_size, kv_chunk_size=shared.cmd_opts.sub_quad_kv_chunk_size, chunk_threshold=shared.cmd_opts.sub_quad_chunk_threshold, use_checkpoint=self.training)
  281. x = x.to(dtype)
  282. x = x.unflatten(0, (-1, h)).transpose(1,2).flatten(start_dim=2)
  283. out_proj, dropout = self.to_out
  284. x = out_proj(x)
  285. x = dropout(x)
  286. return x
  287. def sub_quad_attention(q, k, v, q_chunk_size=1024, kv_chunk_size=None, kv_chunk_size_min=None, chunk_threshold=None, use_checkpoint=True):
  288. bytes_per_token = torch.finfo(q.dtype).bits//8
  289. batch_x_heads, q_tokens, _ = q.shape
  290. _, k_tokens, _ = k.shape
  291. qk_matmul_size_bytes = batch_x_heads * bytes_per_token * q_tokens * k_tokens
  292. if chunk_threshold is None:
  293. chunk_threshold_bytes = int(get_available_vram() * 0.9) if q.device.type == 'mps' else int(get_available_vram() * 0.7)
  294. elif chunk_threshold == 0:
  295. chunk_threshold_bytes = None
  296. else:
  297. chunk_threshold_bytes = int(0.01 * chunk_threshold * get_available_vram())
  298. if kv_chunk_size_min is None and chunk_threshold_bytes is not None:
  299. kv_chunk_size_min = chunk_threshold_bytes // (batch_x_heads * bytes_per_token * (k.shape[2] + v.shape[2]))
  300. elif kv_chunk_size_min == 0:
  301. kv_chunk_size_min = None
  302. if chunk_threshold_bytes is not None and qk_matmul_size_bytes <= chunk_threshold_bytes:
  303. # the big matmul fits into our memory limit; do everything in 1 chunk,
  304. # i.e. send it down the unchunked fast-path
  305. kv_chunk_size = k_tokens
  306. with devices.without_autocast(disable=q.dtype == v.dtype):
  307. return sub_quadratic_attention.efficient_dot_product_attention(
  308. q,
  309. k,
  310. v,
  311. query_chunk_size=q_chunk_size,
  312. kv_chunk_size=kv_chunk_size,
  313. kv_chunk_size_min = kv_chunk_size_min,
  314. use_checkpoint=use_checkpoint,
  315. )
  316. def get_xformers_flash_attention_op(q, k, v):
  317. if not shared.cmd_opts.xformers_flash_attention:
  318. return None
  319. try:
  320. flash_attention_op = xformers.ops.MemoryEfficientAttentionFlashAttentionOp
  321. fw, bw = flash_attention_op
  322. if fw.supports(xformers.ops.fmha.Inputs(query=q, key=k, value=v, attn_bias=None)):
  323. return flash_attention_op
  324. except Exception as e:
  325. errors.display_once(e, "enabling flash attention")
  326. return None
  327. def xformers_attention_forward(self, x, context=None, mask=None):
  328. h = self.heads
  329. q_in = self.to_q(x)
  330. context = default(context, x)
  331. context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context)
  332. k_in = self.to_k(context_k)
  333. v_in = self.to_v(context_v)
  334. q, k, v = (rearrange(t, 'b n (h d) -> b n h d', h=h) for t in (q_in, k_in, v_in))
  335. del q_in, k_in, v_in
  336. dtype = q.dtype
  337. if shared.opts.upcast_attn:
  338. q, k, v = q.float(), k.float(), v.float()
  339. out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=get_xformers_flash_attention_op(q, k, v))
  340. out = out.to(dtype)
  341. out = rearrange(out, 'b n h d -> b n (h d)', h=h)
  342. return self.to_out(out)
  343. # Based on Diffusers usage of scaled dot product attention from https://github.com/huggingface/diffusers/blob/c7da8fd23359a22d0df2741688b5b4f33c26df21/src/diffusers/models/cross_attention.py
  344. # The scaled_dot_product_attention_forward function contains parts of code under Apache-2.0 license listed under Scaled Dot Product Attention in the Licenses section of the web UI interface
  345. def scaled_dot_product_attention_forward(self, x, context=None, mask=None):
  346. batch_size, sequence_length, inner_dim = x.shape
  347. if mask is not None:
  348. mask = self.prepare_attention_mask(mask, sequence_length, batch_size)
  349. mask = mask.view(batch_size, self.heads, -1, mask.shape[-1])
  350. h = self.heads
  351. q_in = self.to_q(x)
  352. context = default(context, x)
  353. context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context)
  354. k_in = self.to_k(context_k)
  355. v_in = self.to_v(context_v)
  356. head_dim = inner_dim // h
  357. q = q_in.view(batch_size, -1, h, head_dim).transpose(1, 2)
  358. k = k_in.view(batch_size, -1, h, head_dim).transpose(1, 2)
  359. v = v_in.view(batch_size, -1, h, head_dim).transpose(1, 2)
  360. del q_in, k_in, v_in
  361. dtype = q.dtype
  362. if shared.opts.upcast_attn:
  363. q, k, v = q.float(), k.float(), v.float()
  364. # the output of sdp = (batch, num_heads, seq_len, head_dim)
  365. hidden_states = torch.nn.functional.scaled_dot_product_attention(
  366. q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False
  367. )
  368. hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, h * head_dim)
  369. hidden_states = hidden_states.to(dtype)
  370. # linear proj
  371. hidden_states = self.to_out[0](hidden_states)
  372. # dropout
  373. hidden_states = self.to_out[1](hidden_states)
  374. return hidden_states
  375. def scaled_dot_product_no_mem_attention_forward(self, x, context=None, mask=None):
  376. with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=False):
  377. return scaled_dot_product_attention_forward(self, x, context, mask)
  378. def cross_attention_attnblock_forward(self, x):
  379. h_ = x
  380. h_ = self.norm(h_)
  381. q1 = self.q(h_)
  382. k1 = self.k(h_)
  383. v = self.v(h_)
  384. # compute attention
  385. b, c, h, w = q1.shape
  386. q2 = q1.reshape(b, c, h*w)
  387. del q1
  388. q = q2.permute(0, 2, 1) # b,hw,c
  389. del q2
  390. k = k1.reshape(b, c, h*w) # b,c,hw
  391. del k1
  392. h_ = torch.zeros_like(k, device=q.device)
  393. mem_free_total = get_available_vram()
  394. tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * q.element_size()
  395. mem_required = tensor_size * 2.5
  396. steps = 1
  397. if mem_required > mem_free_total:
  398. steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2)))
  399. slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
  400. for i in range(0, q.shape[1], slice_size):
  401. end = i + slice_size
  402. w1 = torch.bmm(q[:, i:end], k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
  403. w2 = w1 * (int(c)**(-0.5))
  404. del w1
  405. w3 = torch.nn.functional.softmax(w2, dim=2, dtype=q.dtype)
  406. del w2
  407. # attend to values
  408. v1 = v.reshape(b, c, h*w)
  409. w4 = w3.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
  410. del w3
  411. h_[:, :, i:end] = torch.bmm(v1, w4) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
  412. del v1, w4
  413. h2 = h_.reshape(b, c, h, w)
  414. del h_
  415. h3 = self.proj_out(h2)
  416. del h2
  417. h3 += x
  418. return h3
  419. def xformers_attnblock_forward(self, x):
  420. try:
  421. h_ = x
  422. h_ = self.norm(h_)
  423. q = self.q(h_)
  424. k = self.k(h_)
  425. v = self.v(h_)
  426. b, c, h, w = q.shape
  427. q, k, v = (rearrange(t, 'b c h w -> b (h w) c') for t in (q, k, v))
  428. dtype = q.dtype
  429. if shared.opts.upcast_attn:
  430. q, k = q.float(), k.float()
  431. q = q.contiguous()
  432. k = k.contiguous()
  433. v = v.contiguous()
  434. out = xformers.ops.memory_efficient_attention(q, k, v, op=get_xformers_flash_attention_op(q, k, v))
  435. out = out.to(dtype)
  436. out = rearrange(out, 'b (h w) c -> b c h w', h=h)
  437. out = self.proj_out(out)
  438. return x + out
  439. except NotImplementedError:
  440. return cross_attention_attnblock_forward(self, x)
  441. def sdp_attnblock_forward(self, x):
  442. h_ = x
  443. h_ = self.norm(h_)
  444. q = self.q(h_)
  445. k = self.k(h_)
  446. v = self.v(h_)
  447. b, c, h, w = q.shape
  448. q, k, v = (rearrange(t, 'b c h w -> b (h w) c') for t in (q, k, v))
  449. dtype = q.dtype
  450. if shared.opts.upcast_attn:
  451. q, k, v = q.float(), k.float(), v.float()
  452. q = q.contiguous()
  453. k = k.contiguous()
  454. v = v.contiguous()
  455. out = torch.nn.functional.scaled_dot_product_attention(q, k, v, dropout_p=0.0, is_causal=False)
  456. out = out.to(dtype)
  457. out = rearrange(out, 'b (h w) c -> b c h w', h=h)
  458. out = self.proj_out(out)
  459. return x + out
  460. def sdp_no_mem_attnblock_forward(self, x):
  461. with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=False):
  462. return sdp_attnblock_forward(self, x)
  463. def sub_quad_attnblock_forward(self, x):
  464. h_ = x
  465. h_ = self.norm(h_)
  466. q = self.q(h_)
  467. k = self.k(h_)
  468. v = self.v(h_)
  469. b, c, h, w = q.shape
  470. q, k, v = (rearrange(t, 'b c h w -> b (h w) c') for t in (q, k, v))
  471. q = q.contiguous()
  472. k = k.contiguous()
  473. v = v.contiguous()
  474. out = sub_quad_attention(q, k, v, q_chunk_size=shared.cmd_opts.sub_quad_q_chunk_size, kv_chunk_size=shared.cmd_opts.sub_quad_kv_chunk_size, chunk_threshold=shared.cmd_opts.sub_quad_chunk_threshold, use_checkpoint=self.training)
  475. out = rearrange(out, 'b (h w) c -> b c h w', h=h)
  476. out = self.proj_out(out)
  477. return x + out