Преглед на файлове

Add transformer forward patch

drhead преди 1 година
родител
ревизия
cc9ca67664
променени са 1 файла, в които са добавени 26 реда и са изтрити 1 реда
  1. 26 1
      modules/sd_hijack_unet.py

+ 26 - 1
modules/sd_hijack_unet.py

@@ -74,6 +74,30 @@ def timestep_embedding(_, timesteps, dim, max_period=10000, repeat_only=False):
     return embedding
 
 
+# Monkey patch to SpatialTransformer removing unnecessary contiguous calls.
+# Prevents a lot of unnecessary aten::copy_ calls
+def spatial_transformer_forward(_, self, x: torch.Tensor, context=None):
+    # note: if no context is given, cross-attention defaults to self-attention
+    if not isinstance(context, list):
+        context = [context]
+    b, c, h, w = x.shape
+    x_in = x
+    x = self.norm(x)
+    if not self.use_linear:
+        x = self.proj_in(x)
+    x = x.permute(0, 2, 3, 1).reshape(b, h * w, c)
+    if self.use_linear:
+        x = self.proj_in(x)
+    for i, block in enumerate(self.transformer_blocks):
+        x = block(x, context=context[i])
+    if self.use_linear:
+        x = self.proj_out(x)
+    x = x.view(b, h, w, c).permute(0, 3, 1, 2)
+    if not self.use_linear:
+        x = self.proj_out(x)
+    return x + x_in
+
+
 class GELUHijack(torch.nn.GELU, torch.nn.Module):
     def __init__(self, *args, **kwargs):
         torch.nn.GELU.__init__(self, *args, **kwargs)
@@ -95,7 +119,8 @@ def hijack_ddpm_edit():
 
 unet_needs_upcast = lambda *args, **kwargs: devices.unet_needs_upcast
 CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.apply_model', apply_model, unet_needs_upcast)
-CondFunc('ldm.modules.diffusionmodules.openaimodel.timestep_embedding', timestep_embedding)
+CondFunc('ldm.modules.diffusionmodules.openaimodel.timestep_embedding', timestep_embedding, lambda *args, **kwargs: True)
+CondFunc('ldm.modules.attention.SpatialTransformer.forward', spatial_transformer_forward, lambda *args, **kwargs: True)
 CondFunc('ldm.modules.diffusionmodules.openaimodel.timestep_embedding', lambda orig_func, timesteps, *args, **kwargs: orig_func(timesteps, *args, **kwargs).to(torch.float32 if timesteps.dtype == torch.int64 else devices.dtype_unet), unet_needs_upcast)
 if version.parse(torch.__version__) <= version.parse("1.13.2") or torch.cuda.is_available():
     CondFunc('ldm.modules.diffusionmodules.util.GroupNorm32.forward', lambda orig_func, self, *args, **kwargs: orig_func(self.float(), *args, **kwargs), unet_needs_upcast)