浏览代码

Merge pull request #14476 from akx/dedupe-tiled-weighted-inference

Deduplicate tiled inference code from SwinIR/ScuNET
AUTOMATIC1111 1 年之前
父节点
当前提交
a84e842189

+ 11 - 44
extensions-builtin/ScuNET/scripts/scunet_model.py

@@ -3,12 +3,11 @@ import sys
 import PIL.Image
 import numpy as np
 import torch
-from tqdm import tqdm
 
 import modules.upscaler
 from modules import devices, modelloader, script_callbacks, errors
-
 from modules.shared import opts
+from modules.upscaler_utils import tiled_upscale_2
 
 
 class UpscalerScuNET(modules.upscaler.Upscaler):
@@ -40,47 +39,6 @@ class UpscalerScuNET(modules.upscaler.Upscaler):
             scalers.append(scaler_data2)
         self.scalers = scalers
 
-    @staticmethod
-    @torch.no_grad()
-    def tiled_inference(img, model):
-        # test the image tile by tile
-        h, w = img.shape[2:]
-        tile = opts.SCUNET_tile
-        tile_overlap = opts.SCUNET_tile_overlap
-        if tile == 0:
-            return model(img)
-
-        device = devices.get_device_for('scunet')
-        assert tile % 8 == 0, "tile size should be a multiple of window_size"
-        sf = 1
-
-        stride = tile - tile_overlap
-        h_idx_list = list(range(0, h - tile, stride)) + [h - tile]
-        w_idx_list = list(range(0, w - tile, stride)) + [w - tile]
-        E = torch.zeros(1, 3, h * sf, w * sf, dtype=img.dtype, device=device)
-        W = torch.zeros_like(E, dtype=devices.dtype, device=device)
-
-        with tqdm(total=len(h_idx_list) * len(w_idx_list), desc="ScuNET tiles") as pbar:
-            for h_idx in h_idx_list:
-
-                for w_idx in w_idx_list:
-
-                    in_patch = img[..., h_idx: h_idx + tile, w_idx: w_idx + tile]
-
-                    out_patch = model(in_patch)
-                    out_patch_mask = torch.ones_like(out_patch)
-
-                    E[
-                        ..., h_idx * sf: (h_idx + tile) * sf, w_idx * sf: (w_idx + tile) * sf
-                    ].add_(out_patch)
-                    W[
-                        ..., h_idx * sf: (h_idx + tile) * sf, w_idx * sf: (w_idx + tile) * sf
-                    ].add_(out_patch_mask)
-                    pbar.update(1)
-        output = E.div_(W)
-
-        return output
-
     def do_upscale(self, img: PIL.Image.Image, selected_file):
 
         devices.torch_gc()
@@ -104,7 +62,16 @@ class UpscalerScuNET(modules.upscaler.Upscaler):
             _img[:, :, :h, :w] = torch_img # pad image
             torch_img = _img
 
-        torch_output = self.tiled_inference(torch_img, model).squeeze(0)
+        with torch.no_grad():
+            torch_output = tiled_upscale_2(
+                torch_img,
+                model,
+                tile_size=opts.SCUNET_tile,
+                tile_overlap=opts.SCUNET_tile_overlap,
+                scale=1,
+                device=devices.get_device_for('scunet'),
+                desc="ScuNET tiles",
+            ).squeeze(0)
         torch_output = torch_output[:, :h * 1, :w * 1] # remove padding, if any
         np_output: np.ndarray = torch_output.float().cpu().clamp_(0, 1).numpy()
         del torch_img, torch_output

+ 5 - 52
extensions-builtin/SwinIR/scripts/swinir_model.py

@@ -4,11 +4,11 @@ import sys
 import numpy as np
 import torch
 from PIL import Image
-from tqdm import tqdm
 
 from modules import modelloader, devices, script_callbacks, shared
-from modules.shared import opts, state
+from modules.shared import opts
 from modules.upscaler import Upscaler, UpscalerData
+from modules.upscaler_utils import tiled_upscale_2
 
 SWINIR_MODEL_URL = "https://github.com/JingyunLiang/SwinIR/releases/download/v0.0/003_realSR_BSRGAN_DFOWMFC_s64w8_SwinIR-L_x4_GAN.pth"
 
@@ -110,14 +110,14 @@ def upscale(
         w_pad = (w_old // window_size + 1) * window_size - w_old
         img = torch.cat([img, torch.flip(img, [2])], 2)[:, :, : h_old + h_pad, :]
         img = torch.cat([img, torch.flip(img, [3])], 3)[:, :, :, : w_old + w_pad]
-        output = inference(
+        output = tiled_upscale_2(
             img,
             model,
-            tile=tile,
+            tile_size=tile,
             tile_overlap=tile_overlap,
-            window_size=window_size,
             scale=scale,
             device=device,
+            desc="SwinIR tiles",
         )
         output = output[..., : h_old * scale, : w_old * scale]
         output = output.data.squeeze().float().cpu().clamp_(0, 1).numpy()
@@ -129,53 +129,6 @@ def upscale(
         return Image.fromarray(output, "RGB")
 
 
-def inference(
-    img,
-    model,
-    *,
-    tile: int,
-    tile_overlap: int,
-    window_size: int,
-    scale: int,
-    device,
-):
-    # test the image tile by tile
-    b, c, h, w = img.size()
-    tile = min(tile, h, w)
-    assert tile % window_size == 0, "tile size should be a multiple of window_size"
-    sf = scale
-
-    stride = tile - tile_overlap
-    h_idx_list = list(range(0, h - tile, stride)) + [h - tile]
-    w_idx_list = list(range(0, w - tile, stride)) + [w - tile]
-    E = torch.zeros(b, c, h * sf, w * sf, dtype=devices.dtype, device=device).type_as(img)
-    W = torch.zeros_like(E, dtype=devices.dtype, device=device)
-
-    with tqdm(total=len(h_idx_list) * len(w_idx_list), desc="SwinIR tiles") as pbar:
-        for h_idx in h_idx_list:
-            if state.interrupted or state.skipped:
-                break
-
-            for w_idx in w_idx_list:
-                if state.interrupted or state.skipped:
-                    break
-
-                in_patch = img[..., h_idx: h_idx + tile, w_idx: w_idx + tile]
-                out_patch = model(in_patch)
-                out_patch_mask = torch.ones_like(out_patch)
-
-                E[
-                ..., h_idx * sf: (h_idx + tile) * sf, w_idx * sf: (w_idx + tile) * sf
-                ].add_(out_patch)
-                W[
-                ..., h_idx * sf: (h_idx + tile) * sf, w_idx * sf: (w_idx + tile) * sf
-                ].add_(out_patch_mask)
-                pbar.update(1)
-    output = E.div_(W)
-
-    return output
-
-
 def on_ui_settings():
     import gradio as gr
 

+ 71 - 1
modules/upscaler_utils.py

@@ -6,7 +6,7 @@ import torch
 import tqdm
 from PIL import Image
 
-from modules import images
+from modules import images, shared
 
 logger = logging.getLogger(__name__)
 
@@ -68,3 +68,73 @@ def upscale_with_model(
         overlap=grid.overlap * scale_factor,
     )
     return images.combine_grid(newgrid)
+
+
+def tiled_upscale_2(
+    img,
+    model,
+    *,
+    tile_size: int,
+    tile_overlap: int,
+    scale: int,
+    device,
+    desc="Tiled upscale",
+):
+    # Alternative implementation of `upscale_with_model` originally used by
+    # SwinIR and ScuNET.  It differs from `upscale_with_model` in that tiling and
+    # weighting is done in PyTorch space, as opposed to `images.Grid` doing it in
+    # Pillow space without weighting.
+    b, c, h, w = img.size()
+    tile_size = min(tile_size, h, w)
+
+    if tile_size <= 0:
+        logger.debug("Upscaling %s without tiling", img.shape)
+        return model(img)
+
+    stride = tile_size - tile_overlap
+    h_idx_list = list(range(0, h - tile_size, stride)) + [h - tile_size]
+    w_idx_list = list(range(0, w - tile_size, stride)) + [w - tile_size]
+    result = torch.zeros(
+        b,
+        c,
+        h * scale,
+        w * scale,
+        device=device,
+    ).type_as(img)
+    weights = torch.zeros_like(result)
+    logger.debug("Upscaling %s to %s with tiles", img.shape, result.shape)
+    with tqdm.tqdm(total=len(h_idx_list) * len(w_idx_list), desc=desc) as pbar:
+        for h_idx in h_idx_list:
+            if shared.state.interrupted or shared.state.skipped:
+                break
+
+            for w_idx in w_idx_list:
+                if shared.state.interrupted or shared.state.skipped:
+                    break
+
+                in_patch = img[
+                    ...,
+                    h_idx : h_idx + tile_size,
+                    w_idx : w_idx + tile_size,
+                ]
+                out_patch = model(in_patch)
+
+                result[
+                    ...,
+                    h_idx * scale : (h_idx + tile_size) * scale,
+                    w_idx * scale : (w_idx + tile_size) * scale,
+                ].add_(out_patch)
+
+                out_patch_mask = torch.ones_like(out_patch)
+
+                weights[
+                    ...,
+                    h_idx * scale : (h_idx + tile_size) * scale,
+                    w_idx * scale : (w_idx + tile_size) * scale,
+                ].add_(out_patch_mask)
+
+                pbar.update(1)
+
+    output = result.div_(weights)
+
+    return output