فهرست منبع

Refactor upscale_2 helper out of ScuNET/SwinIR; make sure devices are right

Aarni Koskela 1 سال پیش
والد
کامیت
cf14a6a7aa
3فایلهای تغییر یافته به همراه87 افزوده شده و 112 حذف شده
  1. 10 38
      extensions-builtin/ScuNET/scripts/scunet_model.py
  2. 8 54
      extensions-builtin/SwinIR/scripts/swinir_model.py
  3. 69 20
      modules/upscaler_utils.py

+ 10 - 38
extensions-builtin/ScuNET/scripts/scunet_model.py

@@ -1,13 +1,9 @@
 import sys
 
 import PIL.Image
-import numpy as np
-import torch
 
 import modules.upscaler
-from modules import devices, modelloader, script_callbacks, errors
-from modules.shared import opts
-from modules.upscaler_utils import tiled_upscale_2
+from modules import devices, errors, modelloader, script_callbacks, shared, upscaler_utils
 
 
 class UpscalerScuNET(modules.upscaler.Upscaler):
@@ -40,46 +36,23 @@ class UpscalerScuNET(modules.upscaler.Upscaler):
         self.scalers = scalers
 
     def do_upscale(self, img: PIL.Image.Image, selected_file):
-
         devices.torch_gc()
-
         try:
             model = self.load_model(selected_file)
         except Exception as e:
             print(f"ScuNET: Unable to load model from {selected_file}: {e}", file=sys.stderr)
             return img
 
-        device = devices.get_device_for('scunet')
-        tile = opts.SCUNET_tile
-        h, w = img.height, img.width
-        np_img = np.array(img)
-        np_img = np_img[:, :, ::-1]  # RGB to BGR
-        np_img = np_img.transpose((2, 0, 1)) / 255  # HWC to CHW
-        torch_img = torch.from_numpy(np_img).float().unsqueeze(0).to(device)  # type: ignore
-
-        if tile > h or tile > w:
-            _img = torch.zeros(1, 3, max(h, tile), max(w, tile), dtype=torch_img.dtype, device=torch_img.device)
-            _img[:, :, :h, :w] = torch_img # pad image
-            torch_img = _img
-
-        with torch.no_grad():
-            torch_output = tiled_upscale_2(
-                torch_img,
-                model,
-                tile_size=opts.SCUNET_tile,
-                tile_overlap=opts.SCUNET_tile_overlap,
-                scale=1,
-                device=devices.get_device_for('scunet'),
-                desc="ScuNET tiles",
-            ).squeeze(0)
-        torch_output = torch_output[:, :h * 1, :w * 1] # remove padding, if any
-        np_output: np.ndarray = torch_output.float().cpu().clamp_(0, 1).numpy()
-        del torch_img, torch_output
+        img = upscaler_utils.upscale_2(
+            img,
+            model,
+            tile_size=shared.opts.SCUNET_tile,
+            tile_overlap=shared.opts.SCUNET_tile_overlap,
+            scale=1,  # ScuNET is a denoising model, not an upscaler
+            desc='ScuNET',
+        )
         devices.torch_gc()
-
-        output = np_output.transpose((1, 2, 0))  # CHW to HWC
-        output = output[:, :, ::-1]  # BGR to RGB
-        return PIL.Image.fromarray((output * 255).astype(np.uint8))
+        return img
 
     def load_model(self, path: str):
         device = devices.get_device_for('scunet')
@@ -93,7 +66,6 @@ class UpscalerScuNET(modules.upscaler.Upscaler):
 
 def on_ui_settings():
     import gradio as gr
-    from modules import shared
 
     shared.opts.add_option("SCUNET_tile", shared.OptionInfo(256, "Tile size for SCUNET upscalers.", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}, section=('upscaling', "Upscaling")).info("0 = no tiling"))
     shared.opts.add_option("SCUNET_tile_overlap", shared.OptionInfo(8, "Tile overlap for SCUNET upscalers.", gr.Slider, {"minimum": 0, "maximum": 64, "step": 1}, section=('upscaling', "Upscaling")).info("Low values = visible seam"))

+ 8 - 54
extensions-builtin/SwinIR/scripts/swinir_model.py

@@ -1,14 +1,10 @@
 import logging
 import sys
 
-import numpy as np
-import torch
 from PIL import Image
 
-from modules import modelloader, devices, script_callbacks, shared
-from modules.shared import opts
+from modules import devices, modelloader, script_callbacks, shared, upscaler_utils
 from modules.upscaler import Upscaler, UpscalerData
-from modules.upscaler_utils import tiled_upscale_2
 
 SWINIR_MODEL_URL = "https://github.com/JingyunLiang/SwinIR/releases/download/v0.0/003_realSR_BSRGAN_DFOWMFC_s64w8_SwinIR-L_x4_GAN.pth"
 
@@ -36,9 +32,7 @@ class UpscalerSwinIR(Upscaler):
         self.scalers = scalers
 
     def do_upscale(self, img: Image.Image, model_file: str) -> Image.Image:
-        current_config = (model_file, opts.SWIN_tile)
-
-        device = self._get_device()
+        current_config = (model_file, shared.opts.SWIN_tile)
 
         if self._cached_model_config == current_config:
             model = self._cached_model
@@ -51,12 +45,13 @@ class UpscalerSwinIR(Upscaler):
             self._cached_model = model
             self._cached_model_config = current_config
 
-        img = upscale(
+        img = upscaler_utils.upscale_2(
             img,
             model,
-            tile=opts.SWIN_tile,
-            tile_overlap=opts.SWIN_tile_overlap,
-            device=device,
+            tile_size=shared.opts.SWIN_tile,
+            tile_overlap=shared.opts.SWIN_tile_overlap,
+            scale=4,  # TODO: This was hard-coded before too...
+            desc="SwinIR",
         )
         devices.torch_gc()
         return img
@@ -77,7 +72,7 @@ class UpscalerSwinIR(Upscaler):
             dtype=devices.dtype,
             expected_architecture="SwinIR",
         )
-        if getattr(opts, 'SWIN_torch_compile', False):
+        if getattr(shared.opts, 'SWIN_torch_compile', False):
             try:
                 model_descriptor.model.compile()
             except Exception:
@@ -88,47 +83,6 @@ class UpscalerSwinIR(Upscaler):
         return devices.get_device_for('swinir')
 
 
-def upscale(
-    img,
-    model,
-    *,
-    tile: int,
-    tile_overlap: int,
-    window_size=8,
-    scale=4,
-    device,
-):
-
-    img = np.array(img)
-    img = img[:, :, ::-1]
-    img = np.moveaxis(img, 2, 0) / 255
-    img = torch.from_numpy(img).float()
-    img = img.unsqueeze(0).to(device, dtype=devices.dtype)
-    with torch.no_grad(), devices.autocast():
-        _, _, h_old, w_old = img.size()
-        h_pad = (h_old // window_size + 1) * window_size - h_old
-        w_pad = (w_old // window_size + 1) * window_size - w_old
-        img = torch.cat([img, torch.flip(img, [2])], 2)[:, :, : h_old + h_pad, :]
-        img = torch.cat([img, torch.flip(img, [3])], 3)[:, :, :, : w_old + w_pad]
-        output = tiled_upscale_2(
-            img,
-            model,
-            tile_size=tile,
-            tile_overlap=tile_overlap,
-            scale=scale,
-            device=device,
-            desc="SwinIR tiles",
-        )
-        output = output[..., : h_old * scale, : w_old * scale]
-        output = output.data.squeeze().float().cpu().clamp_(0, 1).numpy()
-        if output.ndim == 3:
-            output = np.transpose(
-                output[[2, 1, 0], :, :], (1, 2, 0)
-            )  # CHW-RGB to HCW-BGR
-        output = (output * 255.0).round().astype(np.uint8)  # float32 to uint8
-        return Image.fromarray(output, "RGB")
-
-
 def on_ui_settings():
     import gradio as gr
 

+ 69 - 20
modules/upscaler_utils.py

@@ -11,23 +11,40 @@ from modules import images, shared, torch_utils
 logger = logging.getLogger(__name__)
 
 
-def upscale_without_tiling(model, img: Image.Image):
-    img = np.array(img)
-    img = img[:, :, ::-1]
-    img = np.ascontiguousarray(np.transpose(img, (2, 0, 1))) / 255
-    img = torch.from_numpy(img).float()
-
+def pil_image_to_torch_bgr(img: Image.Image) -> torch.Tensor:
+    img = np.array(img.convert("RGB"))
+    img = img[:, :, ::-1]  # flip RGB to BGR
+    img = np.transpose(img, (2, 0, 1))  # HWC to CHW
+    img = np.ascontiguousarray(img) / 255  # Rescale to [0, 1]
+    return torch.from_numpy(img)
+
+
+def torch_bgr_to_pil_image(tensor: torch.Tensor) -> Image.Image:
+    if tensor.ndim == 4:
+        # If we're given a tensor with a batch dimension, squeeze it out
+        # (but only if it's a batch of size 1).
+        if tensor.shape[0] != 1:
+            raise ValueError(f"{tensor.shape} does not describe a BCHW tensor")
+        tensor = tensor.squeeze(0)
+    assert tensor.ndim == 3, f"{tensor.shape} does not describe a CHW tensor"
+    # TODO: is `tensor.float().cpu()...numpy()` the most efficient idiom?
+    arr = tensor.float().cpu().clamp_(0, 1).numpy()  # clamp
+    arr = 255.0 * np.moveaxis(arr, 0, 2)  # CHW to HWC, rescale
+    arr = arr.astype(np.uint8)
+    arr = arr[:, :, ::-1]  # flip BGR to RGB
+    return Image.fromarray(arr, "RGB")
+
+
+def upscale_pil_patch(model, img: Image.Image) -> Image.Image:
+    """
+    Upscale a given PIL image using the given model.
+    """
     param = torch_utils.get_param(model)
-    img = img.unsqueeze(0).to(device=param.device, dtype=param.dtype)
 
     with torch.no_grad():
-        output = model(img)
-
-    output = output.squeeze().float().cpu().clamp_(0, 1).numpy()
-    output = 255. * np.moveaxis(output, 0, 2)
-    output = output.astype(np.uint8)
-    output = output[:, :, ::-1]
-    return Image.fromarray(output, 'RGB')
+        tensor = pil_image_to_torch_bgr(img).unsqueeze(0)  # add batch dimension
+        tensor = tensor.to(device=param.device, dtype=param.dtype)
+        return torch_bgr_to_pil_image(model(tensor))
 
 
 def upscale_with_model(
@@ -40,7 +57,7 @@ def upscale_with_model(
 ) -> Image.Image:
     if tile_size <= 0:
         logger.debug("Upscaling %s without tiling", img)
-        output = upscale_without_tiling(model, img)
+        output = upscale_pil_patch(model, img)
         logger.debug("=> %s", output)
         return output
 
@@ -52,7 +69,7 @@ def upscale_with_model(
             newrow = []
             for x, w, tile in row:
                 logger.debug("Tile (%d, %d) %s...", x, y, tile)
-                output = upscale_without_tiling(model, tile)
+                output = upscale_pil_patch(model, tile)
                 scale_factor = output.width // tile.width
                 logger.debug("=> %s (scale factor %s)", output, scale_factor)
                 newrow.append([x * scale_factor, w * scale_factor, output])
@@ -71,19 +88,22 @@ def upscale_with_model(
 
 
 def tiled_upscale_2(
-    img,
+    img: torch.Tensor,
     model,
     *,
     tile_size: int,
     tile_overlap: int,
     scale: int,
-    device,
     desc="Tiled upscale",
 ):
     # Alternative implementation of `upscale_with_model` originally used by
     # SwinIR and ScuNET.  It differs from `upscale_with_model` in that tiling and
     # weighting is done in PyTorch space, as opposed to `images.Grid` doing it in
     # Pillow space without weighting.
+
+    # Grab the device the model is on, and use it.
+    device = torch_utils.get_param(model).device
+
     b, c, h, w = img.size()
     tile_size = min(tile_size, h, w)
 
@@ -100,7 +120,8 @@ def tiled_upscale_2(
         h * scale,
         w * scale,
         device=device,
-    ).type_as(img)
+        dtype=img.dtype,
+    )
     weights = torch.zeros_like(result)
     logger.debug("Upscaling %s to %s with tiles", img.shape, result.shape)
     with tqdm.tqdm(total=len(h_idx_list) * len(w_idx_list), desc=desc, disable=not shared.opts.enable_upscale_progressbar) as pbar:
@@ -112,11 +133,13 @@ def tiled_upscale_2(
                 if shared.state.interrupted or shared.state.skipped:
                     break
 
+                # Only move this patch to the device if it's not already there.
                 in_patch = img[
                     ...,
                     h_idx : h_idx + tile_size,
                     w_idx : w_idx + tile_size,
-                ]
+                ].to(device=device)
+
                 out_patch = model(in_patch)
 
                 result[
@@ -138,3 +161,29 @@ def tiled_upscale_2(
     output = result.div_(weights)
 
     return output
+
+
+def upscale_2(
+    img: Image.Image,
+    model,
+    *,
+    tile_size: int,
+    tile_overlap: int,
+    scale: int,
+    desc: str,
+):
+    """
+    Convenience wrapper around `tiled_upscale_2` that handles PIL images.
+    """
+    tensor = pil_image_to_torch_bgr(img).float().unsqueeze(0)  # add batch dimension
+
+    with torch.no_grad():
+        output = tiled_upscale_2(
+            tensor,
+            model,
+            tile_size=tile_size,
+            tile_overlap=tile_overlap,
+            scale=scale,
+            desc=desc,
+        )
+    return torch_bgr_to_pil_image(output)