sd_hijack_optimizations.py 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677
  1. from __future__ import annotations
  2. import math
  3. import psutil
  4. import platform
  5. import torch
  6. from torch import einsum
  7. from ldm.util import default
  8. from einops import rearrange
  9. from modules import shared, errors, devices, sub_quadratic_attention
  10. from modules.hypernetworks import hypernetwork
  11. import ldm.modules.attention
  12. import ldm.modules.diffusionmodules.model
  13. import sgm.modules.attention
  14. import sgm.modules.diffusionmodules.model
  15. diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.AttnBlock.forward
  16. sgm_diffusionmodules_model_AttnBlock_forward = sgm.modules.diffusionmodules.model.AttnBlock.forward
  17. class SdOptimization:
  18. name: str = None
  19. label: str | None = None
  20. cmd_opt: str | None = None
  21. priority: int = 0
  22. def title(self):
  23. if self.label is None:
  24. return self.name
  25. return f"{self.name} - {self.label}"
  26. def is_available(self):
  27. return True
  28. def apply(self):
  29. pass
  30. def undo(self):
  31. ldm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward
  32. ldm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward
  33. sgm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward
  34. sgm.modules.diffusionmodules.model.AttnBlock.forward = sgm_diffusionmodules_model_AttnBlock_forward
  35. class SdOptimizationXformers(SdOptimization):
  36. name = "xformers"
  37. cmd_opt = "xformers"
  38. priority = 100
  39. def is_available(self):
  40. return shared.cmd_opts.force_enable_xformers or (shared.xformers_available and torch.cuda.is_available() and (6, 0) <= torch.cuda.get_device_capability(shared.device) <= (9, 0))
  41. def apply(self):
  42. ldm.modules.attention.CrossAttention.forward = xformers_attention_forward
  43. ldm.modules.diffusionmodules.model.AttnBlock.forward = xformers_attnblock_forward
  44. sgm.modules.attention.CrossAttention.forward = xformers_attention_forward
  45. sgm.modules.diffusionmodules.model.AttnBlock.forward = xformers_attnblock_forward
  46. class SdOptimizationSdpNoMem(SdOptimization):
  47. name = "sdp-no-mem"
  48. label = "scaled dot product without memory efficient attention"
  49. cmd_opt = "opt_sdp_no_mem_attention"
  50. priority = 80
  51. def is_available(self):
  52. return hasattr(torch.nn.functional, "scaled_dot_product_attention") and callable(torch.nn.functional.scaled_dot_product_attention)
  53. def apply(self):
  54. ldm.modules.attention.CrossAttention.forward = scaled_dot_product_no_mem_attention_forward
  55. ldm.modules.diffusionmodules.model.AttnBlock.forward = sdp_no_mem_attnblock_forward
  56. sgm.modules.attention.CrossAttention.forward = scaled_dot_product_no_mem_attention_forward
  57. sgm.modules.diffusionmodules.model.AttnBlock.forward = sdp_no_mem_attnblock_forward
  58. class SdOptimizationSdp(SdOptimizationSdpNoMem):
  59. name = "sdp"
  60. label = "scaled dot product"
  61. cmd_opt = "opt_sdp_attention"
  62. priority = 70
  63. def apply(self):
  64. ldm.modules.attention.CrossAttention.forward = scaled_dot_product_attention_forward
  65. ldm.modules.diffusionmodules.model.AttnBlock.forward = sdp_attnblock_forward
  66. sgm.modules.attention.CrossAttention.forward = scaled_dot_product_attention_forward
  67. sgm.modules.diffusionmodules.model.AttnBlock.forward = sdp_attnblock_forward
  68. class SdOptimizationSubQuad(SdOptimization):
  69. name = "sub-quadratic"
  70. cmd_opt = "opt_sub_quad_attention"
  71. @property
  72. def priority(self):
  73. return 1000 if shared.device.type == 'mps' else 10
  74. def apply(self):
  75. ldm.modules.attention.CrossAttention.forward = sub_quad_attention_forward
  76. ldm.modules.diffusionmodules.model.AttnBlock.forward = sub_quad_attnblock_forward
  77. sgm.modules.attention.CrossAttention.forward = sub_quad_attention_forward
  78. sgm.modules.diffusionmodules.model.AttnBlock.forward = sub_quad_attnblock_forward
  79. class SdOptimizationV1(SdOptimization):
  80. name = "V1"
  81. label = "original v1"
  82. cmd_opt = "opt_split_attention_v1"
  83. priority = 10
  84. def apply(self):
  85. ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward_v1
  86. sgm.modules.attention.CrossAttention.forward = split_cross_attention_forward_v1
  87. class SdOptimizationInvokeAI(SdOptimization):
  88. name = "InvokeAI"
  89. cmd_opt = "opt_split_attention_invokeai"
  90. @property
  91. def priority(self):
  92. return 1000 if shared.device.type != 'mps' and not torch.cuda.is_available() else 10
  93. def apply(self):
  94. ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward_invokeAI
  95. sgm.modules.attention.CrossAttention.forward = split_cross_attention_forward_invokeAI
  96. class SdOptimizationDoggettx(SdOptimization):
  97. name = "Doggettx"
  98. cmd_opt = "opt_split_attention"
  99. priority = 90
  100. def apply(self):
  101. ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward
  102. ldm.modules.diffusionmodules.model.AttnBlock.forward = cross_attention_attnblock_forward
  103. sgm.modules.attention.CrossAttention.forward = split_cross_attention_forward
  104. sgm.modules.diffusionmodules.model.AttnBlock.forward = cross_attention_attnblock_forward
  105. def list_optimizers(res):
  106. res.extend([
  107. SdOptimizationXformers(),
  108. SdOptimizationSdpNoMem(),
  109. SdOptimizationSdp(),
  110. SdOptimizationSubQuad(),
  111. SdOptimizationV1(),
  112. SdOptimizationInvokeAI(),
  113. SdOptimizationDoggettx(),
  114. ])
  115. if shared.cmd_opts.xformers or shared.cmd_opts.force_enable_xformers:
  116. try:
  117. import xformers.ops
  118. shared.xformers_available = True
  119. except Exception:
  120. errors.report("Cannot import xformers", exc_info=True)
  121. def get_available_vram():
  122. if shared.device.type == 'cuda':
  123. stats = torch.cuda.memory_stats(shared.device)
  124. mem_active = stats['active_bytes.all.current']
  125. mem_reserved = stats['reserved_bytes.all.current']
  126. mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device())
  127. mem_free_torch = mem_reserved - mem_active
  128. mem_free_total = mem_free_cuda + mem_free_torch
  129. return mem_free_total
  130. else:
  131. return psutil.virtual_memory().available
  132. # see https://github.com/basujindal/stable-diffusion/pull/117 for discussion
  133. def split_cross_attention_forward_v1(self, x, context=None, mask=None, **kwargs):
  134. h = self.heads
  135. q_in = self.to_q(x)
  136. context = default(context, x)
  137. context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context)
  138. k_in = self.to_k(context_k)
  139. v_in = self.to_v(context_v)
  140. del context, context_k, context_v, x
  141. q, k, v = (rearrange(t, 'b n (h d) -> (b h) n d', h=h) for t in (q_in, k_in, v_in))
  142. del q_in, k_in, v_in
  143. dtype = q.dtype
  144. if shared.opts.upcast_attn:
  145. q, k, v = q.float(), k.float(), v.float()
  146. with devices.without_autocast(disable=not shared.opts.upcast_attn):
  147. r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
  148. for i in range(0, q.shape[0], 2):
  149. end = i + 2
  150. s1 = einsum('b i d, b j d -> b i j', q[i:end], k[i:end])
  151. s1 *= self.scale
  152. s2 = s1.softmax(dim=-1)
  153. del s1
  154. r1[i:end] = einsum('b i j, b j d -> b i d', s2, v[i:end])
  155. del s2
  156. del q, k, v
  157. r1 = r1.to(dtype)
  158. r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h)
  159. del r1
  160. return self.to_out(r2)
  161. # taken from https://github.com/Doggettx/stable-diffusion and modified
  162. def split_cross_attention_forward(self, x, context=None, mask=None, **kwargs):
  163. h = self.heads
  164. q_in = self.to_q(x)
  165. context = default(context, x)
  166. context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context)
  167. k_in = self.to_k(context_k)
  168. v_in = self.to_v(context_v)
  169. dtype = q_in.dtype
  170. if shared.opts.upcast_attn:
  171. q_in, k_in, v_in = q_in.float(), k_in.float(), v_in if v_in.device.type == 'mps' else v_in.float()
  172. with devices.without_autocast(disable=not shared.opts.upcast_attn):
  173. k_in = k_in * self.scale
  174. del context, x
  175. q, k, v = (rearrange(t, 'b n (h d) -> (b h) n d', h=h) for t in (q_in, k_in, v_in))
  176. del q_in, k_in, v_in
  177. r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
  178. mem_free_total = get_available_vram()
  179. gb = 1024 ** 3
  180. tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size()
  181. modifier = 3 if q.element_size() == 2 else 2.5
  182. mem_required = tensor_size * modifier
  183. steps = 1
  184. if mem_required > mem_free_total:
  185. steps = 2 ** (math.ceil(math.log(mem_required / mem_free_total, 2)))
  186. # print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB "
  187. # f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}")
  188. if steps > 64:
  189. max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64
  190. raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). '
  191. f'Need: {mem_required / 64 / gb:0.1f}GB free, Have:{mem_free_total / gb:0.1f}GB free')
  192. slice_size = q.shape[1] // steps
  193. for i in range(0, q.shape[1], slice_size):
  194. end = min(i + slice_size, q.shape[1])
  195. s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k)
  196. s2 = s1.softmax(dim=-1, dtype=q.dtype)
  197. del s1
  198. r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
  199. del s2
  200. del q, k, v
  201. r1 = r1.to(dtype)
  202. r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h)
  203. del r1
  204. return self.to_out(r2)
  205. # -- Taken from https://github.com/invoke-ai/InvokeAI and modified --
  206. mem_total_gb = psutil.virtual_memory().total // (1 << 30)
  207. def einsum_op_compvis(q, k, v):
  208. s = einsum('b i d, b j d -> b i j', q, k)
  209. s = s.softmax(dim=-1, dtype=s.dtype)
  210. return einsum('b i j, b j d -> b i d', s, v)
  211. def einsum_op_slice_0(q, k, v, slice_size):
  212. r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
  213. for i in range(0, q.shape[0], slice_size):
  214. end = i + slice_size
  215. r[i:end] = einsum_op_compvis(q[i:end], k[i:end], v[i:end])
  216. return r
  217. def einsum_op_slice_1(q, k, v, slice_size):
  218. r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
  219. for i in range(0, q.shape[1], slice_size):
  220. end = i + slice_size
  221. r[:, i:end] = einsum_op_compvis(q[:, i:end], k, v)
  222. return r
  223. def einsum_op_mps_v1(q, k, v):
  224. if q.shape[0] * q.shape[1] <= 2**16: # (512x512) max q.shape[1]: 4096
  225. return einsum_op_compvis(q, k, v)
  226. else:
  227. slice_size = math.floor(2**30 / (q.shape[0] * q.shape[1]))
  228. if slice_size % 4096 == 0:
  229. slice_size -= 1
  230. return einsum_op_slice_1(q, k, v, slice_size)
  231. def einsum_op_mps_v2(q, k, v):
  232. if mem_total_gb > 8 and q.shape[0] * q.shape[1] <= 2**16:
  233. return einsum_op_compvis(q, k, v)
  234. else:
  235. return einsum_op_slice_0(q, k, v, 1)
  236. def einsum_op_tensor_mem(q, k, v, max_tensor_mb):
  237. size_mb = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size() // (1 << 20)
  238. if size_mb <= max_tensor_mb:
  239. return einsum_op_compvis(q, k, v)
  240. div = 1 << int((size_mb - 1) / max_tensor_mb).bit_length()
  241. if div <= q.shape[0]:
  242. return einsum_op_slice_0(q, k, v, q.shape[0] // div)
  243. return einsum_op_slice_1(q, k, v, max(q.shape[1] // div, 1))
  244. def einsum_op_cuda(q, k, v):
  245. stats = torch.cuda.memory_stats(q.device)
  246. mem_active = stats['active_bytes.all.current']
  247. mem_reserved = stats['reserved_bytes.all.current']
  248. mem_free_cuda, _ = torch.cuda.mem_get_info(q.device)
  249. mem_free_torch = mem_reserved - mem_active
  250. mem_free_total = mem_free_cuda + mem_free_torch
  251. # Divide factor of safety as there's copying and fragmentation
  252. return einsum_op_tensor_mem(q, k, v, mem_free_total / 3.3 / (1 << 20))
  253. def einsum_op(q, k, v):
  254. if q.device.type == 'cuda':
  255. return einsum_op_cuda(q, k, v)
  256. if q.device.type == 'mps':
  257. if mem_total_gb >= 32 and q.shape[0] % 32 != 0 and q.shape[0] * q.shape[1] < 2**18:
  258. return einsum_op_mps_v1(q, k, v)
  259. return einsum_op_mps_v2(q, k, v)
  260. # Smaller slices are faster due to L2/L3/SLC caches.
  261. # Tested on i7 with 8MB L3 cache.
  262. return einsum_op_tensor_mem(q, k, v, 32)
  263. def split_cross_attention_forward_invokeAI(self, x, context=None, mask=None, **kwargs):
  264. h = self.heads
  265. q = self.to_q(x)
  266. context = default(context, x)
  267. context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context)
  268. k = self.to_k(context_k)
  269. v = self.to_v(context_v)
  270. del context, context_k, context_v, x
  271. dtype = q.dtype
  272. if shared.opts.upcast_attn:
  273. q, k, v = q.float(), k.float(), v if v.device.type == 'mps' else v.float()
  274. with devices.without_autocast(disable=not shared.opts.upcast_attn):
  275. k = k * self.scale
  276. q, k, v = (rearrange(t, 'b n (h d) -> (b h) n d', h=h) for t in (q, k, v))
  277. r = einsum_op(q, k, v)
  278. r = r.to(dtype)
  279. return self.to_out(rearrange(r, '(b h) n d -> b n (h d)', h=h))
  280. # -- End of code from https://github.com/invoke-ai/InvokeAI --
  281. # Based on Birch-san's modified implementation of sub-quadratic attention from https://github.com/Birch-san/diffusers/pull/1
  282. # The sub_quad_attention_forward function is under the MIT License listed under Memory Efficient Attention in the Licenses section of the web UI interface
  283. def sub_quad_attention_forward(self, x, context=None, mask=None, **kwargs):
  284. assert mask is None, "attention-mask not currently implemented for SubQuadraticCrossAttnProcessor."
  285. h = self.heads
  286. q = self.to_q(x)
  287. context = default(context, x)
  288. context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context)
  289. k = self.to_k(context_k)
  290. v = self.to_v(context_v)
  291. del context, context_k, context_v, x
  292. q = q.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1)
  293. k = k.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1)
  294. v = v.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1)
  295. if q.device.type == 'mps':
  296. q, k, v = q.contiguous(), k.contiguous(), v.contiguous()
  297. dtype = q.dtype
  298. if shared.opts.upcast_attn:
  299. q, k = q.float(), k.float()
  300. x = sub_quad_attention(q, k, v, q_chunk_size=shared.cmd_opts.sub_quad_q_chunk_size, kv_chunk_size=shared.cmd_opts.sub_quad_kv_chunk_size, chunk_threshold=shared.cmd_opts.sub_quad_chunk_threshold, use_checkpoint=self.training)
  301. x = x.to(dtype)
  302. x = x.unflatten(0, (-1, h)).transpose(1,2).flatten(start_dim=2)
  303. out_proj, dropout = self.to_out
  304. x = out_proj(x)
  305. x = dropout(x)
  306. return x
  307. def sub_quad_attention(q, k, v, q_chunk_size=1024, kv_chunk_size=None, kv_chunk_size_min=None, chunk_threshold=None, use_checkpoint=True):
  308. bytes_per_token = torch.finfo(q.dtype).bits//8
  309. batch_x_heads, q_tokens, _ = q.shape
  310. _, k_tokens, _ = k.shape
  311. qk_matmul_size_bytes = batch_x_heads * bytes_per_token * q_tokens * k_tokens
  312. if chunk_threshold is None:
  313. if q.device.type == 'mps':
  314. chunk_threshold_bytes = 268435456 * (2 if platform.processor() == 'i386' else bytes_per_token)
  315. else:
  316. chunk_threshold_bytes = int(get_available_vram() * 0.7)
  317. elif chunk_threshold == 0:
  318. chunk_threshold_bytes = None
  319. else:
  320. chunk_threshold_bytes = int(0.01 * chunk_threshold * get_available_vram())
  321. if kv_chunk_size_min is None and chunk_threshold_bytes is not None:
  322. kv_chunk_size_min = chunk_threshold_bytes // (batch_x_heads * bytes_per_token * (k.shape[2] + v.shape[2]))
  323. elif kv_chunk_size_min == 0:
  324. kv_chunk_size_min = None
  325. if chunk_threshold_bytes is not None and qk_matmul_size_bytes <= chunk_threshold_bytes:
  326. # the big matmul fits into our memory limit; do everything in 1 chunk,
  327. # i.e. send it down the unchunked fast-path
  328. kv_chunk_size = k_tokens
  329. with devices.without_autocast(disable=q.dtype == v.dtype):
  330. return sub_quadratic_attention.efficient_dot_product_attention(
  331. q,
  332. k,
  333. v,
  334. query_chunk_size=q_chunk_size,
  335. kv_chunk_size=kv_chunk_size,
  336. kv_chunk_size_min = kv_chunk_size_min,
  337. use_checkpoint=use_checkpoint,
  338. )
  339. def get_xformers_flash_attention_op(q, k, v):
  340. if not shared.cmd_opts.xformers_flash_attention:
  341. return None
  342. try:
  343. flash_attention_op = xformers.ops.MemoryEfficientAttentionFlashAttentionOp
  344. fw, bw = flash_attention_op
  345. if fw.supports(xformers.ops.fmha.Inputs(query=q, key=k, value=v, attn_bias=None)):
  346. return flash_attention_op
  347. except Exception as e:
  348. errors.display_once(e, "enabling flash attention")
  349. return None
  350. def xformers_attention_forward(self, x, context=None, mask=None, **kwargs):
  351. h = self.heads
  352. q_in = self.to_q(x)
  353. context = default(context, x)
  354. context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context)
  355. k_in = self.to_k(context_k)
  356. v_in = self.to_v(context_v)
  357. q, k, v = (t.reshape(t.shape[0], t.shape[1], h, -1) for t in (q_in, k_in, v_in))
  358. del q_in, k_in, v_in
  359. dtype = q.dtype
  360. if shared.opts.upcast_attn:
  361. q, k, v = q.float(), k.float(), v.float()
  362. out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=get_xformers_flash_attention_op(q, k, v))
  363. out = out.to(dtype)
  364. b, n, h, d = out.shape
  365. out = out.reshape(b, n, h * d)
  366. return self.to_out(out)
  367. # Based on Diffusers usage of scaled dot product attention from https://github.com/huggingface/diffusers/blob/c7da8fd23359a22d0df2741688b5b4f33c26df21/src/diffusers/models/cross_attention.py
  368. # The scaled_dot_product_attention_forward function contains parts of code under Apache-2.0 license listed under Scaled Dot Product Attention in the Licenses section of the web UI interface
  369. def scaled_dot_product_attention_forward(self, x, context=None, mask=None, **kwargs):
  370. batch_size, sequence_length, inner_dim = x.shape
  371. if mask is not None:
  372. mask = self.prepare_attention_mask(mask, sequence_length, batch_size)
  373. mask = mask.view(batch_size, self.heads, -1, mask.shape[-1])
  374. h = self.heads
  375. q_in = self.to_q(x)
  376. context = default(context, x)
  377. context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context)
  378. k_in = self.to_k(context_k)
  379. v_in = self.to_v(context_v)
  380. head_dim = inner_dim // h
  381. q = q_in.view(batch_size, -1, h, head_dim).transpose(1, 2)
  382. k = k_in.view(batch_size, -1, h, head_dim).transpose(1, 2)
  383. v = v_in.view(batch_size, -1, h, head_dim).transpose(1, 2)
  384. del q_in, k_in, v_in
  385. dtype = q.dtype
  386. if shared.opts.upcast_attn:
  387. q, k, v = q.float(), k.float(), v.float()
  388. # the output of sdp = (batch, num_heads, seq_len, head_dim)
  389. hidden_states = torch.nn.functional.scaled_dot_product_attention(
  390. q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False
  391. )
  392. hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, h * head_dim)
  393. hidden_states = hidden_states.to(dtype)
  394. # linear proj
  395. hidden_states = self.to_out[0](hidden_states)
  396. # dropout
  397. hidden_states = self.to_out[1](hidden_states)
  398. return hidden_states
  399. def scaled_dot_product_no_mem_attention_forward(self, x, context=None, mask=None, **kwargs):
  400. with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=False):
  401. return scaled_dot_product_attention_forward(self, x, context, mask)
  402. def cross_attention_attnblock_forward(self, x):
  403. h_ = x
  404. h_ = self.norm(h_)
  405. q1 = self.q(h_)
  406. k1 = self.k(h_)
  407. v = self.v(h_)
  408. # compute attention
  409. b, c, h, w = q1.shape
  410. q2 = q1.reshape(b, c, h*w)
  411. del q1
  412. q = q2.permute(0, 2, 1) # b,hw,c
  413. del q2
  414. k = k1.reshape(b, c, h*w) # b,c,hw
  415. del k1
  416. h_ = torch.zeros_like(k, device=q.device)
  417. mem_free_total = get_available_vram()
  418. tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * q.element_size()
  419. mem_required = tensor_size * 2.5
  420. steps = 1
  421. if mem_required > mem_free_total:
  422. steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2)))
  423. slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
  424. for i in range(0, q.shape[1], slice_size):
  425. end = i + slice_size
  426. w1 = torch.bmm(q[:, i:end], k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
  427. w2 = w1 * (int(c)**(-0.5))
  428. del w1
  429. w3 = torch.nn.functional.softmax(w2, dim=2, dtype=q.dtype)
  430. del w2
  431. # attend to values
  432. v1 = v.reshape(b, c, h*w)
  433. w4 = w3.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
  434. del w3
  435. h_[:, :, i:end] = torch.bmm(v1, w4) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
  436. del v1, w4
  437. h2 = h_.reshape(b, c, h, w)
  438. del h_
  439. h3 = self.proj_out(h2)
  440. del h2
  441. h3 += x
  442. return h3
  443. def xformers_attnblock_forward(self, x):
  444. try:
  445. h_ = x
  446. h_ = self.norm(h_)
  447. q = self.q(h_)
  448. k = self.k(h_)
  449. v = self.v(h_)
  450. b, c, h, w = q.shape
  451. q, k, v = (rearrange(t, 'b c h w -> b (h w) c') for t in (q, k, v))
  452. dtype = q.dtype
  453. if shared.opts.upcast_attn:
  454. q, k = q.float(), k.float()
  455. q = q.contiguous()
  456. k = k.contiguous()
  457. v = v.contiguous()
  458. out = xformers.ops.memory_efficient_attention(q, k, v, op=get_xformers_flash_attention_op(q, k, v))
  459. out = out.to(dtype)
  460. out = rearrange(out, 'b (h w) c -> b c h w', h=h)
  461. out = self.proj_out(out)
  462. return x + out
  463. except NotImplementedError:
  464. return cross_attention_attnblock_forward(self, x)
  465. def sdp_attnblock_forward(self, x):
  466. h_ = x
  467. h_ = self.norm(h_)
  468. q = self.q(h_)
  469. k = self.k(h_)
  470. v = self.v(h_)
  471. b, c, h, w = q.shape
  472. q, k, v = (rearrange(t, 'b c h w -> b (h w) c') for t in (q, k, v))
  473. dtype = q.dtype
  474. if shared.opts.upcast_attn:
  475. q, k, v = q.float(), k.float(), v.float()
  476. q = q.contiguous()
  477. k = k.contiguous()
  478. v = v.contiguous()
  479. out = torch.nn.functional.scaled_dot_product_attention(q, k, v, dropout_p=0.0, is_causal=False)
  480. out = out.to(dtype)
  481. out = rearrange(out, 'b (h w) c -> b c h w', h=h)
  482. out = self.proj_out(out)
  483. return x + out
  484. def sdp_no_mem_attnblock_forward(self, x):
  485. with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=False):
  486. return sdp_attnblock_forward(self, x)
  487. def sub_quad_attnblock_forward(self, x):
  488. h_ = x
  489. h_ = self.norm(h_)
  490. q = self.q(h_)
  491. k = self.k(h_)
  492. v = self.v(h_)
  493. b, c, h, w = q.shape
  494. q, k, v = (rearrange(t, 'b c h w -> b (h w) c') for t in (q, k, v))
  495. q = q.contiguous()
  496. k = k.contiguous()
  497. v = v.contiguous()
  498. out = sub_quad_attention(q, k, v, q_chunk_size=shared.cmd_opts.sub_quad_q_chunk_size, kv_chunk_size=shared.cmd_opts.sub_quad_kv_chunk_size, chunk_threshold=shared.cmd_opts.sub_quad_chunk_threshold, use_checkpoint=self.training)
  499. out = rearrange(out, 'b (h w) c -> b c h w', h=h)
  500. out = self.proj_out(out)
  501. return x + out