|
@@ -5,8 +5,7 @@ from modules import script_callbacks, shared, devices
|
|
|
unet_options = []
|
|
|
current_unet_option = None
|
|
|
current_unet = None
|
|
|
-original_forward = None
|
|
|
-
|
|
|
+original_forward = None # not used, only left temporarily for compatibility
|
|
|
|
|
|
def list_unets():
|
|
|
new_unets = script_callbacks.list_unets_callback()
|
|
@@ -84,9 +83,12 @@ class SdUnet(torch.nn.Module):
|
|
|
pass
|
|
|
|
|
|
|
|
|
-def UNetModel_forward(self, x, timesteps=None, context=None, *args, **kwargs):
|
|
|
- if current_unet is not None:
|
|
|
- return current_unet.forward(x, timesteps, context, *args, **kwargs)
|
|
|
+def create_unet_forward(original_forward):
|
|
|
+ def UNetModel_forward(self, x, timesteps=None, context=None, *args, **kwargs):
|
|
|
+ if current_unet is not None:
|
|
|
+ return current_unet.forward(x, timesteps, context, *args, **kwargs)
|
|
|
+
|
|
|
+ return original_forward(self, x, timesteps, context, *args, **kwargs)
|
|
|
|
|
|
- return original_forward(self, x, timesteps, context, *args, **kwargs)
|
|
|
+ return UNetModel_forward
|
|
|
|