Browse Source

fix for unet hijack breaking the train tab

AUTOMATIC 2 years ago
parent
commit
15e89ef0f6
1 changed files with 5 additions and 2 deletions
  1. 5 2
      modules/sd_hijack_unet.py

+ 5 - 2
modules/sd_hijack_unet.py

@@ -36,8 +36,11 @@ th = TorchHijackForUnet()
 
 # Below are monkey patches to enable upcasting a float16 UNet for float32 sampling
 def apply_model(orig_func, self, x_noisy, t, cond, **kwargs):
-    for y in cond.keys():
-        cond[y] = [x.to(devices.dtype_unet) if isinstance(x, torch.Tensor) else x for x in cond[y]]
+
+    if isinstance(cond, dict):
+        for y in cond.keys():
+            cond[y] = [x.to(devices.dtype_unet) if isinstance(x, torch.Tensor) else x for x in cond[y]]
+
     with devices.autocast():
         return orig_func(self, x_noisy.to(devices.dtype_unet), t.to(devices.dtype_unet), cond, **kwargs).float()