sd_hijack_optimizations.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635
  1. import math
  2. import sys
  3. import traceback
  4. import psutil
  5. import torch
  6. from torch import einsum
  7. from ldm.util import default
  8. from einops import rearrange
  9. from modules import shared, errors, devices, sub_quadratic_attention
  10. from modules.hypernetworks import hypernetwork
  11. import ldm.modules.attention
  12. import ldm.modules.diffusionmodules.model
  13. diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.AttnBlock.forward
  14. class SdOptimization:
  15. name: str = None
  16. label: str | None = None
  17. cmd_opt: str | None = None
  18. priority: int = 0
  19. def title(self):
  20. if self.label is None:
  21. return self.name
  22. return f"{self.name} - {self.label}"
  23. def is_available(self):
  24. return True
  25. def apply(self):
  26. pass
  27. def undo(self):
  28. ldm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward
  29. ldm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward
  30. class SdOptimizationXformers(SdOptimization):
  31. name = "xformers"
  32. cmd_opt = "xformers"
  33. priority = 100
  34. def is_available(self):
  35. return shared.cmd_opts.force_enable_xformers or (shared.xformers_available and torch.version.cuda and (6, 0) <= torch.cuda.get_device_capability(shared.device) <= (9, 0))
  36. def apply(self):
  37. ldm.modules.attention.CrossAttention.forward = xformers_attention_forward
  38. ldm.modules.diffusionmodules.model.AttnBlock.forward = xformers_attnblock_forward
  39. class SdOptimizationSdpNoMem(SdOptimization):
  40. name = "sdp-no-mem"
  41. label = "scaled dot product without memory efficient attention"
  42. cmd_opt = "opt_sdp_no_mem_attention"
  43. priority = 90
  44. def is_available(self):
  45. return hasattr(torch.nn.functional, "scaled_dot_product_attention") and callable(torch.nn.functional.scaled_dot_product_attention)
  46. def apply(self):
  47. ldm.modules.attention.CrossAttention.forward = scaled_dot_product_no_mem_attention_forward
  48. ldm.modules.diffusionmodules.model.AttnBlock.forward = sdp_no_mem_attnblock_forward
  49. class SdOptimizationSdp(SdOptimizationSdpNoMem):
  50. name = "sdp"
  51. label = "scaled dot product"
  52. cmd_opt = "opt_sdp_attention"
  53. priority = 80
  54. def apply(self):
  55. ldm.modules.attention.CrossAttention.forward = scaled_dot_product_attention_forward
  56. ldm.modules.diffusionmodules.model.AttnBlock.forward = sdp_attnblock_forward
  57. class SdOptimizationSubQuad(SdOptimization):
  58. name = "sub-quadratic"
  59. cmd_opt = "opt_sub_quad_attention"
  60. priority = 10
  61. def apply(self):
  62. ldm.modules.attention.CrossAttention.forward = sub_quad_attention_forward
  63. ldm.modules.diffusionmodules.model.AttnBlock.forward = sub_quad_attnblock_forward
  64. class SdOptimizationV1(SdOptimization):
  65. name = "V1"
  66. label = "original v1"
  67. cmd_opt = "opt_split_attention_v1"
  68. priority = 10
  69. def apply(self):
  70. ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward_v1
  71. class SdOptimizationInvokeAI(SdOptimization):
  72. name = "InvokeAI"
  73. cmd_opt = "opt_split_attention_invokeai"
  74. @property
  75. def priority(self):
  76. return 1000 if not torch.cuda.is_available() else 10
  77. def apply(self):
  78. ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward_invokeAI
  79. class SdOptimizationDoggettx(SdOptimization):
  80. name = "Doggettx"
  81. cmd_opt = "opt_split_attention"
  82. priority = 20
  83. def apply(self):
  84. ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward
  85. ldm.modules.diffusionmodules.model.AttnBlock.forward = cross_attention_attnblock_forward
  86. def list_optimizers(res):
  87. res.extend([
  88. SdOptimizationXformers(),
  89. SdOptimizationSdpNoMem(),
  90. SdOptimizationSdp(),
  91. SdOptimizationSubQuad(),
  92. SdOptimizationV1(),
  93. SdOptimizationInvokeAI(),
  94. SdOptimizationDoggettx(),
  95. ])
  96. if shared.cmd_opts.xformers or shared.cmd_opts.force_enable_xformers:
  97. try:
  98. import xformers.ops
  99. shared.xformers_available = True
  100. except Exception:
  101. print("Cannot import xformers", file=sys.stderr)
  102. print(traceback.format_exc(), file=sys.stderr)
  103. def get_available_vram():
  104. if shared.device.type == 'cuda':
  105. stats = torch.cuda.memory_stats(shared.device)
  106. mem_active = stats['active_bytes.all.current']
  107. mem_reserved = stats['reserved_bytes.all.current']
  108. mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device())
  109. mem_free_torch = mem_reserved - mem_active
  110. mem_free_total = mem_free_cuda + mem_free_torch
  111. return mem_free_total
  112. else:
  113. return psutil.virtual_memory().available
  114. # see https://github.com/basujindal/stable-diffusion/pull/117 for discussion
  115. def split_cross_attention_forward_v1(self, x, context=None, mask=None):
  116. h = self.heads
  117. q_in = self.to_q(x)
  118. context = default(context, x)
  119. context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context)
  120. k_in = self.to_k(context_k)
  121. v_in = self.to_v(context_v)
  122. del context, context_k, context_v, x
  123. q, k, v = (rearrange(t, 'b n (h d) -> (b h) n d', h=h) for t in (q_in, k_in, v_in))
  124. del q_in, k_in, v_in
  125. dtype = q.dtype
  126. if shared.opts.upcast_attn:
  127. q, k, v = q.float(), k.float(), v.float()
  128. with devices.without_autocast(disable=not shared.opts.upcast_attn):
  129. r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
  130. for i in range(0, q.shape[0], 2):
  131. end = i + 2
  132. s1 = einsum('b i d, b j d -> b i j', q[i:end], k[i:end])
  133. s1 *= self.scale
  134. s2 = s1.softmax(dim=-1)
  135. del s1
  136. r1[i:end] = einsum('b i j, b j d -> b i d', s2, v[i:end])
  137. del s2
  138. del q, k, v
  139. r1 = r1.to(dtype)
  140. r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h)
  141. del r1
  142. return self.to_out(r2)
  143. # taken from https://github.com/Doggettx/stable-diffusion and modified
  144. def split_cross_attention_forward(self, x, context=None, mask=None):
  145. h = self.heads
  146. q_in = self.to_q(x)
  147. context = default(context, x)
  148. context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context)
  149. k_in = self.to_k(context_k)
  150. v_in = self.to_v(context_v)
  151. dtype = q_in.dtype
  152. if shared.opts.upcast_attn:
  153. q_in, k_in, v_in = q_in.float(), k_in.float(), v_in if v_in.device.type == 'mps' else v_in.float()
  154. with devices.without_autocast(disable=not shared.opts.upcast_attn):
  155. k_in = k_in * self.scale
  156. del context, x
  157. q, k, v = (rearrange(t, 'b n (h d) -> (b h) n d', h=h) for t in (q_in, k_in, v_in))
  158. del q_in, k_in, v_in
  159. r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
  160. mem_free_total = get_available_vram()
  161. gb = 1024 ** 3
  162. tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size()
  163. modifier = 3 if q.element_size() == 2 else 2.5
  164. mem_required = tensor_size * modifier
  165. steps = 1
  166. if mem_required > mem_free_total:
  167. steps = 2 ** (math.ceil(math.log(mem_required / mem_free_total, 2)))
  168. # print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB "
  169. # f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}")
  170. if steps > 64:
  171. max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64
  172. raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). '
  173. f'Need: {mem_required / 64 / gb:0.1f}GB free, Have:{mem_free_total / gb:0.1f}GB free')
  174. slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
  175. for i in range(0, q.shape[1], slice_size):
  176. end = i + slice_size
  177. s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k)
  178. s2 = s1.softmax(dim=-1, dtype=q.dtype)
  179. del s1
  180. r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
  181. del s2
  182. del q, k, v
  183. r1 = r1.to(dtype)
  184. r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h)
  185. del r1
  186. return self.to_out(r2)
  187. # -- Taken from https://github.com/invoke-ai/InvokeAI and modified --
  188. mem_total_gb = psutil.virtual_memory().total // (1 << 30)
  189. def einsum_op_compvis(q, k, v):
  190. s = einsum('b i d, b j d -> b i j', q, k)
  191. s = s.softmax(dim=-1, dtype=s.dtype)
  192. return einsum('b i j, b j d -> b i d', s, v)
  193. def einsum_op_slice_0(q, k, v, slice_size):
  194. r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
  195. for i in range(0, q.shape[0], slice_size):
  196. end = i + slice_size
  197. r[i:end] = einsum_op_compvis(q[i:end], k[i:end], v[i:end])
  198. return r
  199. def einsum_op_slice_1(q, k, v, slice_size):
  200. r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
  201. for i in range(0, q.shape[1], slice_size):
  202. end = i + slice_size
  203. r[:, i:end] = einsum_op_compvis(q[:, i:end], k, v)
  204. return r
  205. def einsum_op_mps_v1(q, k, v):
  206. if q.shape[0] * q.shape[1] <= 2**16: # (512x512) max q.shape[1]: 4096
  207. return einsum_op_compvis(q, k, v)
  208. else:
  209. slice_size = math.floor(2**30 / (q.shape[0] * q.shape[1]))
  210. if slice_size % 4096 == 0:
  211. slice_size -= 1
  212. return einsum_op_slice_1(q, k, v, slice_size)
  213. def einsum_op_mps_v2(q, k, v):
  214. if mem_total_gb > 8 and q.shape[0] * q.shape[1] <= 2**16:
  215. return einsum_op_compvis(q, k, v)
  216. else:
  217. return einsum_op_slice_0(q, k, v, 1)
  218. def einsum_op_tensor_mem(q, k, v, max_tensor_mb):
  219. size_mb = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size() // (1 << 20)
  220. if size_mb <= max_tensor_mb:
  221. return einsum_op_compvis(q, k, v)
  222. div = 1 << int((size_mb - 1) / max_tensor_mb).bit_length()
  223. if div <= q.shape[0]:
  224. return einsum_op_slice_0(q, k, v, q.shape[0] // div)
  225. return einsum_op_slice_1(q, k, v, max(q.shape[1] // div, 1))
  226. def einsum_op_cuda(q, k, v):
  227. stats = torch.cuda.memory_stats(q.device)
  228. mem_active = stats['active_bytes.all.current']
  229. mem_reserved = stats['reserved_bytes.all.current']
  230. mem_free_cuda, _ = torch.cuda.mem_get_info(q.device)
  231. mem_free_torch = mem_reserved - mem_active
  232. mem_free_total = mem_free_cuda + mem_free_torch
  233. # Divide factor of safety as there's copying and fragmentation
  234. return einsum_op_tensor_mem(q, k, v, mem_free_total / 3.3 / (1 << 20))
  235. def einsum_op(q, k, v):
  236. if q.device.type == 'cuda':
  237. return einsum_op_cuda(q, k, v)
  238. if q.device.type == 'mps':
  239. if mem_total_gb >= 32 and q.shape[0] % 32 != 0 and q.shape[0] * q.shape[1] < 2**18:
  240. return einsum_op_mps_v1(q, k, v)
  241. return einsum_op_mps_v2(q, k, v)
  242. # Smaller slices are faster due to L2/L3/SLC caches.
  243. # Tested on i7 with 8MB L3 cache.
  244. return einsum_op_tensor_mem(q, k, v, 32)
  245. def split_cross_attention_forward_invokeAI(self, x, context=None, mask=None):
  246. h = self.heads
  247. q = self.to_q(x)
  248. context = default(context, x)
  249. context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context)
  250. k = self.to_k(context_k)
  251. v = self.to_v(context_v)
  252. del context, context_k, context_v, x
  253. dtype = q.dtype
  254. if shared.opts.upcast_attn:
  255. q, k, v = q.float(), k.float(), v if v.device.type == 'mps' else v.float()
  256. with devices.without_autocast(disable=not shared.opts.upcast_attn):
  257. k = k * self.scale
  258. q, k, v = (rearrange(t, 'b n (h d) -> (b h) n d', h=h) for t in (q, k, v))
  259. r = einsum_op(q, k, v)
  260. r = r.to(dtype)
  261. return self.to_out(rearrange(r, '(b h) n d -> b n (h d)', h=h))
  262. # -- End of code from https://github.com/invoke-ai/InvokeAI --
  263. # Based on Birch-san's modified implementation of sub-quadratic attention from https://github.com/Birch-san/diffusers/pull/1
  264. # The sub_quad_attention_forward function is under the MIT License listed under Memory Efficient Attention in the Licenses section of the web UI interface
  265. def sub_quad_attention_forward(self, x, context=None, mask=None):
  266. assert mask is None, "attention-mask not currently implemented for SubQuadraticCrossAttnProcessor."
  267. h = self.heads
  268. q = self.to_q(x)
  269. context = default(context, x)
  270. context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context)
  271. k = self.to_k(context_k)
  272. v = self.to_v(context_v)
  273. del context, context_k, context_v, x
  274. q = q.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1)
  275. k = k.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1)
  276. v = v.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1)
  277. if q.device.type == 'mps':
  278. q, k, v = q.contiguous(), k.contiguous(), v.contiguous()
  279. dtype = q.dtype
  280. if shared.opts.upcast_attn:
  281. q, k = q.float(), k.float()
  282. x = sub_quad_attention(q, k, v, q_chunk_size=shared.cmd_opts.sub_quad_q_chunk_size, kv_chunk_size=shared.cmd_opts.sub_quad_kv_chunk_size, chunk_threshold=shared.cmd_opts.sub_quad_chunk_threshold, use_checkpoint=self.training)
  283. x = x.to(dtype)
  284. x = x.unflatten(0, (-1, h)).transpose(1,2).flatten(start_dim=2)
  285. out_proj, dropout = self.to_out
  286. x = out_proj(x)
  287. x = dropout(x)
  288. return x
  289. def sub_quad_attention(q, k, v, q_chunk_size=1024, kv_chunk_size=None, kv_chunk_size_min=None, chunk_threshold=None, use_checkpoint=True):
  290. bytes_per_token = torch.finfo(q.dtype).bits//8
  291. batch_x_heads, q_tokens, _ = q.shape
  292. _, k_tokens, _ = k.shape
  293. qk_matmul_size_bytes = batch_x_heads * bytes_per_token * q_tokens * k_tokens
  294. if chunk_threshold is None:
  295. chunk_threshold_bytes = int(get_available_vram() * 0.9) if q.device.type == 'mps' else int(get_available_vram() * 0.7)
  296. elif chunk_threshold == 0:
  297. chunk_threshold_bytes = None
  298. else:
  299. chunk_threshold_bytes = int(0.01 * chunk_threshold * get_available_vram())
  300. if kv_chunk_size_min is None and chunk_threshold_bytes is not None:
  301. kv_chunk_size_min = chunk_threshold_bytes // (batch_x_heads * bytes_per_token * (k.shape[2] + v.shape[2]))
  302. elif kv_chunk_size_min == 0:
  303. kv_chunk_size_min = None
  304. if chunk_threshold_bytes is not None and qk_matmul_size_bytes <= chunk_threshold_bytes:
  305. # the big matmul fits into our memory limit; do everything in 1 chunk,
  306. # i.e. send it down the unchunked fast-path
  307. kv_chunk_size = k_tokens
  308. with devices.without_autocast(disable=q.dtype == v.dtype):
  309. return sub_quadratic_attention.efficient_dot_product_attention(
  310. q,
  311. k,
  312. v,
  313. query_chunk_size=q_chunk_size,
  314. kv_chunk_size=kv_chunk_size,
  315. kv_chunk_size_min = kv_chunk_size_min,
  316. use_checkpoint=use_checkpoint,
  317. )
  318. def get_xformers_flash_attention_op(q, k, v):
  319. if not shared.cmd_opts.xformers_flash_attention:
  320. return None
  321. try:
  322. flash_attention_op = xformers.ops.MemoryEfficientAttentionFlashAttentionOp
  323. fw, bw = flash_attention_op
  324. if fw.supports(xformers.ops.fmha.Inputs(query=q, key=k, value=v, attn_bias=None)):
  325. return flash_attention_op
  326. except Exception as e:
  327. errors.display_once(e, "enabling flash attention")
  328. return None
  329. def xformers_attention_forward(self, x, context=None, mask=None):
  330. h = self.heads
  331. q_in = self.to_q(x)
  332. context = default(context, x)
  333. context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context)
  334. k_in = self.to_k(context_k)
  335. v_in = self.to_v(context_v)
  336. q, k, v = (rearrange(t, 'b n (h d) -> b n h d', h=h) for t in (q_in, k_in, v_in))
  337. del q_in, k_in, v_in
  338. dtype = q.dtype
  339. if shared.opts.upcast_attn:
  340. q, k, v = q.float(), k.float(), v.float()
  341. out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=get_xformers_flash_attention_op(q, k, v))
  342. out = out.to(dtype)
  343. out = rearrange(out, 'b n h d -> b n (h d)', h=h)
  344. return self.to_out(out)
  345. # Based on Diffusers usage of scaled dot product attention from https://github.com/huggingface/diffusers/blob/c7da8fd23359a22d0df2741688b5b4f33c26df21/src/diffusers/models/cross_attention.py
  346. # The scaled_dot_product_attention_forward function contains parts of code under Apache-2.0 license listed under Scaled Dot Product Attention in the Licenses section of the web UI interface
  347. def scaled_dot_product_attention_forward(self, x, context=None, mask=None):
  348. batch_size, sequence_length, inner_dim = x.shape
  349. if mask is not None:
  350. mask = self.prepare_attention_mask(mask, sequence_length, batch_size)
  351. mask = mask.view(batch_size, self.heads, -1, mask.shape[-1])
  352. h = self.heads
  353. q_in = self.to_q(x)
  354. context = default(context, x)
  355. context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context)
  356. k_in = self.to_k(context_k)
  357. v_in = self.to_v(context_v)
  358. head_dim = inner_dim // h
  359. q = q_in.view(batch_size, -1, h, head_dim).transpose(1, 2)
  360. k = k_in.view(batch_size, -1, h, head_dim).transpose(1, 2)
  361. v = v_in.view(batch_size, -1, h, head_dim).transpose(1, 2)
  362. del q_in, k_in, v_in
  363. dtype = q.dtype
  364. if shared.opts.upcast_attn:
  365. q, k, v = q.float(), k.float(), v.float()
  366. # the output of sdp = (batch, num_heads, seq_len, head_dim)
  367. hidden_states = torch.nn.functional.scaled_dot_product_attention(
  368. q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False
  369. )
  370. hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, h * head_dim)
  371. hidden_states = hidden_states.to(dtype)
  372. # linear proj
  373. hidden_states = self.to_out[0](hidden_states)
  374. # dropout
  375. hidden_states = self.to_out[1](hidden_states)
  376. return hidden_states
  377. def scaled_dot_product_no_mem_attention_forward(self, x, context=None, mask=None):
  378. with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=False):
  379. return scaled_dot_product_attention_forward(self, x, context, mask)
  380. def cross_attention_attnblock_forward(self, x):
  381. h_ = x
  382. h_ = self.norm(h_)
  383. q1 = self.q(h_)
  384. k1 = self.k(h_)
  385. v = self.v(h_)
  386. # compute attention
  387. b, c, h, w = q1.shape
  388. q2 = q1.reshape(b, c, h*w)
  389. del q1
  390. q = q2.permute(0, 2, 1) # b,hw,c
  391. del q2
  392. k = k1.reshape(b, c, h*w) # b,c,hw
  393. del k1
  394. h_ = torch.zeros_like(k, device=q.device)
  395. mem_free_total = get_available_vram()
  396. tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * q.element_size()
  397. mem_required = tensor_size * 2.5
  398. steps = 1
  399. if mem_required > mem_free_total:
  400. steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2)))
  401. slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
  402. for i in range(0, q.shape[1], slice_size):
  403. end = i + slice_size
  404. w1 = torch.bmm(q[:, i:end], k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
  405. w2 = w1 * (int(c)**(-0.5))
  406. del w1
  407. w3 = torch.nn.functional.softmax(w2, dim=2, dtype=q.dtype)
  408. del w2
  409. # attend to values
  410. v1 = v.reshape(b, c, h*w)
  411. w4 = w3.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
  412. del w3
  413. h_[:, :, i:end] = torch.bmm(v1, w4) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
  414. del v1, w4
  415. h2 = h_.reshape(b, c, h, w)
  416. del h_
  417. h3 = self.proj_out(h2)
  418. del h2
  419. h3 += x
  420. return h3
  421. def xformers_attnblock_forward(self, x):
  422. try:
  423. h_ = x
  424. h_ = self.norm(h_)
  425. q = self.q(h_)
  426. k = self.k(h_)
  427. v = self.v(h_)
  428. b, c, h, w = q.shape
  429. q, k, v = (rearrange(t, 'b c h w -> b (h w) c') for t in (q, k, v))
  430. dtype = q.dtype
  431. if shared.opts.upcast_attn:
  432. q, k = q.float(), k.float()
  433. q = q.contiguous()
  434. k = k.contiguous()
  435. v = v.contiguous()
  436. out = xformers.ops.memory_efficient_attention(q, k, v, op=get_xformers_flash_attention_op(q, k, v))
  437. out = out.to(dtype)
  438. out = rearrange(out, 'b (h w) c -> b c h w', h=h)
  439. out = self.proj_out(out)
  440. return x + out
  441. except NotImplementedError:
  442. return cross_attention_attnblock_forward(self, x)
  443. def sdp_attnblock_forward(self, x):
  444. h_ = x
  445. h_ = self.norm(h_)
  446. q = self.q(h_)
  447. k = self.k(h_)
  448. v = self.v(h_)
  449. b, c, h, w = q.shape
  450. q, k, v = (rearrange(t, 'b c h w -> b (h w) c') for t in (q, k, v))
  451. dtype = q.dtype
  452. if shared.opts.upcast_attn:
  453. q, k = q.float(), k.float()
  454. q = q.contiguous()
  455. k = k.contiguous()
  456. v = v.contiguous()
  457. out = torch.nn.functional.scaled_dot_product_attention(q, k, v, dropout_p=0.0, is_causal=False)
  458. out = out.to(dtype)
  459. out = rearrange(out, 'b (h w) c -> b c h w', h=h)
  460. out = self.proj_out(out)
  461. return x + out
  462. def sdp_no_mem_attnblock_forward(self, x):
  463. with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=False):
  464. return sdp_attnblock_forward(self, x)
  465. def sub_quad_attnblock_forward(self, x):
  466. h_ = x
  467. h_ = self.norm(h_)
  468. q = self.q(h_)
  469. k = self.k(h_)
  470. v = self.v(h_)
  471. b, c, h, w = q.shape
  472. q, k, v = (rearrange(t, 'b c h w -> b (h w) c') for t in (q, k, v))
  473. q = q.contiguous()
  474. k = k.contiguous()
  475. v = v.contiguous()
  476. out = sub_quad_attention(q, k, v, q_chunk_size=shared.cmd_opts.sub_quad_q_chunk_size, kv_chunk_size=shared.cmd_opts.sub_quad_kv_chunk_size, chunk_threshold=shared.cmd_opts.sub_quad_chunk_threshold, use_checkpoint=self.training)
  477. out = rearrange(out, 'b (h w) c -> b c h w', h=h)
  478. out = self.proj_out(out)
  479. return x + out