Эх сурвалжийг харах

Unify CodeFormer and GFPGAN restoration backends, use Spandrel for GFPGAN

Aarni Koskela 1 жил өмнө
parent
commit
b621a63cf6

+ 8 - 0
.github/workflows/run_tests.yaml

@@ -20,6 +20,12 @@ jobs:
           cache-dependency-path: |
           cache-dependency-path: |
             **/requirements*txt
             **/requirements*txt
             launch.py
             launch.py
+      - name: Cache models
+        id: cache-models
+        uses: actions/cache@v3
+        with:
+          path: models
+          key: "2023-12-30"
       - name: Install test dependencies
       - name: Install test dependencies
         run: pip install wait-for-it -r requirements-test.txt
         run: pip install wait-for-it -r requirements-test.txt
         env:
         env:
@@ -33,6 +39,8 @@ jobs:
           TORCH_INDEX_URL: https://download.pytorch.org/whl/cpu
           TORCH_INDEX_URL: https://download.pytorch.org/whl/cpu
           WEBUI_LAUNCH_LIVE_OUTPUT: "1"
           WEBUI_LAUNCH_LIVE_OUTPUT: "1"
           PYTHONUNBUFFERED: "1"
           PYTHONUNBUFFERED: "1"
+      - name: Print installed packages
+        run: pip freeze
       - name: Start test server
       - name: Start test server
         run: >
         run: >
           python -m coverage run
           python -m coverage run

+ 1 - 0
.gitignore

@@ -37,3 +37,4 @@ notification.mp3
 /node_modules
 /node_modules
 /package-lock.json
 /package-lock.json
 /.coverage*
 /.coverage*
+/test/test_outputs

+ 40 - 118
modules/codeformer_model.py

@@ -1,140 +1,62 @@
-import os
+from __future__ import annotations
+
+import logging
 
 
-import cv2
 import torch
 import torch
 
 
-import modules.face_restoration
-import modules.shared
-from modules import shared, devices, modelloader, errors
-from modules.paths import models_path
+from modules import (
+    devices,
+    errors,
+    face_restoration,
+    face_restoration_utils,
+    modelloader,
+    shared,
+)
+
+logger = logging.getLogger(__name__)
 
 
-model_dir = "Codeformer"
-model_path = os.path.join(models_path, model_dir)
 model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth'
 model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth'
+model_download_name = 'codeformer-v0.1.0.pth'
 
 
-codeformer = None
+# used by e.g. postprocessing_codeformer.py
+codeformer: face_restoration.FaceRestoration | None = None
 
 
 
 
-class FaceRestorerCodeFormer(modules.face_restoration.FaceRestoration):
+class FaceRestorerCodeFormer(face_restoration_utils.CommonFaceRestoration):
     def name(self):
     def name(self):
         return "CodeFormer"
         return "CodeFormer"
 
 
-    def __init__(self, dirname):
-        self.net = None
-        self.face_helper = None
-        self.cmd_dir = dirname
-
-    def create_models(self):
-        from facexlib.detection import retinaface
-        from facexlib.utils.face_restoration_helper import FaceRestoreHelper
-
-        if self.net is not None and self.face_helper is not None:
-            self.net.to(devices.device_codeformer)
-            return self.net, self.face_helper
-        model_paths = modelloader.load_models(
-            model_path,
-            model_url,
-            self.cmd_dir,
-            download_name='codeformer-v0.1.0.pth',
+    def load_net(self) -> torch.Module:
+        for model_path in modelloader.load_models(
+            model_path=self.model_path,
+            model_url=model_url,
+            command_path=self.model_path,
+            download_name=model_download_name,
             ext_filter=['.pth'],
             ext_filter=['.pth'],
-        )
-
-        if len(model_paths) != 0:
-            ckpt_path = model_paths[0]
-        else:
-            print("Unable to load codeformer model.")
-            return None, None
-        net = modelloader.load_spandrel_model(ckpt_path, device=devices.device_codeformer)
-
-        if hasattr(retinaface, 'device'):
-            retinaface.device = devices.device_codeformer
-
-        face_helper = FaceRestoreHelper(
-            upscale_factor=1,
-            face_size=512,
-            crop_ratio=(1, 1),
-            det_model='retinaface_resnet50',
-            save_ext='png',
-            use_parse=True,
-            device=devices.device_codeformer,
-        )
-
-        self.net = net
-        self.face_helper = face_helper
-
-    def send_model_to(self, device):
-        self.net.to(device)
-        self.face_helper.face_det.to(device)
-        self.face_helper.face_parse.to(device)
-
-    def restore(self, np_image, w=None):
-        from torchvision.transforms.functional import normalize
-        from basicsr.utils import img2tensor, tensor2img
-        np_image = np_image[:, :, ::-1]
-
-        original_resolution = np_image.shape[0:2]
-
-        self.create_models()
-        if self.net is None or self.face_helper is None:
-            return np_image
-
-        self.send_model_to(devices.device_codeformer)
-
-        self.face_helper.clean_all()
-        self.face_helper.read_image(np_image)
-        self.face_helper.get_face_landmarks_5(only_center_face=False, resize=640, eye_dist_threshold=5)
-        self.face_helper.align_warp_face()
-
-        for cropped_face in self.face_helper.cropped_faces:
-            cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True)
-            normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
-            cropped_face_t = cropped_face_t.unsqueeze(0).to(devices.device_codeformer)
-
-            try:
-                with torch.no_grad():
-                    res = self.net(cropped_face_t, w=w if w is not None else shared.opts.code_former_weight, adain=True)
-                    if isinstance(res, tuple):
-                        output = res[0]
-                    else:
-                        output = res
-                    if not isinstance(res, torch.Tensor):
-                        raise TypeError(f"Expected torch.Tensor, got {type(res)}")
-                    restored_face = tensor2img(output, rgb2bgr=True, min_max=(-1, 1))
-                del output
-                devices.torch_gc()
-            except Exception:
-                errors.report('Failed inference for CodeFormer', exc_info=True)
-                restored_face = tensor2img(cropped_face_t, rgb2bgr=True, min_max=(-1, 1))
-
-            restored_face = restored_face.astype('uint8')
-            self.face_helper.add_restored_face(restored_face)
-
-        self.face_helper.get_inverse_affine(None)
-
-        restored_img = self.face_helper.paste_faces_to_input_image()
-        restored_img = restored_img[:, :, ::-1]
+        ):
+            return modelloader.load_spandrel_model(
+                model_path,
+                device=devices.device_codeformer,
+            ).model
+        raise ValueError("No codeformer model found")
 
 
-        if original_resolution != restored_img.shape[0:2]:
-            restored_img = cv2.resize(
-                restored_img,
-                (0, 0),
-                fx=original_resolution[1]/restored_img.shape[1],
-                fy=original_resolution[0]/restored_img.shape[0],
-                interpolation=cv2.INTER_LINEAR,
-            )
+    def get_device(self):
+        return devices.device_codeformer
 
 
-        self.face_helper.clean_all()
+    def restore(self, np_image, w: float | None = None):
+        if w is None:
+            w = getattr(shared.opts, "code_former_weight", 0.5)
 
 
-        if shared.opts.face_restoration_unload:
-            self.send_model_to(devices.cpu)
+        def restore_face(cropped_face_t):
+            assert self.net is not None
+            return self.net(cropped_face_t, w=w, adain=True)[0]
 
 
-        return restored_img
+        return self.restore_with_helper(np_image, restore_face)
 
 
 
 
-def setup_model(dirname):
-    os.makedirs(model_path, exist_ok=True)
+def setup_model(dirname: str) -> None:
+    global codeformer
     try:
     try:
-        global codeformer
         codeformer = FaceRestorerCodeFormer(dirname)
         codeformer = FaceRestorerCodeFormer(dirname)
         shared.face_restorers.append(codeformer)
         shared.face_restorers.append(codeformer)
     except Exception:
     except Exception:

+ 163 - 0
modules/face_restoration_utils.py

@@ -0,0 +1,163 @@
+from __future__ import annotations
+
+import logging
+import os
+from functools import cached_property
+from typing import TYPE_CHECKING, Callable
+
+import cv2
+import numpy as np
+import torch
+
+from modules import devices, errors, face_restoration, shared
+
+if TYPE_CHECKING:
+    from facexlib.utils.face_restoration_helper import FaceRestoreHelper
+
+logger = logging.getLogger(__name__)
+
+
+def create_face_helper(device) -> FaceRestoreHelper:
+    from facexlib.detection import retinaface
+    from facexlib.utils.face_restoration_helper import FaceRestoreHelper
+    if hasattr(retinaface, 'device'):
+        retinaface.device = device
+    return FaceRestoreHelper(
+        upscale_factor=1,
+        face_size=512,
+        crop_ratio=(1, 1),
+        det_model='retinaface_resnet50',
+        save_ext='png',
+        use_parse=True,
+        device=device,
+    )
+
+
+def restore_with_face_helper(
+    np_image: np.ndarray,
+    face_helper: FaceRestoreHelper,
+    restore_face: Callable[[np.ndarray], np.ndarray],
+) -> np.ndarray:
+    """
+    Find faces in the image using face_helper, restore them using restore_face, and paste them back into the image.
+
+    `restore_face` should take a cropped face image and return a restored face image.
+    """
+    from basicsr.utils import img2tensor, tensor2img
+    from torchvision.transforms.functional import normalize
+    np_image = np_image[:, :, ::-1]
+    original_resolution = np_image.shape[0:2]
+
+    try:
+        logger.debug("Detecting faces...")
+        face_helper.clean_all()
+        face_helper.read_image(np_image)
+        face_helper.get_face_landmarks_5(only_center_face=False, resize=640, eye_dist_threshold=5)
+        face_helper.align_warp_face()
+        logger.debug("Found %d faces, restoring", len(face_helper.cropped_faces))
+        for cropped_face in face_helper.cropped_faces:
+            cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True)
+            normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
+            cropped_face_t = cropped_face_t.unsqueeze(0).to(devices.device_codeformer)
+
+            try:
+                with torch.no_grad():
+                    restored_face = tensor2img(
+                        restore_face(cropped_face_t),
+                        rgb2bgr=True,
+                        min_max=(-1, 1),
+                    )
+                devices.torch_gc()
+            except Exception:
+                errors.report('Failed face-restoration inference', exc_info=True)
+                restored_face = tensor2img(cropped_face_t, rgb2bgr=True, min_max=(-1, 1))
+
+            restored_face = restored_face.astype('uint8')
+            face_helper.add_restored_face(restored_face)
+
+        logger.debug("Merging restored faces into image")
+        face_helper.get_inverse_affine(None)
+        img = face_helper.paste_faces_to_input_image()
+        img = img[:, :, ::-1]
+        if original_resolution != img.shape[0:2]:
+            img = cv2.resize(
+                img,
+                (0, 0),
+                fx=original_resolution[1] / img.shape[1],
+                fy=original_resolution[0] / img.shape[0],
+                interpolation=cv2.INTER_LINEAR,
+            )
+        logger.debug("Face restoration complete")
+    finally:
+        face_helper.clean_all()
+    return img
+
+
+class CommonFaceRestoration(face_restoration.FaceRestoration):
+    net: torch.Module | None
+    model_url: str
+    model_download_name: str
+
+    def __init__(self, model_path: str):
+        super().__init__()
+        self.net = None
+        self.model_path = model_path
+        os.makedirs(model_path, exist_ok=True)
+
+    @cached_property
+    def face_helper(self) -> FaceRestoreHelper:
+        return create_face_helper(self.get_device())
+
+    def send_model_to(self, device):
+        if self.net:
+            logger.debug("Sending %s to %s", self.net, device)
+            self.net.to(device)
+        if self.face_helper:
+            logger.debug("Sending face helper to %s", device)
+            self.face_helper.face_det.to(device)
+            self.face_helper.face_parse.to(device)
+
+    def get_device(self):
+        raise NotImplementedError("get_device must be implemented by subclasses")
+
+    def load_net(self) -> torch.Module:
+        raise NotImplementedError("load_net must be implemented by subclasses")
+
+    def restore_with_helper(
+        self,
+        np_image: np.ndarray,
+        restore_face: Callable[[np.ndarray], np.ndarray],
+    ) -> np.ndarray:
+        try:
+            if self.net is None:
+                self.net = self.load_net()
+        except Exception:
+            logger.warning("Unable to load face-restoration model", exc_info=True)
+            return np_image
+
+        try:
+            self.send_model_to(self.get_device())
+            return restore_with_face_helper(np_image, self.face_helper, restore_face)
+        finally:
+            if shared.opts.face_restoration_unload:
+                self.send_model_to(devices.cpu)
+
+
+def patch_facexlib(dirname: str) -> None:
+    import facexlib.detection
+    import facexlib.parsing
+
+    det_facex_load_file_from_url = facexlib.detection.load_file_from_url
+    par_facex_load_file_from_url = facexlib.parsing.load_file_from_url
+
+    def update_kwargs(kwargs):
+        return dict(kwargs, save_dir=dirname, model_dir=None)
+
+    def facex_load_file_from_url(**kwargs):
+        return det_facex_load_file_from_url(**update_kwargs(kwargs))
+
+    def facex_load_file_from_url2(**kwargs):
+        return par_facex_load_file_from_url(**update_kwargs(kwargs))
+
+    facexlib.detection.load_file_from_url = facex_load_file_from_url
+    facexlib.parsing.load_file_from_url = facex_load_file_from_url2

+ 54 - 112
modules/gfpgan_model.py

@@ -1,126 +1,68 @@
+from __future__ import annotations
+
+import logging
 import os
 import os
 
 
-import modules.face_restoration
-from modules import paths, shared, devices, modelloader, errors
+from modules import (
+    devices,
+    errors,
+    face_restoration,
+    face_restoration_utils,
+    modelloader,
+    shared,
+)
 
 
-model_dir = "GFPGAN"
-user_path = None
-model_path = os.path.join(paths.models_path, model_dir)
-model_file_path = None
+logger = logging.getLogger(__name__)
 model_url = "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth"
 model_url = "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth"
-have_gfpgan = False
-loaded_gfpgan_model = None
-
-
-def gfpgann():
-    global loaded_gfpgan_model
-    global model_path
-    global model_file_path
-    if loaded_gfpgan_model is not None:
-        loaded_gfpgan_model.gfpgan.to(devices.device_gfpgan)
-        return loaded_gfpgan_model
-
-    if gfpgan_constructor is None:
-        return None
-
-    models = modelloader.load_models(model_path, model_url, user_path, ext_filter=['.pth'])
-
-    if len(models) == 1 and models[0].startswith("http"):
-        model_file = models[0]
-    elif len(models) != 0:
-        gfp_models = []
-        for item in models:
-            if 'GFPGAN' in os.path.basename(item):
-                gfp_models.append(item)
-        latest_file = max(gfp_models, key=os.path.getctime)
-        model_file = latest_file
-    else:
-        print("Unable to load gfpgan model!")
-        return None
-
-    import facexlib.detection.retinaface
-
-    if hasattr(facexlib.detection.retinaface, 'device'):
-        facexlib.detection.retinaface.device = devices.device_gfpgan
-    model_file_path = model_file
-    model = gfpgan_constructor(model_path=model_file, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None, device=devices.device_gfpgan)
-    loaded_gfpgan_model = model
-
-    return model
-
-
-def send_model_to(model, device):
-    model.gfpgan.to(device)
-    model.face_helper.face_det.to(device)
-    model.face_helper.face_parse.to(device)
+model_download_name = "GFPGANv1.4.pth"
+gfpgan_face_restorer: face_restoration.FaceRestoration | None = None
+
+
+class FaceRestorerGFPGAN(face_restoration_utils.CommonFaceRestoration):
+    def name(self):
+        return "GFPGAN"
+
+    def get_device(self):
+        return devices.device_gfpgan
+
+    def load_net(self) -> None:
+        for model_path in modelloader.load_models(
+            model_path=self.model_path,
+            model_url=model_url,
+            command_path=self.model_path,
+            download_name=model_download_name,
+            ext_filter=['.pth'],
+        ):
+            if 'GFPGAN' in os.path.basename(model_path):
+                net = modelloader.load_spandrel_model(
+                    model_path,
+                    device=self.get_device(),
+                ).model
+                net.different_w = True  # see https://github.com/chaiNNer-org/spandrel/pull/81
+                return net
+        raise ValueError("No GFPGAN model found")
+
+    def restore(self, np_image):
+        def restore_face(cropped_face_t):
+            assert self.net is not None
+            return self.net(cropped_face_t, return_rgb=False)[0]
+
+        return self.restore_with_helper(np_image, restore_face)
 
 
 
 
 def gfpgan_fix_faces(np_image):
 def gfpgan_fix_faces(np_image):
-    model = gfpgann()
-    if model is None:
-        return np_image
-
-    send_model_to(model, devices.device_gfpgan)
-
-    np_image_bgr = np_image[:, :, ::-1]
-    cropped_faces, restored_faces, gfpgan_output_bgr = model.enhance(np_image_bgr, has_aligned=False, only_center_face=False, paste_back=True)
-    np_image = gfpgan_output_bgr[:, :, ::-1]
-
-    model.face_helper.clean_all()
-
-    if shared.opts.face_restoration_unload:
-        send_model_to(model, devices.cpu)
-
+    if gfpgan_face_restorer:
+        return gfpgan_face_restorer.restore(np_image)
+    logger.warning("GFPGAN face restorer not set up")
     return np_image
     return np_image
 
 
 
 
-gfpgan_constructor = None
+def setup_model(dirname: str) -> None:
+    global gfpgan_face_restorer
 
 
-
-def setup_model(dirname):
     try:
     try:
-        os.makedirs(model_path, exist_ok=True)
-        import gfpgan
-        import facexlib.detection
-        import facexlib.parsing
-
-        global user_path
-        global have_gfpgan
-        global gfpgan_constructor
-        global model_file_path
-
-        facexlib_path = model_path
-
-        if dirname is not None:
-            facexlib_path = dirname
-
-        load_file_from_url_orig = gfpgan.utils.load_file_from_url
-        facex_load_file_from_url_orig = facexlib.detection.load_file_from_url
-        facex_load_file_from_url_orig2 = facexlib.parsing.load_file_from_url
-
-        def my_load_file_from_url(**kwargs):
-            return load_file_from_url_orig(**dict(kwargs, model_dir=model_file_path))
-
-        def facex_load_file_from_url(**kwargs):
-            return facex_load_file_from_url_orig(**dict(kwargs, save_dir=facexlib_path, model_dir=None))
-
-        def facex_load_file_from_url2(**kwargs):
-            return facex_load_file_from_url_orig2(**dict(kwargs, save_dir=facexlib_path, model_dir=None))
-
-        gfpgan.utils.load_file_from_url = my_load_file_from_url
-        facexlib.detection.load_file_from_url = facex_load_file_from_url
-        facexlib.parsing.load_file_from_url = facex_load_file_from_url2
-        user_path = dirname
-        have_gfpgan = True
-        gfpgan_constructor = gfpgan.GFPGANer
-
-        class FaceRestorerGFPGAN(modules.face_restoration.FaceRestoration):
-            def name(self):
-                return "GFPGAN"
-
-            def restore(self, np_image):
-                return gfpgan_fix_faces(np_image)
-
-        shared.face_restorers.append(FaceRestorerGFPGAN())
+        face_restoration_utils.patch_facexlib(dirname)
+        gfpgan_face_restorer = FaceRestorerGFPGAN(model_path=dirname)
+        shared.face_restorers.append(gfpgan_face_restorer)
     except Exception:
     except Exception:
         errors.report("Error setting up GFPGAN", exc_info=True)
         errors.report("Error setting up GFPGAN", exc_info=True)

+ 0 - 1
requirements.txt

@@ -8,7 +8,6 @@ clean-fid
 einops
 einops
 facexlib
 facexlib
 fastapi>=0.90.1
 fastapi>=0.90.1
-gfpgan
 gradio==3.41.2
 gradio==3.41.2
 inflection
 inflection
 jsonmerge
 jsonmerge

+ 0 - 1
requirements_versions.txt

@@ -7,7 +7,6 @@ clean-fid==0.1.35
 einops==0.4.1
 einops==0.4.1
 facexlib==0.3.0
 facexlib==0.3.0
 fastapi==0.94.0
 fastapi==0.94.0
-gfpgan==1.3.8
 gradio==3.41.2
 gradio==3.41.2
 httpcore==0.15
 httpcore==0.15
 inflection==0.5.1
 inflection==0.5.1

+ 13 - 2
test/conftest.py

@@ -1,10 +1,16 @@
+import base64
 import os
 import os
 
 
 import pytest
 import pytest
-import base64
-
 
 
 test_files_path = os.path.dirname(__file__) + "/test_files"
 test_files_path = os.path.dirname(__file__) + "/test_files"
+test_outputs_path = os.path.dirname(__file__) + "/test_outputs"
+
+
+def pytest_configure(config):
+    # We don't want to fail on Py.test command line arguments being
+    # parsed by webui:
+    os.environ.setdefault("IGNORE_CMD_ARGS_ERRORS", "1")
 
 
 
 
 def file_to_base64(filename):
 def file_to_base64(filename):
@@ -23,3 +29,8 @@ def img2img_basic_image_base64() -> str:
 @pytest.fixture(scope="session")  # session so we don't read this over and over
 @pytest.fixture(scope="session")  # session so we don't read this over and over
 def mask_basic_image_base64() -> str:
 def mask_basic_image_base64() -> str:
     return file_to_base64(os.path.join(test_files_path, "mask_basic.png"))
     return file_to_base64(os.path.join(test_files_path, "mask_basic.png"))
+
+
+@pytest.fixture(scope="session")
+def initialize() -> None:
+    import webui  # noqa: F401

+ 29 - 0
test/test_face_restorers.py

@@ -0,0 +1,29 @@
+import os
+from test.conftest import test_files_path, test_outputs_path
+
+import numpy as np
+import pytest
+from PIL import Image
+
+
+@pytest.mark.usefixtures("initialize")
+@pytest.mark.parametrize("restorer_name", ["gfpgan", "codeformer"])
+def test_face_restorers(restorer_name):
+    from modules import shared
+
+    if restorer_name == "gfpgan":
+        from modules import gfpgan_model
+        gfpgan_model.setup_model(shared.cmd_opts.gfpgan_models_path)
+        restorer = gfpgan_model.gfpgan_fix_faces
+    elif restorer_name == "codeformer":
+        from modules import codeformer_model
+        codeformer_model.setup_model(shared.cmd_opts.codeformer_models_path)
+        restorer = codeformer_model.codeformer.restore
+    else:
+        raise NotImplementedError("...")
+    img = Image.open(os.path.join(test_files_path, "two-faces.jpg"))
+    np_img = np.array(img, dtype=np.uint8)
+    fixed_image = restorer(np_img)
+    assert fixed_image.shape == np_img.shape
+    assert not np.allclose(fixed_image, np_img)  # should have visibly changed
+    Image.fromarray(fixed_image).save(os.path.join(test_outputs_path, f"{restorer_name}.png"))

BIN
test/test_files/two-faces.jpg


+ 0 - 0
test/test_outputs/.gitkeep