upscaler_utils.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140
  1. import logging
  2. from typing import Callable
  3. import numpy as np
  4. import torch
  5. import tqdm
  6. from PIL import Image
  7. from modules import images, shared, torch_utils
  8. logger = logging.getLogger(__name__)
  9. def upscale_without_tiling(model, img: Image.Image):
  10. img = np.array(img)
  11. img = img[:, :, ::-1]
  12. img = np.ascontiguousarray(np.transpose(img, (2, 0, 1))) / 255
  13. img = torch.from_numpy(img).float()
  14. param = torch_utils.get_param(model)
  15. img = img.unsqueeze(0).to(device=param.device, dtype=param.dtype)
  16. with torch.no_grad():
  17. output = model(img)
  18. output = output.squeeze().float().cpu().clamp_(0, 1).numpy()
  19. output = 255. * np.moveaxis(output, 0, 2)
  20. output = output.astype(np.uint8)
  21. output = output[:, :, ::-1]
  22. return Image.fromarray(output, 'RGB')
  23. def upscale_with_model(
  24. model: Callable[[torch.Tensor], torch.Tensor],
  25. img: Image.Image,
  26. *,
  27. tile_size: int,
  28. tile_overlap: int = 0,
  29. desc="tiled upscale",
  30. ) -> Image.Image:
  31. if tile_size <= 0:
  32. logger.debug("Upscaling %s without tiling", img)
  33. output = upscale_without_tiling(model, img)
  34. logger.debug("=> %s", output)
  35. return output
  36. grid = images.split_grid(img, tile_size, tile_size, tile_overlap)
  37. newtiles = []
  38. with tqdm.tqdm(total=grid.tile_count, desc=desc) as p:
  39. for y, h, row in grid.tiles:
  40. newrow = []
  41. for x, w, tile in row:
  42. logger.debug("Tile (%d, %d) %s...", x, y, tile)
  43. output = upscale_without_tiling(model, tile)
  44. scale_factor = output.width // tile.width
  45. logger.debug("=> %s (scale factor %s)", output, scale_factor)
  46. newrow.append([x * scale_factor, w * scale_factor, output])
  47. p.update(1)
  48. newtiles.append([y * scale_factor, h * scale_factor, newrow])
  49. newgrid = images.Grid(
  50. newtiles,
  51. tile_w=grid.tile_w * scale_factor,
  52. tile_h=grid.tile_h * scale_factor,
  53. image_w=grid.image_w * scale_factor,
  54. image_h=grid.image_h * scale_factor,
  55. overlap=grid.overlap * scale_factor,
  56. )
  57. return images.combine_grid(newgrid)
  58. def tiled_upscale_2(
  59. img,
  60. model,
  61. *,
  62. tile_size: int,
  63. tile_overlap: int,
  64. scale: int,
  65. device,
  66. desc="Tiled upscale",
  67. ):
  68. # Alternative implementation of `upscale_with_model` originally used by
  69. # SwinIR and ScuNET. It differs from `upscale_with_model` in that tiling and
  70. # weighting is done in PyTorch space, as opposed to `images.Grid` doing it in
  71. # Pillow space without weighting.
  72. b, c, h, w = img.size()
  73. tile_size = min(tile_size, h, w)
  74. if tile_size <= 0:
  75. logger.debug("Upscaling %s without tiling", img.shape)
  76. return model(img)
  77. stride = tile_size - tile_overlap
  78. h_idx_list = list(range(0, h - tile_size, stride)) + [h - tile_size]
  79. w_idx_list = list(range(0, w - tile_size, stride)) + [w - tile_size]
  80. result = torch.zeros(
  81. b,
  82. c,
  83. h * scale,
  84. w * scale,
  85. device=device,
  86. ).type_as(img)
  87. weights = torch.zeros_like(result)
  88. logger.debug("Upscaling %s to %s with tiles", img.shape, result.shape)
  89. with tqdm.tqdm(total=len(h_idx_list) * len(w_idx_list), desc=desc) as pbar:
  90. for h_idx in h_idx_list:
  91. if shared.state.interrupted or shared.state.skipped:
  92. break
  93. for w_idx in w_idx_list:
  94. if shared.state.interrupted or shared.state.skipped:
  95. break
  96. in_patch = img[
  97. ...,
  98. h_idx : h_idx + tile_size,
  99. w_idx : w_idx + tile_size,
  100. ]
  101. out_patch = model(in_patch)
  102. result[
  103. ...,
  104. h_idx * scale : (h_idx + tile_size) * scale,
  105. w_idx * scale : (w_idx + tile_size) * scale,
  106. ].add_(out_patch)
  107. out_patch_mask = torch.ones_like(out_patch)
  108. weights[
  109. ...,
  110. h_idx * scale : (h_idx + tile_size) * scale,
  111. w_idx * scale : (w_idx + tile_size) * scale,
  112. ].add_(out_patch_mask)
  113. pbar.update(1)
  114. output = result.div_(weights)
  115. return output