浏览代码

In-place operations can break gradient calculation

brkirch 2 年之前
父节点
当前提交
df3b31eb55
共有 1 个文件被更改,包括 2 次插入2 次删除
  1. 2 2
      modules/sd_hijack_clip.py

+ 2 - 2
modules/sd_hijack_clip.py

@@ -247,9 +247,9 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
         # restoring original mean is likely not correct, but it seems to work well to prevent artifacts that happen otherwise
         # restoring original mean is likely not correct, but it seems to work well to prevent artifacts that happen otherwise
         batch_multipliers = torch.asarray(batch_multipliers).to(devices.device)
         batch_multipliers = torch.asarray(batch_multipliers).to(devices.device)
         original_mean = z.mean()
         original_mean = z.mean()
-        z *= batch_multipliers.reshape(batch_multipliers.shape + (1,)).expand(z.shape)
+        z = z * batch_multipliers.reshape(batch_multipliers.shape + (1,)).expand(z.shape)
         new_mean = z.mean()
         new_mean = z.mean()
-        z *= original_mean / new_mean
+        z = z * (original_mean / new_mean)
 
 
         return z
         return z