|
@@ -6,6 +6,8 @@ import shutil
|
|
|
import importlib
|
|
|
from urllib.parse import urlparse
|
|
|
|
|
|
+import torch
|
|
|
+
|
|
|
from modules import shared
|
|
|
from modules.upscaler import Upscaler, UpscalerLanczos, UpscalerNearest, UpscalerNone
|
|
|
from modules.paths import script_path, models_path
|
|
@@ -183,9 +185,18 @@ def load_upscalers():
|
|
|
)
|
|
|
|
|
|
|
|
|
-def load_spandrel_model(path, *, device, half: bool = False, dtype=None):
|
|
|
+def load_spandrel_model(
|
|
|
+ path: str,
|
|
|
+ *,
|
|
|
+ device: str | torch.device | None,
|
|
|
+ half: bool = False,
|
|
|
+ dtype: str | None = None,
|
|
|
+ expected_architecture: str | None = None,
|
|
|
+):
|
|
|
import spandrel
|
|
|
model = spandrel.ModelLoader(device=device).load_from_file(path)
|
|
|
+ if expected_architecture and model.architecture != expected_architecture:
|
|
|
+ raise TypeError(f"Model {path} is not a {expected_architecture} model")
|
|
|
if half:
|
|
|
model = model.model.half()
|
|
|
if dtype:
|