Browse Source

alternate implementation for unet forward replacement that does not depend on hijack being applied

AUTOMATIC1111 1 year ago
parent
commit
ac02216e54
2 changed files with 13 additions and 8 deletions
  1. 5 2
      modules/sd_hijack.py
  2. 8 6
      modules/sd_unet.py

+ 5 - 2
modules/sd_hijack.py

@@ -38,8 +38,11 @@ ldm.models.diffusion.ddpm.print = shared.ldm_print
 optimizers = []
 current_optimizer: sd_hijack_optimizations.SdOptimization = None
 
-ldm_original_forward = patches.patch(__file__, ldm.modules.diffusionmodules.openaimodel.UNetModel, "forward", sd_unet.UNetModel_forward)
-sgm_original_forward = patches.patch(__file__, sgm.modules.diffusionmodules.openaimodel.UNetModel, "forward", sd_unet.UNetModel_forward)
+ldm_patched_forward = sd_unet.create_unet_forward(ldm.modules.diffusionmodules.openaimodel.UNetModel.forward)
+ldm_original_forward = patches.patch(__file__, ldm.modules.diffusionmodules.openaimodel.UNetModel, "forward", ldm_patched_forward)
+
+sgm_patched_forward = sd_unet.create_unet_forward(sgm.modules.diffusionmodules.openaimodel.UNetModel.forward)
+sgm_original_forward = patches.patch(__file__, sgm.modules.diffusionmodules.openaimodel.UNetModel, "forward", sgm_patched_forward)
 
 
 def list_optimizers():

+ 8 - 6
modules/sd_unet.py

@@ -5,8 +5,7 @@ from modules import script_callbacks, shared, devices
 unet_options = []
 current_unet_option = None
 current_unet = None
-original_forward = None
-
+original_forward = None  # not used, only left temporarily for compatibility
 
 def list_unets():
     new_unets = script_callbacks.list_unets_callback()
@@ -84,9 +83,12 @@ class SdUnet(torch.nn.Module):
         pass
 
 
-def UNetModel_forward(self, x, timesteps=None, context=None, *args, **kwargs):
-    if current_unet is not None:
-        return current_unet.forward(x, timesteps, context, *args, **kwargs)
+def create_unet_forward(original_forward):
+    def UNetModel_forward(self, x, timesteps=None, context=None, *args, **kwargs):
+        if current_unet is not None:
+            return current_unet.forward(x, timesteps, context, *args, **kwargs)
+
+        return original_forward(self, x, timesteps, context, *args, **kwargs)
 
-    return original_forward(self, x, timesteps, context, *args, **kwargs)
+    return UNetModel_forward