swinir_model.py 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187
  1. import logging
  2. import sys
  3. import numpy as np
  4. import torch
  5. from PIL import Image
  6. from tqdm import tqdm
  7. from modules import modelloader, devices, script_callbacks, shared
  8. from modules.shared import opts, state
  9. from modules.upscaler import Upscaler, UpscalerData
  10. SWINIR_MODEL_URL = "https://github.com/JingyunLiang/SwinIR/releases/download/v0.0/003_realSR_BSRGAN_DFOWMFC_s64w8_SwinIR-L_x4_GAN.pth"
  11. logger = logging.getLogger(__name__)
  12. class UpscalerSwinIR(Upscaler):
  13. def __init__(self, dirname):
  14. self._cached_model = None # keep the model when SWIN_torch_compile is on to prevent re-compile every runs
  15. self._cached_model_config = None # to clear '_cached_model' when changing model (v1/v2) or settings
  16. self.name = "SwinIR"
  17. self.model_url = SWINIR_MODEL_URL
  18. self.model_name = "SwinIR 4x"
  19. self.user_path = dirname
  20. super().__init__()
  21. scalers = []
  22. model_files = self.find_models(ext_filter=[".pt", ".pth"])
  23. for model in model_files:
  24. if model.startswith("http"):
  25. name = self.model_name
  26. else:
  27. name = modelloader.friendly_name(model)
  28. model_data = UpscalerData(name, model, self)
  29. scalers.append(model_data)
  30. self.scalers = scalers
  31. def do_upscale(self, img: Image.Image, model_file: str) -> Image.Image:
  32. current_config = (model_file, opts.SWIN_tile)
  33. device = self._get_device()
  34. if self._cached_model_config == current_config:
  35. model = self._cached_model
  36. else:
  37. try:
  38. model = self.load_model(model_file)
  39. except Exception as e:
  40. print(f"Failed loading SwinIR model {model_file}: {e}", file=sys.stderr)
  41. return img
  42. self._cached_model = model
  43. self._cached_model_config = current_config
  44. img = upscale(
  45. img,
  46. model,
  47. tile=opts.SWIN_tile,
  48. tile_overlap=opts.SWIN_tile_overlap,
  49. device=device,
  50. )
  51. devices.torch_gc()
  52. return img
  53. def load_model(self, path, scale=4):
  54. if path.startswith("http"):
  55. filename = modelloader.load_file_from_url(
  56. url=path,
  57. model_dir=self.model_download_path,
  58. file_name=f"{self.model_name.replace(' ', '_')}.pth",
  59. )
  60. else:
  61. filename = path
  62. model = modelloader.load_spandrel_model(
  63. filename,
  64. device=self._get_device(),
  65. dtype=devices.dtype,
  66. expected_architecture="SwinIR",
  67. )
  68. if getattr(opts, 'SWIN_torch_compile', False):
  69. try:
  70. model = torch.compile(model)
  71. except Exception:
  72. logger.warning("Failed to compile SwinIR model, fallback to JIT", exc_info=True)
  73. return model
  74. def _get_device(self):
  75. return devices.get_device_for('swinir')
  76. def upscale(
  77. img,
  78. model,
  79. *,
  80. tile: int,
  81. tile_overlap: int,
  82. window_size=8,
  83. scale=4,
  84. device,
  85. ):
  86. img = np.array(img)
  87. img = img[:, :, ::-1]
  88. img = np.moveaxis(img, 2, 0) / 255
  89. img = torch.from_numpy(img).float()
  90. img = img.unsqueeze(0).to(device, dtype=devices.dtype)
  91. with torch.no_grad(), devices.autocast():
  92. _, _, h_old, w_old = img.size()
  93. h_pad = (h_old // window_size + 1) * window_size - h_old
  94. w_pad = (w_old // window_size + 1) * window_size - w_old
  95. img = torch.cat([img, torch.flip(img, [2])], 2)[:, :, : h_old + h_pad, :]
  96. img = torch.cat([img, torch.flip(img, [3])], 3)[:, :, :, : w_old + w_pad]
  97. output = inference(
  98. img,
  99. model,
  100. tile=tile,
  101. tile_overlap=tile_overlap,
  102. window_size=window_size,
  103. scale=scale,
  104. device=device,
  105. )
  106. output = output[..., : h_old * scale, : w_old * scale]
  107. output = output.data.squeeze().float().cpu().clamp_(0, 1).numpy()
  108. if output.ndim == 3:
  109. output = np.transpose(
  110. output[[2, 1, 0], :, :], (1, 2, 0)
  111. ) # CHW-RGB to HCW-BGR
  112. output = (output * 255.0).round().astype(np.uint8) # float32 to uint8
  113. return Image.fromarray(output, "RGB")
  114. def inference(
  115. img,
  116. model,
  117. *,
  118. tile: int,
  119. tile_overlap: int,
  120. window_size: int,
  121. scale: int,
  122. device,
  123. ):
  124. # test the image tile by tile
  125. b, c, h, w = img.size()
  126. tile = min(tile, h, w)
  127. assert tile % window_size == 0, "tile size should be a multiple of window_size"
  128. sf = scale
  129. stride = tile - tile_overlap
  130. h_idx_list = list(range(0, h - tile, stride)) + [h - tile]
  131. w_idx_list = list(range(0, w - tile, stride)) + [w - tile]
  132. E = torch.zeros(b, c, h * sf, w * sf, dtype=devices.dtype, device=device).type_as(img)
  133. W = torch.zeros_like(E, dtype=devices.dtype, device=device)
  134. with tqdm(total=len(h_idx_list) * len(w_idx_list), desc="SwinIR tiles") as pbar:
  135. for h_idx in h_idx_list:
  136. if state.interrupted or state.skipped:
  137. break
  138. for w_idx in w_idx_list:
  139. if state.interrupted or state.skipped:
  140. break
  141. in_patch = img[..., h_idx: h_idx + tile, w_idx: w_idx + tile]
  142. out_patch = model(in_patch)
  143. out_patch_mask = torch.ones_like(out_patch)
  144. E[
  145. ..., h_idx * sf: (h_idx + tile) * sf, w_idx * sf: (w_idx + tile) * sf
  146. ].add_(out_patch)
  147. W[
  148. ..., h_idx * sf: (h_idx + tile) * sf, w_idx * sf: (w_idx + tile) * sf
  149. ].add_(out_patch_mask)
  150. pbar.update(1)
  151. output = E.div_(W)
  152. return output
  153. def on_ui_settings():
  154. import gradio as gr
  155. shared.opts.add_option("SWIN_tile", shared.OptionInfo(192, "Tile size for all SwinIR.", gr.Slider, {"minimum": 16, "maximum": 512, "step": 16}, section=('upscaling', "Upscaling")))
  156. shared.opts.add_option("SWIN_tile_overlap", shared.OptionInfo(8, "Tile overlap, in pixels for SwinIR. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}, section=('upscaling', "Upscaling")))
  157. shared.opts.add_option("SWIN_torch_compile", shared.OptionInfo(False, "Use torch.compile to accelerate SwinIR.", gr.Checkbox, {"interactive": True}, section=('upscaling', "Upscaling")).info("Takes longer on first run"))
  158. script_callbacks.on_ui_settings(on_ui_settings)