hypertile.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351
  1. """
  2. Hypertile module for splitting attention layers in SD-1.5 U-Net and SD-1.5 VAE
  3. Warn: The patch works well only if the input image has a width and height that are multiples of 128
  4. Original author: @tfernd Github: https://github.com/tfernd/HyperTile
  5. """
  6. from __future__ import annotations
  7. from dataclasses import dataclass
  8. from typing import Callable
  9. from functools import wraps, cache
  10. import math
  11. import torch.nn as nn
  12. import random
  13. from einops import rearrange
  14. @dataclass
  15. class HypertileParams:
  16. depth = 0
  17. layer_name = ""
  18. tile_size: int = 0
  19. swap_size: int = 0
  20. aspect_ratio: float = 1.0
  21. forward = None
  22. enabled = False
  23. # TODO add SD-XL layers
  24. DEPTH_LAYERS = {
  25. 0: [
  26. # SD 1.5 U-Net (diffusers)
  27. "down_blocks.0.attentions.0.transformer_blocks.0.attn1",
  28. "down_blocks.0.attentions.1.transformer_blocks.0.attn1",
  29. "up_blocks.3.attentions.0.transformer_blocks.0.attn1",
  30. "up_blocks.3.attentions.1.transformer_blocks.0.attn1",
  31. "up_blocks.3.attentions.2.transformer_blocks.0.attn1",
  32. # SD 1.5 U-Net (ldm)
  33. "input_blocks.1.1.transformer_blocks.0.attn1",
  34. "input_blocks.2.1.transformer_blocks.0.attn1",
  35. "output_blocks.9.1.transformer_blocks.0.attn1",
  36. "output_blocks.10.1.transformer_blocks.0.attn1",
  37. "output_blocks.11.1.transformer_blocks.0.attn1",
  38. # SD 1.5 VAE
  39. "decoder.mid_block.attentions.0",
  40. "decoder.mid.attn_1",
  41. ],
  42. 1: [
  43. # SD 1.5 U-Net (diffusers)
  44. "down_blocks.1.attentions.0.transformer_blocks.0.attn1",
  45. "down_blocks.1.attentions.1.transformer_blocks.0.attn1",
  46. "up_blocks.2.attentions.0.transformer_blocks.0.attn1",
  47. "up_blocks.2.attentions.1.transformer_blocks.0.attn1",
  48. "up_blocks.2.attentions.2.transformer_blocks.0.attn1",
  49. # SD 1.5 U-Net (ldm)
  50. "input_blocks.4.1.transformer_blocks.0.attn1",
  51. "input_blocks.5.1.transformer_blocks.0.attn1",
  52. "output_blocks.6.1.transformer_blocks.0.attn1",
  53. "output_blocks.7.1.transformer_blocks.0.attn1",
  54. "output_blocks.8.1.transformer_blocks.0.attn1",
  55. ],
  56. 2: [
  57. # SD 1.5 U-Net (diffusers)
  58. "down_blocks.2.attentions.0.transformer_blocks.0.attn1",
  59. "down_blocks.2.attentions.1.transformer_blocks.0.attn1",
  60. "up_blocks.1.attentions.0.transformer_blocks.0.attn1",
  61. "up_blocks.1.attentions.1.transformer_blocks.0.attn1",
  62. "up_blocks.1.attentions.2.transformer_blocks.0.attn1",
  63. # SD 1.5 U-Net (ldm)
  64. "input_blocks.7.1.transformer_blocks.0.attn1",
  65. "input_blocks.8.1.transformer_blocks.0.attn1",
  66. "output_blocks.3.1.transformer_blocks.0.attn1",
  67. "output_blocks.4.1.transformer_blocks.0.attn1",
  68. "output_blocks.5.1.transformer_blocks.0.attn1",
  69. ],
  70. 3: [
  71. # SD 1.5 U-Net (diffusers)
  72. "mid_block.attentions.0.transformer_blocks.0.attn1",
  73. # SD 1.5 U-Net (ldm)
  74. "middle_block.1.transformer_blocks.0.attn1",
  75. ],
  76. }
  77. # XL layers, thanks for GitHub@gel-crabs for the help
  78. DEPTH_LAYERS_XL = {
  79. 0: [
  80. # SD 1.5 U-Net (diffusers)
  81. "down_blocks.0.attentions.0.transformer_blocks.0.attn1",
  82. "down_blocks.0.attentions.1.transformer_blocks.0.attn1",
  83. "up_blocks.3.attentions.0.transformer_blocks.0.attn1",
  84. "up_blocks.3.attentions.1.transformer_blocks.0.attn1",
  85. "up_blocks.3.attentions.2.transformer_blocks.0.attn1",
  86. # SD 1.5 U-Net (ldm)
  87. "input_blocks.4.1.transformer_blocks.0.attn1",
  88. "input_blocks.5.1.transformer_blocks.0.attn1",
  89. "output_blocks.3.1.transformer_blocks.0.attn1",
  90. "output_blocks.4.1.transformer_blocks.0.attn1",
  91. "output_blocks.5.1.transformer_blocks.0.attn1",
  92. # SD 1.5 VAE
  93. "decoder.mid_block.attentions.0",
  94. "decoder.mid.attn_1",
  95. ],
  96. 1: [
  97. # SD 1.5 U-Net (diffusers)
  98. #"down_blocks.1.attentions.0.transformer_blocks.0.attn1",
  99. #"down_blocks.1.attentions.1.transformer_blocks.0.attn1",
  100. #"up_blocks.2.attentions.0.transformer_blocks.0.attn1",
  101. #"up_blocks.2.attentions.1.transformer_blocks.0.attn1",
  102. #"up_blocks.2.attentions.2.transformer_blocks.0.attn1",
  103. # SD 1.5 U-Net (ldm)
  104. "input_blocks.4.1.transformer_blocks.1.attn1",
  105. "input_blocks.5.1.transformer_blocks.1.attn1",
  106. "output_blocks.3.1.transformer_blocks.1.attn1",
  107. "output_blocks.4.1.transformer_blocks.1.attn1",
  108. "output_blocks.5.1.transformer_blocks.1.attn1",
  109. "input_blocks.7.1.transformer_blocks.0.attn1",
  110. "input_blocks.8.1.transformer_blocks.0.attn1",
  111. "output_blocks.0.1.transformer_blocks.0.attn1",
  112. "output_blocks.1.1.transformer_blocks.0.attn1",
  113. "output_blocks.2.1.transformer_blocks.0.attn1",
  114. "input_blocks.7.1.transformer_blocks.1.attn1",
  115. "input_blocks.8.1.transformer_blocks.1.attn1",
  116. "output_blocks.0.1.transformer_blocks.1.attn1",
  117. "output_blocks.1.1.transformer_blocks.1.attn1",
  118. "output_blocks.2.1.transformer_blocks.1.attn1",
  119. "input_blocks.7.1.transformer_blocks.2.attn1",
  120. "input_blocks.8.1.transformer_blocks.2.attn1",
  121. "output_blocks.0.1.transformer_blocks.2.attn1",
  122. "output_blocks.1.1.transformer_blocks.2.attn1",
  123. "output_blocks.2.1.transformer_blocks.2.attn1",
  124. "input_blocks.7.1.transformer_blocks.3.attn1",
  125. "input_blocks.8.1.transformer_blocks.3.attn1",
  126. "output_blocks.0.1.transformer_blocks.3.attn1",
  127. "output_blocks.1.1.transformer_blocks.3.attn1",
  128. "output_blocks.2.1.transformer_blocks.3.attn1",
  129. "input_blocks.7.1.transformer_blocks.4.attn1",
  130. "input_blocks.8.1.transformer_blocks.4.attn1",
  131. "output_blocks.0.1.transformer_blocks.4.attn1",
  132. "output_blocks.1.1.transformer_blocks.4.attn1",
  133. "output_blocks.2.1.transformer_blocks.4.attn1",
  134. "input_blocks.7.1.transformer_blocks.5.attn1",
  135. "input_blocks.8.1.transformer_blocks.5.attn1",
  136. "output_blocks.0.1.transformer_blocks.5.attn1",
  137. "output_blocks.1.1.transformer_blocks.5.attn1",
  138. "output_blocks.2.1.transformer_blocks.5.attn1",
  139. "input_blocks.7.1.transformer_blocks.6.attn1",
  140. "input_blocks.8.1.transformer_blocks.6.attn1",
  141. "output_blocks.0.1.transformer_blocks.6.attn1",
  142. "output_blocks.1.1.transformer_blocks.6.attn1",
  143. "output_blocks.2.1.transformer_blocks.6.attn1",
  144. "input_blocks.7.1.transformer_blocks.7.attn1",
  145. "input_blocks.8.1.transformer_blocks.7.attn1",
  146. "output_blocks.0.1.transformer_blocks.7.attn1",
  147. "output_blocks.1.1.transformer_blocks.7.attn1",
  148. "output_blocks.2.1.transformer_blocks.7.attn1",
  149. "input_blocks.7.1.transformer_blocks.8.attn1",
  150. "input_blocks.8.1.transformer_blocks.8.attn1",
  151. "output_blocks.0.1.transformer_blocks.8.attn1",
  152. "output_blocks.1.1.transformer_blocks.8.attn1",
  153. "output_blocks.2.1.transformer_blocks.8.attn1",
  154. "input_blocks.7.1.transformer_blocks.9.attn1",
  155. "input_blocks.8.1.transformer_blocks.9.attn1",
  156. "output_blocks.0.1.transformer_blocks.9.attn1",
  157. "output_blocks.1.1.transformer_blocks.9.attn1",
  158. "output_blocks.2.1.transformer_blocks.9.attn1",
  159. ],
  160. 2: [
  161. # SD 1.5 U-Net (diffusers)
  162. "mid_block.attentions.0.transformer_blocks.0.attn1",
  163. # SD 1.5 U-Net (ldm)
  164. "middle_block.1.transformer_blocks.0.attn1",
  165. "middle_block.1.transformer_blocks.1.attn1",
  166. "middle_block.1.transformer_blocks.2.attn1",
  167. "middle_block.1.transformer_blocks.3.attn1",
  168. "middle_block.1.transformer_blocks.4.attn1",
  169. "middle_block.1.transformer_blocks.5.attn1",
  170. "middle_block.1.transformer_blocks.6.attn1",
  171. "middle_block.1.transformer_blocks.7.attn1",
  172. "middle_block.1.transformer_blocks.8.attn1",
  173. "middle_block.1.transformer_blocks.9.attn1",
  174. ],
  175. 3 : [] # TODO - separate layers for SD-XL
  176. }
  177. RNG_INSTANCE = random.Random()
  178. @cache
  179. def get_divisors(value: int, min_value: int, /, max_options: int = 1) -> list[int]:
  180. """
  181. Returns divisors of value that
  182. x * min_value <= value
  183. in big -> small order, amount of divisors is limited by max_options
  184. """
  185. max_options = max(1, max_options) # at least 1 option should be returned
  186. min_value = min(min_value, value)
  187. divisors = [i for i in range(min_value, value + 1) if value % i == 0] # divisors in small -> big order
  188. ns = [value // i for i in divisors[:max_options]] # has at least 1 element # big -> small order
  189. return ns
  190. def random_divisor(value: int, min_value: int, /, max_options: int = 1) -> int:
  191. """
  192. Returns a random divisor of value that
  193. x * min_value <= value
  194. if max_options is 1, the behavior is deterministic
  195. """
  196. ns = get_divisors(value, min_value, max_options=max_options) # get cached divisors
  197. idx = RNG_INSTANCE.randint(0, len(ns) - 1)
  198. return ns[idx]
  199. def set_hypertile_seed(seed: int) -> None:
  200. RNG_INSTANCE.seed(seed)
  201. @cache
  202. def largest_tile_size_available(width: int, height: int) -> int:
  203. """
  204. Calculates the largest tile size available for a given width and height
  205. Tile size is always a power of 2
  206. """
  207. gcd = math.gcd(width, height)
  208. largest_tile_size_available = 1
  209. while gcd % (largest_tile_size_available * 2) == 0:
  210. largest_tile_size_available *= 2
  211. return largest_tile_size_available
  212. def iterative_closest_divisors(hw:int, aspect_ratio:float) -> tuple[int, int]:
  213. """
  214. Finds h and w such that h*w = hw and h/w = aspect_ratio
  215. We check all possible divisors of hw and return the closest to the aspect ratio
  216. """
  217. divisors = [i for i in range(2, hw + 1) if hw % i == 0] # all divisors of hw
  218. pairs = [(i, hw // i) for i in divisors] # all pairs of divisors of hw
  219. ratios = [w/h for h, w in pairs] # all ratios of pairs of divisors of hw
  220. closest_ratio = min(ratios, key=lambda x: abs(x - aspect_ratio)) # closest ratio to aspect_ratio
  221. closest_pair = pairs[ratios.index(closest_ratio)] # closest pair of divisors to aspect_ratio
  222. return closest_pair
  223. @cache
  224. def find_hw_candidates(hw:int, aspect_ratio:float) -> tuple[int, int]:
  225. """
  226. Finds h and w such that h*w = hw and h/w = aspect_ratio
  227. """
  228. h, w = round(math.sqrt(hw * aspect_ratio)), round(math.sqrt(hw / aspect_ratio))
  229. # find h and w such that h*w = hw and h/w = aspect_ratio
  230. if h * w != hw:
  231. w_candidate = hw / h
  232. # check if w is an integer
  233. if not w_candidate.is_integer():
  234. h_candidate = hw / w
  235. # check if h is an integer
  236. if not h_candidate.is_integer():
  237. return iterative_closest_divisors(hw, aspect_ratio)
  238. else:
  239. h = int(h_candidate)
  240. else:
  241. w = int(w_candidate)
  242. return h, w
  243. def self_attn_forward(params: HypertileParams, scale_depth=True) -> Callable:
  244. @wraps(params.forward)
  245. def wrapper(*args, **kwargs):
  246. if not params.enabled:
  247. return params.forward(*args, **kwargs)
  248. latent_tile_size = max(128, params.tile_size) // 8
  249. x = args[0]
  250. # VAE
  251. if x.ndim == 4:
  252. b, c, h, w = x.shape
  253. nh = random_divisor(h, latent_tile_size, params.swap_size)
  254. nw = random_divisor(w, latent_tile_size, params.swap_size)
  255. if nh * nw > 1:
  256. x = rearrange(x, "b c (nh h) (nw w) -> (b nh nw) c h w", nh=nh, nw=nw) # split into nh * nw tiles
  257. out = params.forward(x, *args[1:], **kwargs)
  258. if nh * nw > 1:
  259. out = rearrange(out, "(b nh nw) c h w -> b c (nh h) (nw w)", nh=nh, nw=nw)
  260. # U-Net
  261. else:
  262. hw: int = x.size(1)
  263. h, w = find_hw_candidates(hw, params.aspect_ratio)
  264. assert h * w == hw, f"Invalid aspect ratio {params.aspect_ratio} for input of shape {x.shape}, hw={hw}, h={h}, w={w}"
  265. factor = 2 ** params.depth if scale_depth else 1
  266. nh = random_divisor(h, latent_tile_size * factor, params.swap_size)
  267. nw = random_divisor(w, latent_tile_size * factor, params.swap_size)
  268. if nh * nw > 1:
  269. x = rearrange(x, "b (nh h nw w) c -> (b nh nw) (h w) c", h=h // nh, w=w // nw, nh=nh, nw=nw)
  270. out = params.forward(x, *args[1:], **kwargs)
  271. if nh * nw > 1:
  272. out = rearrange(out, "(b nh nw) hw c -> b nh nw hw c", nh=nh, nw=nw)
  273. out = rearrange(out, "b nh nw (h w) c -> b (nh h nw w) c", h=h // nh, w=w // nw)
  274. return out
  275. return wrapper
  276. def hypertile_hook_model(model: nn.Module, width, height, *, enable=False, tile_size_max=128, swap_size=1, max_depth=3, is_sdxl=False):
  277. hypertile_layers = getattr(model, "__webui_hypertile_layers", None)
  278. if hypertile_layers is None:
  279. if not enable:
  280. return
  281. hypertile_layers = {}
  282. layers = DEPTH_LAYERS_XL if is_sdxl else DEPTH_LAYERS
  283. for depth in range(4):
  284. for layer_name, module in model.named_modules():
  285. if any(layer_name.endswith(try_name) for try_name in layers[depth]):
  286. params = HypertileParams()
  287. module.__webui_hypertile_params = params
  288. params.forward = module.forward
  289. params.depth = depth
  290. params.layer_name = layer_name
  291. module.forward = self_attn_forward(params)
  292. hypertile_layers[layer_name] = 1
  293. model.__webui_hypertile_layers = hypertile_layers
  294. aspect_ratio = width / height
  295. tile_size = min(largest_tile_size_available(width, height), tile_size_max)
  296. for layer_name, module in model.named_modules():
  297. if layer_name in hypertile_layers:
  298. params = module.__webui_hypertile_params
  299. params.tile_size = tile_size
  300. params.swap_size = swap_size
  301. params.aspect_ratio = aspect_ratio
  302. params.enabled = enable and params.depth <= max_depth