123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351 |
- """
- Hypertile module for splitting attention layers in SD-1.5 U-Net and SD-1.5 VAE
- Warn: The patch works well only if the input image has a width and height that are multiples of 128
- Original author: @tfernd GitHub: https://github.com/tfernd/HyperTile
- """
- from __future__ import annotations
- from dataclasses import dataclass
- from typing import Callable
- from functools import wraps, cache
- import math
- import torch.nn as nn
- import random
- from einops import rearrange
- @dataclass
- class HypertileParams:
- depth = 0
- layer_name = ""
- tile_size: int = 0
- swap_size: int = 0
- aspect_ratio: float = 1.0
- forward = None
- enabled = False
- # TODO add SD-XL layers
- DEPTH_LAYERS = {
- 0: [
- # SD 1.5 U-Net (diffusers)
- "down_blocks.0.attentions.0.transformer_blocks.0.attn1",
- "down_blocks.0.attentions.1.transformer_blocks.0.attn1",
- "up_blocks.3.attentions.0.transformer_blocks.0.attn1",
- "up_blocks.3.attentions.1.transformer_blocks.0.attn1",
- "up_blocks.3.attentions.2.transformer_blocks.0.attn1",
- # SD 1.5 U-Net (ldm)
- "input_blocks.1.1.transformer_blocks.0.attn1",
- "input_blocks.2.1.transformer_blocks.0.attn1",
- "output_blocks.9.1.transformer_blocks.0.attn1",
- "output_blocks.10.1.transformer_blocks.0.attn1",
- "output_blocks.11.1.transformer_blocks.0.attn1",
- # SD 1.5 VAE
- "decoder.mid_block.attentions.0",
- "decoder.mid.attn_1",
- ],
- 1: [
- # SD 1.5 U-Net (diffusers)
- "down_blocks.1.attentions.0.transformer_blocks.0.attn1",
- "down_blocks.1.attentions.1.transformer_blocks.0.attn1",
- "up_blocks.2.attentions.0.transformer_blocks.0.attn1",
- "up_blocks.2.attentions.1.transformer_blocks.0.attn1",
- "up_blocks.2.attentions.2.transformer_blocks.0.attn1",
- # SD 1.5 U-Net (ldm)
- "input_blocks.4.1.transformer_blocks.0.attn1",
- "input_blocks.5.1.transformer_blocks.0.attn1",
- "output_blocks.6.1.transformer_blocks.0.attn1",
- "output_blocks.7.1.transformer_blocks.0.attn1",
- "output_blocks.8.1.transformer_blocks.0.attn1",
- ],
- 2: [
- # SD 1.5 U-Net (diffusers)
- "down_blocks.2.attentions.0.transformer_blocks.0.attn1",
- "down_blocks.2.attentions.1.transformer_blocks.0.attn1",
- "up_blocks.1.attentions.0.transformer_blocks.0.attn1",
- "up_blocks.1.attentions.1.transformer_blocks.0.attn1",
- "up_blocks.1.attentions.2.transformer_blocks.0.attn1",
- # SD 1.5 U-Net (ldm)
- "input_blocks.7.1.transformer_blocks.0.attn1",
- "input_blocks.8.1.transformer_blocks.0.attn1",
- "output_blocks.3.1.transformer_blocks.0.attn1",
- "output_blocks.4.1.transformer_blocks.0.attn1",
- "output_blocks.5.1.transformer_blocks.0.attn1",
- ],
- 3: [
- # SD 1.5 U-Net (diffusers)
- "mid_block.attentions.0.transformer_blocks.0.attn1",
- # SD 1.5 U-Net (ldm)
- "middle_block.1.transformer_blocks.0.attn1",
- ],
- }
- # XL layers, thanks for GitHub@gel-crabs for the help
- DEPTH_LAYERS_XL = {
- 0: [
- # SD 1.5 U-Net (diffusers)
- "down_blocks.0.attentions.0.transformer_blocks.0.attn1",
- "down_blocks.0.attentions.1.transformer_blocks.0.attn1",
- "up_blocks.3.attentions.0.transformer_blocks.0.attn1",
- "up_blocks.3.attentions.1.transformer_blocks.0.attn1",
- "up_blocks.3.attentions.2.transformer_blocks.0.attn1",
- # SD 1.5 U-Net (ldm)
- "input_blocks.4.1.transformer_blocks.0.attn1",
- "input_blocks.5.1.transformer_blocks.0.attn1",
- "output_blocks.3.1.transformer_blocks.0.attn1",
- "output_blocks.4.1.transformer_blocks.0.attn1",
- "output_blocks.5.1.transformer_blocks.0.attn1",
- # SD 1.5 VAE
- "decoder.mid_block.attentions.0",
- "decoder.mid.attn_1",
- ],
- 1: [
- # SD 1.5 U-Net (diffusers)
- #"down_blocks.1.attentions.0.transformer_blocks.0.attn1",
- #"down_blocks.1.attentions.1.transformer_blocks.0.attn1",
- #"up_blocks.2.attentions.0.transformer_blocks.0.attn1",
- #"up_blocks.2.attentions.1.transformer_blocks.0.attn1",
- #"up_blocks.2.attentions.2.transformer_blocks.0.attn1",
- # SD 1.5 U-Net (ldm)
- "input_blocks.4.1.transformer_blocks.1.attn1",
- "input_blocks.5.1.transformer_blocks.1.attn1",
- "output_blocks.3.1.transformer_blocks.1.attn1",
- "output_blocks.4.1.transformer_blocks.1.attn1",
- "output_blocks.5.1.transformer_blocks.1.attn1",
- "input_blocks.7.1.transformer_blocks.0.attn1",
- "input_blocks.8.1.transformer_blocks.0.attn1",
- "output_blocks.0.1.transformer_blocks.0.attn1",
- "output_blocks.1.1.transformer_blocks.0.attn1",
- "output_blocks.2.1.transformer_blocks.0.attn1",
- "input_blocks.7.1.transformer_blocks.1.attn1",
- "input_blocks.8.1.transformer_blocks.1.attn1",
- "output_blocks.0.1.transformer_blocks.1.attn1",
- "output_blocks.1.1.transformer_blocks.1.attn1",
- "output_blocks.2.1.transformer_blocks.1.attn1",
- "input_blocks.7.1.transformer_blocks.2.attn1",
- "input_blocks.8.1.transformer_blocks.2.attn1",
- "output_blocks.0.1.transformer_blocks.2.attn1",
- "output_blocks.1.1.transformer_blocks.2.attn1",
- "output_blocks.2.1.transformer_blocks.2.attn1",
- "input_blocks.7.1.transformer_blocks.3.attn1",
- "input_blocks.8.1.transformer_blocks.3.attn1",
- "output_blocks.0.1.transformer_blocks.3.attn1",
- "output_blocks.1.1.transformer_blocks.3.attn1",
- "output_blocks.2.1.transformer_blocks.3.attn1",
- "input_blocks.7.1.transformer_blocks.4.attn1",
- "input_blocks.8.1.transformer_blocks.4.attn1",
- "output_blocks.0.1.transformer_blocks.4.attn1",
- "output_blocks.1.1.transformer_blocks.4.attn1",
- "output_blocks.2.1.transformer_blocks.4.attn1",
- "input_blocks.7.1.transformer_blocks.5.attn1",
- "input_blocks.8.1.transformer_blocks.5.attn1",
- "output_blocks.0.1.transformer_blocks.5.attn1",
- "output_blocks.1.1.transformer_blocks.5.attn1",
- "output_blocks.2.1.transformer_blocks.5.attn1",
- "input_blocks.7.1.transformer_blocks.6.attn1",
- "input_blocks.8.1.transformer_blocks.6.attn1",
- "output_blocks.0.1.transformer_blocks.6.attn1",
- "output_blocks.1.1.transformer_blocks.6.attn1",
- "output_blocks.2.1.transformer_blocks.6.attn1",
- "input_blocks.7.1.transformer_blocks.7.attn1",
- "input_blocks.8.1.transformer_blocks.7.attn1",
- "output_blocks.0.1.transformer_blocks.7.attn1",
- "output_blocks.1.1.transformer_blocks.7.attn1",
- "output_blocks.2.1.transformer_blocks.7.attn1",
- "input_blocks.7.1.transformer_blocks.8.attn1",
- "input_blocks.8.1.transformer_blocks.8.attn1",
- "output_blocks.0.1.transformer_blocks.8.attn1",
- "output_blocks.1.1.transformer_blocks.8.attn1",
- "output_blocks.2.1.transformer_blocks.8.attn1",
- "input_blocks.7.1.transformer_blocks.9.attn1",
- "input_blocks.8.1.transformer_blocks.9.attn1",
- "output_blocks.0.1.transformer_blocks.9.attn1",
- "output_blocks.1.1.transformer_blocks.9.attn1",
- "output_blocks.2.1.transformer_blocks.9.attn1",
- ],
- 2: [
- # SD 1.5 U-Net (diffusers)
- "mid_block.attentions.0.transformer_blocks.0.attn1",
- # SD 1.5 U-Net (ldm)
- "middle_block.1.transformer_blocks.0.attn1",
- "middle_block.1.transformer_blocks.1.attn1",
- "middle_block.1.transformer_blocks.2.attn1",
- "middle_block.1.transformer_blocks.3.attn1",
- "middle_block.1.transformer_blocks.4.attn1",
- "middle_block.1.transformer_blocks.5.attn1",
- "middle_block.1.transformer_blocks.6.attn1",
- "middle_block.1.transformer_blocks.7.attn1",
- "middle_block.1.transformer_blocks.8.attn1",
- "middle_block.1.transformer_blocks.9.attn1",
- ],
- 3 : [] # TODO - separate layers for SD-XL
- }
- RNG_INSTANCE = random.Random()
- @cache
- def get_divisors(value: int, min_value: int, /, max_options: int = 1) -> list[int]:
- """
- Returns divisors of value that
- x * min_value <= value
- in big -> small order, amount of divisors is limited by max_options
- """
- max_options = max(1, max_options) # at least 1 option should be returned
- min_value = min(min_value, value)
- divisors = [i for i in range(min_value, value + 1) if value % i == 0] # divisors in small -> big order
- ns = [value // i for i in divisors[:max_options]] # has at least 1 element # big -> small order
- return ns
- def random_divisor(value: int, min_value: int, /, max_options: int = 1) -> int:
- """
- Returns a random divisor of value that
- x * min_value <= value
- if max_options is 1, the behavior is deterministic
- """
- ns = get_divisors(value, min_value, max_options=max_options) # get cached divisors
- idx = RNG_INSTANCE.randint(0, len(ns) - 1)
- return ns[idx]
- def set_hypertile_seed(seed: int) -> None:
- RNG_INSTANCE.seed(seed)
- @cache
- def largest_tile_size_available(width: int, height: int) -> int:
- """
- Calculates the largest tile size available for a given width and height
- Tile size is always a power of 2
- """
- gcd = math.gcd(width, height)
- largest_tile_size_available = 1
- while gcd % (largest_tile_size_available * 2) == 0:
- largest_tile_size_available *= 2
- return largest_tile_size_available
- def iterative_closest_divisors(hw:int, aspect_ratio:float) -> tuple[int, int]:
- """
- Finds h and w such that h*w = hw and h/w = aspect_ratio
- We check all possible divisors of hw and return the closest to the aspect ratio
- """
- divisors = [i for i in range(2, hw + 1) if hw % i == 0] # all divisors of hw
- pairs = [(i, hw // i) for i in divisors] # all pairs of divisors of hw
- ratios = [w/h for h, w in pairs] # all ratios of pairs of divisors of hw
- closest_ratio = min(ratios, key=lambda x: abs(x - aspect_ratio)) # closest ratio to aspect_ratio
- closest_pair = pairs[ratios.index(closest_ratio)] # closest pair of divisors to aspect_ratio
- return closest_pair
- @cache
- def find_hw_candidates(hw:int, aspect_ratio:float) -> tuple[int, int]:
- """
- Finds h and w such that h*w = hw and h/w = aspect_ratio
- """
- h, w = round(math.sqrt(hw * aspect_ratio)), round(math.sqrt(hw / aspect_ratio))
- # find h and w such that h*w = hw and h/w = aspect_ratio
- if h * w != hw:
- w_candidate = hw / h
- # check if w is an integer
- if not w_candidate.is_integer():
- h_candidate = hw / w
- # check if h is an integer
- if not h_candidate.is_integer():
- return iterative_closest_divisors(hw, aspect_ratio)
- else:
- h = int(h_candidate)
- else:
- w = int(w_candidate)
- return h, w
- def self_attn_forward(params: HypertileParams, scale_depth=True) -> Callable:
- @wraps(params.forward)
- def wrapper(*args, **kwargs):
- if not params.enabled:
- return params.forward(*args, **kwargs)
- latent_tile_size = max(128, params.tile_size) // 8
- x = args[0]
- # VAE
- if x.ndim == 4:
- b, c, h, w = x.shape
- nh = random_divisor(h, latent_tile_size, params.swap_size)
- nw = random_divisor(w, latent_tile_size, params.swap_size)
- if nh * nw > 1:
- x = rearrange(x, "b c (nh h) (nw w) -> (b nh nw) c h w", nh=nh, nw=nw) # split into nh * nw tiles
- out = params.forward(x, *args[1:], **kwargs)
- if nh * nw > 1:
- out = rearrange(out, "(b nh nw) c h w -> b c (nh h) (nw w)", nh=nh, nw=nw)
- # U-Net
- else:
- hw: int = x.size(1)
- h, w = find_hw_candidates(hw, params.aspect_ratio)
- assert h * w == hw, f"Invalid aspect ratio {params.aspect_ratio} for input of shape {x.shape}, hw={hw}, h={h}, w={w}"
- factor = 2 ** params.depth if scale_depth else 1
- nh = random_divisor(h, latent_tile_size * factor, params.swap_size)
- nw = random_divisor(w, latent_tile_size * factor, params.swap_size)
- if nh * nw > 1:
- x = rearrange(x, "b (nh h nw w) c -> (b nh nw) (h w) c", h=h // nh, w=w // nw, nh=nh, nw=nw)
- out = params.forward(x, *args[1:], **kwargs)
- if nh * nw > 1:
- out = rearrange(out, "(b nh nw) hw c -> b nh nw hw c", nh=nh, nw=nw)
- out = rearrange(out, "b nh nw (h w) c -> b (nh h nw w) c", h=h // nh, w=w // nw)
- return out
- return wrapper
- def hypertile_hook_model(model: nn.Module, width, height, *, enable=False, tile_size_max=128, swap_size=1, max_depth=3, is_sdxl=False):
- hypertile_layers = getattr(model, "__webui_hypertile_layers", None)
- if hypertile_layers is None:
- if not enable:
- return
- hypertile_layers = {}
- layers = DEPTH_LAYERS_XL if is_sdxl else DEPTH_LAYERS
- for depth in range(4):
- for layer_name, module in model.named_modules():
- if any(layer_name.endswith(try_name) for try_name in layers[depth]):
- params = HypertileParams()
- module.__webui_hypertile_params = params
- params.forward = module.forward
- params.depth = depth
- params.layer_name = layer_name
- module.forward = self_attn_forward(params)
- hypertile_layers[layer_name] = 1
- model.__webui_hypertile_layers = hypertile_layers
- aspect_ratio = width / height
- tile_size = min(largest_tile_size_available(width, height), tile_size_max)
- for layer_name, module in model.named_modules():
- if layer_name in hypertile_layers:
- params = module.__webui_hypertile_params
- params.tile_size = tile_size
- params.swap_size = swap_size
- params.aspect_ratio = aspect_ratio
- params.enabled = enable and params.depth <= max_depth
|