upscaler_utils.py 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188
  1. import logging
  2. from typing import Callable
  3. import numpy as np
  4. import torch
  5. import tqdm
  6. from PIL import Image
  7. from modules import devices, images, shared, torch_utils
  8. logger = logging.getLogger(__name__)
  9. def pil_image_to_torch_bgr(img: Image.Image) -> torch.Tensor:
  10. img = np.array(img.convert("RGB"))
  11. img = img[:, :, ::-1] # flip RGB to BGR
  12. img = np.transpose(img, (2, 0, 1)) # HWC to CHW
  13. img = np.ascontiguousarray(img) / 255 # Rescale to [0, 1]
  14. return torch.from_numpy(img)
  15. def torch_bgr_to_pil_image(tensor: torch.Tensor) -> Image.Image:
  16. if tensor.ndim == 4:
  17. # If we're given a tensor with a batch dimension, squeeze it out
  18. # (but only if it's a batch of size 1).
  19. if tensor.shape[0] != 1:
  20. raise ValueError(f"{tensor.shape} does not describe a BCHW tensor")
  21. tensor = tensor.squeeze(0)
  22. assert tensor.ndim == 3, f"{tensor.shape} does not describe a CHW tensor"
  23. # TODO: is `tensor.float().cpu()...numpy()` the most efficient idiom?
  24. arr = tensor.float().cpu().clamp_(0, 1).numpy() # clamp
  25. arr = 255.0 * np.moveaxis(arr, 0, 2) # CHW to HWC, rescale
  26. arr = arr.round().astype(np.uint8)
  27. arr = arr[:, :, ::-1] # flip BGR to RGB
  28. return Image.fromarray(arr, "RGB")
  29. def upscale_pil_patch(model, img: Image.Image) -> Image.Image:
  30. """
  31. Upscale a given PIL image using the given model.
  32. """
  33. param = torch_utils.get_param(model)
  34. with torch.no_grad():
  35. tensor = pil_image_to_torch_bgr(img).unsqueeze(0) # add batch dimension
  36. tensor = tensor.to(device=param.device, dtype=param.dtype)
  37. with devices.without_autocast():
  38. return torch_bgr_to_pil_image(model(tensor))
  39. def upscale_with_model(
  40. model: Callable[[torch.Tensor], torch.Tensor],
  41. img: Image.Image,
  42. *,
  43. tile_size: int,
  44. tile_overlap: int = 0,
  45. desc="tiled upscale",
  46. ) -> Image.Image:
  47. if tile_size <= 0:
  48. logger.debug("Upscaling %s without tiling", img)
  49. output = upscale_pil_patch(model, img)
  50. logger.debug("=> %s", output)
  51. return output
  52. grid = images.split_grid(img, tile_size, tile_size, tile_overlap)
  53. newtiles = []
  54. with tqdm.tqdm(total=grid.tile_count, desc=desc, disable=not shared.opts.enable_upscale_progressbar) as p:
  55. for y, h, row in grid.tiles:
  56. newrow = []
  57. for x, w, tile in row:
  58. output = upscale_pil_patch(model, tile)
  59. scale_factor = output.width // tile.width
  60. newrow.append([x * scale_factor, w * scale_factor, output])
  61. p.update(1)
  62. newtiles.append([y * scale_factor, h * scale_factor, newrow])
  63. newgrid = images.Grid(
  64. newtiles,
  65. tile_w=grid.tile_w * scale_factor,
  66. tile_h=grid.tile_h * scale_factor,
  67. image_w=grid.image_w * scale_factor,
  68. image_h=grid.image_h * scale_factor,
  69. overlap=grid.overlap * scale_factor,
  70. )
  71. return images.combine_grid(newgrid)
  72. def tiled_upscale_2(
  73. img: torch.Tensor,
  74. model,
  75. *,
  76. tile_size: int,
  77. tile_overlap: int,
  78. scale: int,
  79. device: torch.device,
  80. desc="Tiled upscale",
  81. ):
  82. # Alternative implementation of `upscale_with_model` originally used by
  83. # SwinIR and ScuNET. It differs from `upscale_with_model` in that tiling and
  84. # weighting is done in PyTorch space, as opposed to `images.Grid` doing it in
  85. # Pillow space without weighting.
  86. b, c, h, w = img.size()
  87. tile_size = min(tile_size, h, w)
  88. if tile_size <= 0:
  89. logger.debug("Upscaling %s without tiling", img.shape)
  90. return model(img)
  91. stride = tile_size - tile_overlap
  92. h_idx_list = list(range(0, h - tile_size, stride)) + [h - tile_size]
  93. w_idx_list = list(range(0, w - tile_size, stride)) + [w - tile_size]
  94. result = torch.zeros(
  95. b,
  96. c,
  97. h * scale,
  98. w * scale,
  99. device=device,
  100. dtype=img.dtype,
  101. )
  102. weights = torch.zeros_like(result)
  103. logger.debug("Upscaling %s to %s with tiles", img.shape, result.shape)
  104. with tqdm.tqdm(total=len(h_idx_list) * len(w_idx_list), desc=desc, disable=not shared.opts.enable_upscale_progressbar) as pbar:
  105. for h_idx in h_idx_list:
  106. if shared.state.interrupted or shared.state.skipped:
  107. break
  108. for w_idx in w_idx_list:
  109. if shared.state.interrupted or shared.state.skipped:
  110. break
  111. # Only move this patch to the device if it's not already there.
  112. in_patch = img[
  113. ...,
  114. h_idx : h_idx + tile_size,
  115. w_idx : w_idx + tile_size,
  116. ].to(device=device)
  117. out_patch = model(in_patch)
  118. result[
  119. ...,
  120. h_idx * scale : (h_idx + tile_size) * scale,
  121. w_idx * scale : (w_idx + tile_size) * scale,
  122. ].add_(out_patch)
  123. out_patch_mask = torch.ones_like(out_patch)
  124. weights[
  125. ...,
  126. h_idx * scale : (h_idx + tile_size) * scale,
  127. w_idx * scale : (w_idx + tile_size) * scale,
  128. ].add_(out_patch_mask)
  129. pbar.update(1)
  130. output = result.div_(weights)
  131. return output
  132. def upscale_2(
  133. img: Image.Image,
  134. model,
  135. *,
  136. tile_size: int,
  137. tile_overlap: int,
  138. scale: int,
  139. desc: str,
  140. ):
  141. """
  142. Convenience wrapper around `tiled_upscale_2` that handles PIL images.
  143. """
  144. param = torch_utils.get_param(model)
  145. tensor = pil_image_to_torch_bgr(img).to(dtype=param.dtype).unsqueeze(0) # add batch dimension
  146. with torch.no_grad():
  147. output = tiled_upscale_2(
  148. tensor,
  149. model,
  150. tile_size=tile_size,
  151. tile_overlap=tile_overlap,
  152. scale=scale,
  153. desc=desc,
  154. device=param.device,
  155. )
  156. return torch_bgr_to_pil_image(output)