esrgan_model.py 8.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232
  1. import os
  2. import numpy as np
  3. import torch
  4. from PIL import Image
  5. from basicsr.utils.download_util import load_file_from_url
  6. import modules.esrgan_model_arch as arch
  7. from modules import modelloader, images, devices
  8. from modules.upscaler import Upscaler, UpscalerData
  9. from modules.shared import opts
  10. def mod2normal(state_dict):
  11. # this code is copied from https://github.com/victorca25/iNNfer
  12. if 'conv_first.weight' in state_dict:
  13. crt_net = {}
  14. items = list(state_dict)
  15. crt_net['model.0.weight'] = state_dict['conv_first.weight']
  16. crt_net['model.0.bias'] = state_dict['conv_first.bias']
  17. for k in items.copy():
  18. if 'RDB' in k:
  19. ori_k = k.replace('RRDB_trunk.', 'model.1.sub.')
  20. if '.weight' in k:
  21. ori_k = ori_k.replace('.weight', '.0.weight')
  22. elif '.bias' in k:
  23. ori_k = ori_k.replace('.bias', '.0.bias')
  24. crt_net[ori_k] = state_dict[k]
  25. items.remove(k)
  26. crt_net['model.1.sub.23.weight'] = state_dict['trunk_conv.weight']
  27. crt_net['model.1.sub.23.bias'] = state_dict['trunk_conv.bias']
  28. crt_net['model.3.weight'] = state_dict['upconv1.weight']
  29. crt_net['model.3.bias'] = state_dict['upconv1.bias']
  30. crt_net['model.6.weight'] = state_dict['upconv2.weight']
  31. crt_net['model.6.bias'] = state_dict['upconv2.bias']
  32. crt_net['model.8.weight'] = state_dict['HRconv.weight']
  33. crt_net['model.8.bias'] = state_dict['HRconv.bias']
  34. crt_net['model.10.weight'] = state_dict['conv_last.weight']
  35. crt_net['model.10.bias'] = state_dict['conv_last.bias']
  36. state_dict = crt_net
  37. return state_dict
  38. def resrgan2normal(state_dict, nb=23):
  39. # this code is copied from https://github.com/victorca25/iNNfer
  40. if "conv_first.weight" in state_dict and "body.0.rdb1.conv1.weight" in state_dict:
  41. re8x = 0
  42. crt_net = {}
  43. items = list(state_dict)
  44. crt_net['model.0.weight'] = state_dict['conv_first.weight']
  45. crt_net['model.0.bias'] = state_dict['conv_first.bias']
  46. for k in items.copy():
  47. if "rdb" in k:
  48. ori_k = k.replace('body.', 'model.1.sub.')
  49. ori_k = ori_k.replace('.rdb', '.RDB')
  50. if '.weight' in k:
  51. ori_k = ori_k.replace('.weight', '.0.weight')
  52. elif '.bias' in k:
  53. ori_k = ori_k.replace('.bias', '.0.bias')
  54. crt_net[ori_k] = state_dict[k]
  55. items.remove(k)
  56. crt_net[f'model.1.sub.{nb}.weight'] = state_dict['conv_body.weight']
  57. crt_net[f'model.1.sub.{nb}.bias'] = state_dict['conv_body.bias']
  58. crt_net['model.3.weight'] = state_dict['conv_up1.weight']
  59. crt_net['model.3.bias'] = state_dict['conv_up1.bias']
  60. crt_net['model.6.weight'] = state_dict['conv_up2.weight']
  61. crt_net['model.6.bias'] = state_dict['conv_up2.bias']
  62. if 'conv_up3.weight' in state_dict:
  63. # modification supporting: https://github.com/ai-forever/Real-ESRGAN/blob/main/RealESRGAN/rrdbnet_arch.py
  64. re8x = 3
  65. crt_net['model.9.weight'] = state_dict['conv_up3.weight']
  66. crt_net['model.9.bias'] = state_dict['conv_up3.bias']
  67. crt_net[f'model.{8+re8x}.weight'] = state_dict['conv_hr.weight']
  68. crt_net[f'model.{8+re8x}.bias'] = state_dict['conv_hr.bias']
  69. crt_net[f'model.{10+re8x}.weight'] = state_dict['conv_last.weight']
  70. crt_net[f'model.{10+re8x}.bias'] = state_dict['conv_last.bias']
  71. state_dict = crt_net
  72. return state_dict
  73. def infer_params(state_dict):
  74. # this code is copied from https://github.com/victorca25/iNNfer
  75. scale2x = 0
  76. scalemin = 6
  77. n_uplayer = 0
  78. plus = False
  79. for block in list(state_dict):
  80. parts = block.split(".")
  81. n_parts = len(parts)
  82. if n_parts == 5 and parts[2] == "sub":
  83. nb = int(parts[3])
  84. elif n_parts == 3:
  85. part_num = int(parts[1])
  86. if (part_num > scalemin
  87. and parts[0] == "model"
  88. and parts[2] == "weight"):
  89. scale2x += 1
  90. if part_num > n_uplayer:
  91. n_uplayer = part_num
  92. out_nc = state_dict[block].shape[0]
  93. if not plus and "conv1x1" in block:
  94. plus = True
  95. nf = state_dict["model.0.weight"].shape[0]
  96. in_nc = state_dict["model.0.weight"].shape[1]
  97. out_nc = out_nc
  98. scale = 2 ** scale2x
  99. return in_nc, out_nc, nf, nb, plus, scale
  100. class UpscalerESRGAN(Upscaler):
  101. def __init__(self, dirname):
  102. self.name = "ESRGAN"
  103. self.model_url = "https://github.com/cszn/KAIR/releases/download/v1.0/ESRGAN.pth"
  104. self.model_name = "ESRGAN_4x"
  105. self.scalers = []
  106. self.user_path = dirname
  107. super().__init__()
  108. model_paths = self.find_models(ext_filter=[".pt", ".pth"])
  109. scalers = []
  110. if len(model_paths) == 0:
  111. scaler_data = UpscalerData(self.model_name, self.model_url, self, 4)
  112. scalers.append(scaler_data)
  113. for file in model_paths:
  114. if "http" in file:
  115. name = self.model_name
  116. else:
  117. name = modelloader.friendly_name(file)
  118. scaler_data = UpscalerData(name, file, self, 4)
  119. self.scalers.append(scaler_data)
  120. def do_upscale(self, img, selected_model):
  121. model = self.load_model(selected_model)
  122. if model is None:
  123. return img
  124. model.to(devices.device_esrgan)
  125. img = esrgan_upscale(model, img)
  126. return img
  127. def load_model(self, path: str):
  128. if "http" in path:
  129. filename = load_file_from_url(
  130. url=self.model_url,
  131. model_dir=self.model_download_path,
  132. file_name=f"{self.model_name}.pth",
  133. progress=True,
  134. )
  135. else:
  136. filename = path
  137. if not os.path.exists(filename) or filename is None:
  138. print(f"Unable to load {self.model_path} from {filename}")
  139. return None
  140. state_dict = torch.load(filename, map_location='cpu' if devices.device_esrgan.type == 'mps' else None)
  141. if "params_ema" in state_dict:
  142. state_dict = state_dict["params_ema"]
  143. elif "params" in state_dict:
  144. state_dict = state_dict["params"]
  145. num_conv = 16 if "realesr-animevideov3" in filename else 32
  146. model = arch.SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=num_conv, upscale=4, act_type='prelu')
  147. model.load_state_dict(state_dict)
  148. model.eval()
  149. return model
  150. if "body.0.rdb1.conv1.weight" in state_dict and "conv_first.weight" in state_dict:
  151. nb = 6 if "RealESRGAN_x4plus_anime_6B" in filename else 23
  152. state_dict = resrgan2normal(state_dict, nb)
  153. elif "conv_first.weight" in state_dict:
  154. state_dict = mod2normal(state_dict)
  155. elif "model.0.weight" not in state_dict:
  156. raise Exception("The file is not a recognized ESRGAN model.")
  157. in_nc, out_nc, nf, nb, plus, mscale = infer_params(state_dict)
  158. model = arch.RRDBNet(in_nc=in_nc, out_nc=out_nc, nf=nf, nb=nb, upscale=mscale, plus=plus)
  159. model.load_state_dict(state_dict)
  160. model.eval()
  161. return model
  162. def upscale_without_tiling(model, img):
  163. img = np.array(img)
  164. img = img[:, :, ::-1]
  165. img = np.ascontiguousarray(np.transpose(img, (2, 0, 1))) / 255
  166. img = torch.from_numpy(img).float()
  167. img = img.unsqueeze(0).to(devices.device_esrgan)
  168. with torch.no_grad():
  169. output = model(img)
  170. output = output.squeeze().float().cpu().clamp_(0, 1).numpy()
  171. output = 255. * np.moveaxis(output, 0, 2)
  172. output = output.astype(np.uint8)
  173. output = output[:, :, ::-1]
  174. return Image.fromarray(output, 'RGB')
  175. def esrgan_upscale(model, img):
  176. if opts.ESRGAN_tile == 0:
  177. return upscale_without_tiling(model, img)
  178. grid = images.split_grid(img, opts.ESRGAN_tile, opts.ESRGAN_tile, opts.ESRGAN_tile_overlap)
  179. newtiles = []
  180. scale_factor = 1
  181. for y, h, row in grid.tiles:
  182. newrow = []
  183. for tiledata in row:
  184. x, w, tile = tiledata
  185. output = upscale_without_tiling(model, tile)
  186. scale_factor = output.width // tile.width
  187. newrow.append([x * scale_factor, w * scale_factor, output])
  188. newtiles.append([y * scale_factor, h * scale_factor, newrow])
  189. newgrid = images.Grid(newtiles, grid.tile_w * scale_factor, grid.tile_h * scale_factor, grid.image_w * scale_factor, grid.image_h * scale_factor, grid.overlap * scale_factor)
  190. output = images.combine_grid(newgrid)
  191. return output