sd_hijack_unet.py 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156
  1. import torch
  2. from packaging import version
  3. from einops import repeat
  4. import math
  5. from modules import devices
  6. from modules.sd_hijack_utils import CondFunc
  7. class TorchHijackForUnet:
  8. """
  9. This is torch, but with cat that resizes tensors to appropriate dimensions if they do not match;
  10. this makes it possible to create pictures with dimensions that are multiples of 8 rather than 64
  11. """
  12. def __getattr__(self, item):
  13. if item == 'cat':
  14. return self.cat
  15. if hasattr(torch, item):
  16. return getattr(torch, item)
  17. raise AttributeError(f"'{type(self).__name__}' object has no attribute '{item}'")
  18. def cat(self, tensors, *args, **kwargs):
  19. if len(tensors) == 2:
  20. a, b = tensors
  21. if a.shape[-2:] != b.shape[-2:]:
  22. a = torch.nn.functional.interpolate(a, b.shape[-2:], mode="nearest")
  23. tensors = (a, b)
  24. return torch.cat(tensors, *args, **kwargs)
  25. th = TorchHijackForUnet()
  26. # Below are monkey patches to enable upcasting a float16 UNet for float32 sampling
  27. def apply_model(orig_func, self, x_noisy, t, cond, **kwargs):
  28. """Always make sure inputs to unet are in correct dtype."""
  29. if isinstance(cond, dict):
  30. for y in cond.keys():
  31. if isinstance(cond[y], list):
  32. cond[y] = [x.to(devices.dtype_unet) if isinstance(x, torch.Tensor) else x for x in cond[y]]
  33. else:
  34. cond[y] = cond[y].to(devices.dtype_unet) if isinstance(cond[y], torch.Tensor) else cond[y]
  35. with devices.autocast():
  36. result = orig_func(self, x_noisy.to(devices.dtype_unet), t.to(devices.dtype_unet), cond, **kwargs)
  37. if devices.unet_needs_upcast:
  38. return result.float()
  39. else:
  40. return result
  41. # Monkey patch to create timestep embed tensor on device, avoiding a block.
  42. def timestep_embedding(_, timesteps, dim, max_period=10000, repeat_only=False):
  43. """
  44. Create sinusoidal timestep embeddings.
  45. :param timesteps: a 1-D Tensor of N indices, one per batch element.
  46. These may be fractional.
  47. :param dim: the dimension of the output.
  48. :param max_period: controls the minimum frequency of the embeddings.
  49. :return: an [N x dim] Tensor of positional embeddings.
  50. """
  51. if not repeat_only:
  52. half = dim // 2
  53. freqs = torch.exp(
  54. -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=timesteps.device) / half
  55. )
  56. args = timesteps[:, None].float() * freqs[None]
  57. embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
  58. if dim % 2:
  59. embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
  60. else:
  61. embedding = repeat(timesteps, 'b -> b d', d=dim)
  62. return embedding
  63. # Monkey patch to SpatialTransformer removing unnecessary contiguous calls.
  64. # Prevents a lot of unnecessary aten::copy_ calls
  65. def spatial_transformer_forward(_, self, x: torch.Tensor, context=None):
  66. # note: if no context is given, cross-attention defaults to self-attention
  67. if not isinstance(context, list):
  68. context = [context]
  69. b, c, h, w = x.shape
  70. x_in = x
  71. x = self.norm(x)
  72. if not self.use_linear:
  73. x = self.proj_in(x)
  74. x = x.permute(0, 2, 3, 1).reshape(b, h * w, c)
  75. if self.use_linear:
  76. x = self.proj_in(x)
  77. for i, block in enumerate(self.transformer_blocks):
  78. x = block(x, context=context[i])
  79. if self.use_linear:
  80. x = self.proj_out(x)
  81. x = x.view(b, h, w, c).permute(0, 3, 1, 2)
  82. if not self.use_linear:
  83. x = self.proj_out(x)
  84. return x + x_in
  85. class GELUHijack(torch.nn.GELU, torch.nn.Module):
  86. def __init__(self, *args, **kwargs):
  87. torch.nn.GELU.__init__(self, *args, **kwargs)
  88. def forward(self, x):
  89. if devices.unet_needs_upcast:
  90. return torch.nn.GELU.forward(self.float(), x.float()).to(devices.dtype_unet)
  91. else:
  92. return torch.nn.GELU.forward(self, x)
  93. ddpm_edit_hijack = None
  94. def hijack_ddpm_edit():
  95. global ddpm_edit_hijack
  96. if not ddpm_edit_hijack:
  97. CondFunc('modules.models.diffusion.ddpm_edit.LatentDiffusion.decode_first_stage', first_stage_sub, first_stage_cond)
  98. CondFunc('modules.models.diffusion.ddpm_edit.LatentDiffusion.encode_first_stage', first_stage_sub, first_stage_cond)
  99. ddpm_edit_hijack = CondFunc('modules.models.diffusion.ddpm_edit.LatentDiffusion.apply_model', apply_model)
  100. unet_needs_upcast = lambda *args, **kwargs: devices.unet_needs_upcast
  101. CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.apply_model', apply_model, unet_needs_upcast)
  102. CondFunc('ldm.modules.diffusionmodules.openaimodel.timestep_embedding', timestep_embedding)
  103. CondFunc('ldm.modules.attention.SpatialTransformer.forward', spatial_transformer_forward)
  104. CondFunc('ldm.modules.diffusionmodules.openaimodel.timestep_embedding', lambda orig_func, timesteps, *args, **kwargs: orig_func(timesteps, *args, **kwargs).to(torch.float32 if timesteps.dtype == torch.int64 else devices.dtype_unet), unet_needs_upcast)
  105. if version.parse(torch.__version__) <= version.parse("1.13.2") or torch.cuda.is_available():
  106. CondFunc('ldm.modules.diffusionmodules.util.GroupNorm32.forward', lambda orig_func, self, *args, **kwargs: orig_func(self.float(), *args, **kwargs), unet_needs_upcast)
  107. CondFunc('ldm.modules.attention.GEGLU.forward', lambda orig_func, self, x: orig_func(self.float(), x.float()).to(devices.dtype_unet), unet_needs_upcast)
  108. CondFunc('open_clip.transformer.ResidualAttentionBlock.__init__', lambda orig_func, *args, **kwargs: kwargs.update({'act_layer': GELUHijack}) and False or orig_func(*args, **kwargs), lambda _, *args, **kwargs: kwargs.get('act_layer') is None or kwargs['act_layer'] == torch.nn.GELU)
  109. first_stage_cond = lambda _, self, *args, **kwargs: devices.unet_needs_upcast and self.model.diffusion_model.dtype == torch.float16
  110. first_stage_sub = lambda orig_func, self, x, **kwargs: orig_func(self, x.to(devices.dtype_vae), **kwargs)
  111. CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.decode_first_stage', first_stage_sub, first_stage_cond)
  112. CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.encode_first_stage', first_stage_sub, first_stage_cond)
  113. CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.get_first_stage_encoding', lambda orig_func, *args, **kwargs: orig_func(*args, **kwargs).float(), first_stage_cond)
  114. # Always make sure inputs to unet are in correct dtype
  115. CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.apply_model', apply_model)
  116. CondFunc('sgm.modules.diffusionmodules.wrappers.OpenAIWrapper.forward', apply_model)
  117. def timestep_embedding_cast_result(orig_func, timesteps, *args, **kwargs):
  118. if devices.unet_needs_upcast and timesteps.dtype == torch.int64:
  119. dtype = torch.float32
  120. else:
  121. dtype = devices.dtype_unet
  122. return orig_func(timesteps, *args, **kwargs).to(dtype=dtype)
  123. # Always make sure timestep calculation is in correct dtype
  124. CondFunc('ldm.modules.diffusionmodules.openaimodel.timestep_embedding', timestep_embedding_cast_result)
  125. CondFunc('sgm.modules.diffusionmodules.openaimodel.timestep_embedding', timestep_embedding_cast_result)