فهرست منبع

In-place operations can break gradient calculation

brkirch 2 سال پیش
والد
کامیت
df3b31eb55
1فایلهای تغییر یافته به همراه2 افزوده شده و 2 حذف شده
  1. 2 2
      modules/sd_hijack_clip.py

+ 2 - 2
modules/sd_hijack_clip.py

@@ -247,9 +247,9 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
         # restoring original mean is likely not correct, but it seems to work well to prevent artifacts that happen otherwise
         batch_multipliers = torch.asarray(batch_multipliers).to(devices.device)
         original_mean = z.mean()
-        z *= batch_multipliers.reshape(batch_multipliers.shape + (1,)).expand(z.shape)
+        z = z * batch_multipliers.reshape(batch_multipliers.shape + (1,)).expand(z.shape)
         new_mean = z.mean()
-        z *= original_mean / new_mean
+        z = z * (original_mean / new_mean)
 
         return z