|
@@ -1,5 +1,7 @@
|
|
|
import torch
|
|
|
from packaging import version
|
|
|
+from einops import repeat
|
|
|
+import math
|
|
|
|
|
|
from modules import devices
|
|
|
from modules.sd_hijack_utils import CondFunc
|
|
@@ -36,7 +38,7 @@ th = TorchHijackForUnet()
|
|
|
|
|
|
# Below are monkey patches to enable upcasting a float16 UNet for float32 sampling
|
|
|
def apply_model(orig_func, self, x_noisy, t, cond, **kwargs):
|
|
|
-
|
|
|
+ """Always make sure inputs to unet are in correct dtype."""
|
|
|
if isinstance(cond, dict):
|
|
|
for y in cond.keys():
|
|
|
if isinstance(cond[y], list):
|
|
@@ -45,7 +47,59 @@ def apply_model(orig_func, self, x_noisy, t, cond, **kwargs):
|
|
|
cond[y] = cond[y].to(devices.dtype_unet) if isinstance(cond[y], torch.Tensor) else cond[y]
|
|
|
|
|
|
with devices.autocast():
|
|
|
- return orig_func(self, x_noisy.to(devices.dtype_unet), t.to(devices.dtype_unet), cond, **kwargs).float()
|
|
|
+ result = orig_func(self, x_noisy.to(devices.dtype_unet), t.to(devices.dtype_unet), cond, **kwargs)
|
|
|
+ if devices.unet_needs_upcast:
|
|
|
+ return result.float()
|
|
|
+ else:
|
|
|
+ return result
|
|
|
+
|
|
|
+
|
|
|
+# Monkey patch to create timestep embed tensor on device, avoiding a block.
|
|
|
+def timestep_embedding(_, timesteps, dim, max_period=10000, repeat_only=False):
|
|
|
+ """
|
|
|
+ Create sinusoidal timestep embeddings.
|
|
|
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
|
|
|
+ These may be fractional.
|
|
|
+ :param dim: the dimension of the output.
|
|
|
+ :param max_period: controls the minimum frequency of the embeddings.
|
|
|
+ :return: an [N x dim] Tensor of positional embeddings.
|
|
|
+ """
|
|
|
+ if not repeat_only:
|
|
|
+ half = dim // 2
|
|
|
+ freqs = torch.exp(
|
|
|
+ -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=timesteps.device) / half
|
|
|
+ )
|
|
|
+ args = timesteps[:, None].float() * freqs[None]
|
|
|
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
|
|
+ if dim % 2:
|
|
|
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
|
|
+ else:
|
|
|
+ embedding = repeat(timesteps, 'b -> b d', d=dim)
|
|
|
+ return embedding
|
|
|
+
|
|
|
+
|
|
|
+# Monkey patch to SpatialTransformer removing unnecessary contiguous calls.
|
|
|
+# Prevents a lot of unnecessary aten::copy_ calls
|
|
|
+def spatial_transformer_forward(_, self, x: torch.Tensor, context=None):
|
|
|
+ # note: if no context is given, cross-attention defaults to self-attention
|
|
|
+ if not isinstance(context, list):
|
|
|
+ context = [context]
|
|
|
+ b, c, h, w = x.shape
|
|
|
+ x_in = x
|
|
|
+ x = self.norm(x)
|
|
|
+ if not self.use_linear:
|
|
|
+ x = self.proj_in(x)
|
|
|
+ x = x.permute(0, 2, 3, 1).reshape(b, h * w, c)
|
|
|
+ if self.use_linear:
|
|
|
+ x = self.proj_in(x)
|
|
|
+ for i, block in enumerate(self.transformer_blocks):
|
|
|
+ x = block(x, context=context[i])
|
|
|
+ if self.use_linear:
|
|
|
+ x = self.proj_out(x)
|
|
|
+ x = x.view(b, h, w, c).permute(0, 3, 1, 2)
|
|
|
+ if not self.use_linear:
|
|
|
+ x = self.proj_out(x)
|
|
|
+ return x + x_in
|
|
|
|
|
|
|
|
|
class GELUHijack(torch.nn.GELU, torch.nn.Module):
|
|
@@ -64,12 +118,15 @@ def hijack_ddpm_edit():
|
|
|
if not ddpm_edit_hijack:
|
|
|
CondFunc('modules.models.diffusion.ddpm_edit.LatentDiffusion.decode_first_stage', first_stage_sub, first_stage_cond)
|
|
|
CondFunc('modules.models.diffusion.ddpm_edit.LatentDiffusion.encode_first_stage', first_stage_sub, first_stage_cond)
|
|
|
- ddpm_edit_hijack = CondFunc('modules.models.diffusion.ddpm_edit.LatentDiffusion.apply_model', apply_model, unet_needs_upcast)
|
|
|
+ ddpm_edit_hijack = CondFunc('modules.models.diffusion.ddpm_edit.LatentDiffusion.apply_model', apply_model)
|
|
|
|
|
|
|
|
|
unet_needs_upcast = lambda *args, **kwargs: devices.unet_needs_upcast
|
|
|
CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.apply_model', apply_model, unet_needs_upcast)
|
|
|
+CondFunc('ldm.modules.diffusionmodules.openaimodel.timestep_embedding', timestep_embedding)
|
|
|
+CondFunc('ldm.modules.attention.SpatialTransformer.forward', spatial_transformer_forward)
|
|
|
CondFunc('ldm.modules.diffusionmodules.openaimodel.timestep_embedding', lambda orig_func, timesteps, *args, **kwargs: orig_func(timesteps, *args, **kwargs).to(torch.float32 if timesteps.dtype == torch.int64 else devices.dtype_unet), unet_needs_upcast)
|
|
|
+
|
|
|
if version.parse(torch.__version__) <= version.parse("1.13.2") or torch.cuda.is_available():
|
|
|
CondFunc('ldm.modules.diffusionmodules.util.GroupNorm32.forward', lambda orig_func, self, *args, **kwargs: orig_func(self.float(), *args, **kwargs), unet_needs_upcast)
|
|
|
CondFunc('ldm.modules.attention.GEGLU.forward', lambda orig_func, self, x: orig_func(self.float(), x.float()).to(devices.dtype_unet), unet_needs_upcast)
|
|
@@ -81,5 +138,17 @@ CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.decode_first_stage', first_s
|
|
|
CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.encode_first_stage', first_stage_sub, first_stage_cond)
|
|
|
CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.get_first_stage_encoding', lambda orig_func, *args, **kwargs: orig_func(*args, **kwargs).float(), first_stage_cond)
|
|
|
|
|
|
-CondFunc('sgm.modules.diffusionmodules.wrappers.OpenAIWrapper.forward', apply_model, unet_needs_upcast)
|
|
|
-CondFunc('sgm.modules.diffusionmodules.openaimodel.timestep_embedding', lambda orig_func, timesteps, *args, **kwargs: orig_func(timesteps, *args, **kwargs).to(torch.float32 if timesteps.dtype == torch.int64 else devices.dtype_unet), unet_needs_upcast)
|
|
|
+CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.apply_model', apply_model)
|
|
|
+CondFunc('sgm.modules.diffusionmodules.wrappers.OpenAIWrapper.forward', apply_model)
|
|
|
+
|
|
|
+
|
|
|
+def timestep_embedding_cast_result(orig_func, timesteps, *args, **kwargs):
|
|
|
+ if devices.unet_needs_upcast and timesteps.dtype == torch.int64:
|
|
|
+ dtype = torch.float32
|
|
|
+ else:
|
|
|
+ dtype = devices.dtype_unet
|
|
|
+ return orig_func(timesteps, *args, **kwargs).to(dtype=dtype)
|
|
|
+
|
|
|
+
|
|
|
+CondFunc('ldm.modules.diffusionmodules.openaimodel.timestep_embedding', timestep_embedding_cast_result)
|
|
|
+CondFunc('sgm.modules.diffusionmodules.openaimodel.timestep_embedding', timestep_embedding_cast_result)
|