Sfoglia il codice sorgente

Verify architecture for loaded Spandrel models

Aarni Koskela 1 anno fa
parent
commit
4ad0c0c0a8

+ 1 - 1
extensions-builtin/ScuNET/scripts/scunet_model.py

@@ -121,7 +121,7 @@ class UpscalerScuNET(modules.upscaler.Upscaler):
             filename = modelloader.load_file_from_url(self.model_url, model_dir=self.model_download_path, file_name=f"{self.name}.pth")
         else:
             filename = path
-        return modelloader.load_spandrel_model(filename, device=device)
+        return modelloader.load_spandrel_model(filename, device=device, expected_architecture='SCUNet')
 
 
 def on_ui_settings():

+ 1 - 0
extensions-builtin/SwinIR/scripts/swinir_model.py

@@ -75,6 +75,7 @@ class UpscalerSwinIR(Upscaler):
             filename,
             device=self._get_device(),
             dtype=devices.dtype,
+            expected_architecture="SwinIR",
         )
         if getattr(opts, 'SWIN_torch_compile', False):
             try:

+ 1 - 0
modules/codeformer_model.py

@@ -37,6 +37,7 @@ class FaceRestorerCodeFormer(face_restoration_utils.CommonFaceRestoration):
             return modelloader.load_spandrel_model(
                 model_path,
                 device=devices.device_codeformer,
+                expected_architecture='CodeFormer',
             ).model
         raise ValueError("No codeformer model found")
 

+ 1 - 0
modules/esrgan_model.py

@@ -49,6 +49,7 @@ class UpscalerESRGAN(Upscaler):
         return modelloader.load_spandrel_model(
             filename,
             device=('cpu' if devices.device_esrgan.type == 'mps' else None),
+            expected_architecture='ESRGAN',
         )
 
 

+ 1 - 0
modules/gfpgan_model.py

@@ -37,6 +37,7 @@ class FaceRestorerGFPGAN(face_restoration_utils.CommonFaceRestoration):
                 net = modelloader.load_spandrel_model(
                     model_path,
                     device=self.get_device(),
+                    expected_architecture='GFPGAN',
                 ).model
                 net.different_w = True  # see https://github.com/chaiNNer-org/spandrel/pull/81
                 return net

+ 1 - 0
modules/hat_model.py

@@ -39,4 +39,5 @@ class UpscalerHAT(Upscaler):
         return modelloader.load_spandrel_model(
             path,
             device=devices.device_esrgan,  # TODO: should probably be device_hat
+            expected_architecture='HAT',
         )

+ 12 - 1
modules/modelloader.py

@@ -6,6 +6,8 @@ import shutil
 import importlib
 from urllib.parse import urlparse
 
+import torch
+
 from modules import shared
 from modules.upscaler import Upscaler, UpscalerLanczos, UpscalerNearest, UpscalerNone
 from modules.paths import script_path, models_path
@@ -183,9 +185,18 @@ def load_upscalers():
     )
 
 
-def load_spandrel_model(path, *, device, half: bool = False, dtype=None):
+def load_spandrel_model(
+    path: str,
+    *,
+    device: str | torch.device | None,
+    half: bool = False,
+    dtype: str | None = None,
+    expected_architecture: str | None = None,
+):
     import spandrel
     model = spandrel.ModelLoader(device=device).load_from_file(path)
+    if expected_architecture and model.architecture != expected_architecture:
+        raise TypeError(f"Model {path} is not a {expected_architecture} model")
     if half:
         model = model.model.half()
     if dtype:

+ 4 - 3
modules/realesrgan_model.py

@@ -1,9 +1,9 @@
 import os
 
-from modules.upscaler_utils import upscale_with_model
-from modules.upscaler import Upscaler, UpscalerData
-from modules.shared import cmd_opts, opts
 from modules import modelloader, errors
+from modules.shared import cmd_opts, opts
+from modules.upscaler import Upscaler, UpscalerData
+from modules.upscaler_utils import upscale_with_model
 
 
 class UpscalerRealESRGAN(Upscaler):
@@ -40,6 +40,7 @@ class UpscalerRealESRGAN(Upscaler):
             info.local_data_path,
             device=self.device,
             half=(not cmd_opts.no_half and not cmd_opts.upcast_sampling),
+            expected_architecture="RealESRGAN",
         )
         return upscale_with_model(
             mod,