Browse Source

load_spandrel_model: make `half` `prefer_half`

As discussed with the Spandrel folks, it's good to heed Spandrel's
"supports half precision" flag to avoid e.g. black blotches and what-not.
Aarni Koskela 1 năm trước cách đây
mục cha
commit
2cacbc124c
2 tập tin đã thay đổi với 15 bổ sung7 xóa
  1. 14 6
      modules/modelloader.py
  2. 1 1
      modules/realesrgan_model.py

+ 14 - 6
modules/modelloader.py

@@ -139,23 +139,31 @@ def load_upscalers():
 
 
 def load_spandrel_model(
-    path: str,
+    path: str | os.PathLike,
     *,
     device: str | torch.device | None,
-    half: bool = False,
+    prefer_half: bool = False,
     dtype: str | torch.dtype | None = None,
     expected_architecture: str | None = None,
 ) -> spandrel.ModelDescriptor:
     import spandrel
-    model_descriptor = spandrel.ModelLoader(device=device).load_from_file(path)
+    model_descriptor = spandrel.ModelLoader(device=device).load_from_file(str(path))
     if expected_architecture and model_descriptor.architecture != expected_architecture:
         logger.warning(
             f"Model {path!r} is not a {expected_architecture!r} model (got {model_descriptor.architecture!r})",
         )
-    if half:
-        model_descriptor.model.half()
+    half = False
+    if prefer_half:
+        if model_descriptor.supports_half:
+            model_descriptor.model.half()
+            half = True
+        else:
+            logger.info("Model %s does not support half precision, ignoring --half", path)
     if dtype:
         model_descriptor.model.to(dtype=dtype)
     model_descriptor.model.eval()
-    logger.debug("Loaded %s from %s (device=%s, half=%s, dtype=%s)", model_descriptor, path, device, half, dtype)
+    logger.debug(
+        "Loaded %s from %s (device=%s, half=%s, dtype=%s)",
+        model_descriptor, path, device, half, dtype,
+    )
     return model_descriptor

+ 1 - 1
modules/realesrgan_model.py

@@ -39,7 +39,7 @@ class UpscalerRealESRGAN(Upscaler):
         model_descriptor = modelloader.load_spandrel_model(
             info.local_data_path,
             device=self.device,
-            half=(not cmd_opts.no_half and not cmd_opts.upcast_sampling),
+            prefer_half=(not cmd_opts.no_half and not cmd_opts.upcast_sampling),
             expected_architecture="ESRGAN",  # "RealESRGAN" isn't a specific thing for Spandrel
         )
         return upscale_with_model(