Explorar o código

Merge pull request #16144 from akx/bump-spandrel

Bump spandrel to 0.3.4
AUTOMATIC1111 hai 1 ano
pai
achega
a6c384b9f7
Modificáronse 3 ficheiros con 32 adicións e 7 borrados
  1. 1 3
      modules/gfpgan_model.py
  2. 29 3
      modules/modelloader.py
  3. 2 1
      requirements_versions.txt

+ 1 - 3
modules/gfpgan_model.py

@@ -36,13 +36,11 @@ class FaceRestorerGFPGAN(face_restoration_utils.CommonFaceRestoration):
             ext_filter=['.pth'],
         ):
             if 'GFPGAN' in os.path.basename(model_path):
-                model = modelloader.load_spandrel_model(
+                return modelloader.load_spandrel_model(
                     model_path,
                     device=self.get_device(),
                     expected_architecture='GFPGAN',
                 ).model
-                model.different_w = True  # see https://github.com/chaiNNer-org/spandrel/pull/81
-                return model
         raise ValueError("No GFPGAN model found")
 
     def restore(self, np_image):

+ 29 - 3
modules/modelloader.py

@@ -139,6 +139,27 @@ def load_upscalers():
         key=lambda x: x.name.lower() if not isinstance(x.scaler, (UpscalerNone, UpscalerLanczos, UpscalerNearest)) else ""
     )
 
+# None: not loaded, False: failed to load, True: loaded
+_spandrel_extra_init_state = None
+
+
+def _init_spandrel_extra_archs() -> None:
+    """
+    Try to initialize `spandrel_extra_archs` (exactly once).
+    """
+    global _spandrel_extra_init_state
+    if _spandrel_extra_init_state is not None:
+        return
+
+    try:
+        import spandrel
+        import spandrel_extra_arches
+        spandrel.MAIN_REGISTRY.add(*spandrel_extra_arches.EXTRA_REGISTRY)
+        _spandrel_extra_init_state = True
+    except Exception:
+        logger.warning("Failed to load spandrel_extra_arches", exc_info=True)
+        _spandrel_extra_init_state = False
+
 
 def load_spandrel_model(
     path: str | os.PathLike,
@@ -148,11 +169,16 @@ def load_spandrel_model(
     dtype: str | torch.dtype | None = None,
     expected_architecture: str | None = None,
 ) -> spandrel.ModelDescriptor:
+    global _spandrel_extra_init_state
+
     import spandrel
+    _init_spandrel_extra_archs()
+
     model_descriptor = spandrel.ModelLoader(device=device).load_from_file(str(path))
-    if expected_architecture and model_descriptor.architecture != expected_architecture:
+    arch = model_descriptor.architecture
+    if expected_architecture and arch.name != expected_architecture:
         logger.warning(
-            f"Model {path!r} is not a {expected_architecture!r} model (got {model_descriptor.architecture!r})",
+            f"Model {path!r} is not a {expected_architecture!r} model (got {arch.name!r})",
         )
     half = False
     if prefer_half:
@@ -166,6 +192,6 @@ def load_spandrel_model(
     model_descriptor.model.eval()
     logger.debug(
         "Loaded %s from %s (device=%s, half=%s, dtype=%s)",
-        model_descriptor, path, device, half, dtype,
+        arch, path, device, half, dtype,
     )
     return model_descriptor

+ 2 - 1
requirements_versions.txt

@@ -24,7 +24,8 @@ pytorch_lightning==1.9.4
 resize-right==0.0.2
 safetensors==0.4.2
 scikit-image==0.21.0
-spandrel==0.1.6
+spandrel==0.3.4
+spandrel-extra-arches==0.1.1
 tomesd==0.1.3
 torch
 torchdiffeq==0.2.3