浏览代码

mps, xpu compatibility

w-e-w 1 年之前
父节点
当前提交
41f66849c7
共有 1 个文件被更改,包括 4 次插入5 次删除
  1. 4 5
      extensions-builtin/soft-inpainting/scripts/soft_inpainting.py

+ 4 - 5
extensions-builtin/soft-inpainting/scripts/soft_inpainting.py

@@ -3,6 +3,7 @@ import gradio as gr
 import math
 from modules.ui_components import InputAccordion
 import modules.scripts as scripts
+from modules.torch_utils import float64
 
 
 class SoftInpaintingSettings:
@@ -79,13 +80,11 @@ def latent_blend(settings, a, b, t):
 
     # Calculate the magnitude of the interpolated vectors. (We will remove this magnitude.)
     # 64-bit operations are used here to allow large exponents.
-    current_magnitude = torch.norm(image_interp, p=2, dim=1, keepdim=True).to(torch.float64).add_(0.00001)
+    current_magnitude = torch.norm(image_interp, p=2, dim=1, keepdim=True).to(float64(image_interp)).add_(0.00001)
 
     # Interpolate the powered magnitudes, then un-power them (bring them back to a power of 1).
-    a_magnitude = torch.norm(a, p=2, dim=1, keepdim=True).to(torch.float64).pow_(
-        settings.inpaint_detail_preservation) * one_minus_t3
-    b_magnitude = torch.norm(b, p=2, dim=1, keepdim=True).to(torch.float64).pow_(
-        settings.inpaint_detail_preservation) * t3
+    a_magnitude = torch.norm(a, p=2, dim=1, keepdim=True).to(float64(a)).pow_(settings.inpaint_detail_preservation) * one_minus_t3
+    b_magnitude = torch.norm(b, p=2, dim=1, keepdim=True).to(float64(b)).pow_(settings.inpaint_detail_preservation) * t3
     desired_magnitude = a_magnitude
     desired_magnitude.add_(b_magnitude).pow_(1 / settings.inpaint_detail_preservation)
     del a_magnitude, b_magnitude, t3, one_minus_t3