realesrgan_model.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105
  1. import os
  2. from modules import modelloader, errors
  3. from modules.shared import cmd_opts, opts
  4. from modules.upscaler import Upscaler, UpscalerData
  5. from modules.upscaler_utils import upscale_with_model
  6. class UpscalerRealESRGAN(Upscaler):
  7. def __init__(self, path):
  8. self.name = "RealESRGAN"
  9. self.user_path = path
  10. super().__init__()
  11. self.enable = True
  12. self.scalers = []
  13. scalers = get_realesrgan_models(self)
  14. local_model_paths = self.find_models(ext_filter=[".pth"])
  15. for scaler in scalers:
  16. if scaler.local_data_path.startswith("http"):
  17. filename = modelloader.friendly_name(scaler.local_data_path)
  18. local_model_candidates = [local_model for local_model in local_model_paths if local_model.endswith(f"{filename}.pth")]
  19. if local_model_candidates:
  20. scaler.local_data_path = local_model_candidates[0]
  21. if scaler.name in opts.realesrgan_enabled_models:
  22. self.scalers.append(scaler)
  23. def do_upscale(self, img, path):
  24. if not self.enable:
  25. return img
  26. try:
  27. info = self.load_model(path)
  28. except Exception:
  29. errors.report(f"Unable to load RealESRGAN model {path}", exc_info=True)
  30. return img
  31. model_descriptor = modelloader.load_spandrel_model(
  32. info.local_data_path,
  33. device=self.device,
  34. prefer_half=(not cmd_opts.no_half and not cmd_opts.upcast_sampling),
  35. expected_architecture="ESRGAN", # "RealESRGAN" isn't a specific thing for Spandrel
  36. )
  37. return upscale_with_model(
  38. model_descriptor,
  39. img,
  40. tile_size=opts.ESRGAN_tile,
  41. tile_overlap=opts.ESRGAN_tile_overlap,
  42. # TODO: `outscale`?
  43. )
  44. def load_model(self, path):
  45. for scaler in self.scalers:
  46. if scaler.data_path == path:
  47. if scaler.local_data_path.startswith("http"):
  48. scaler.local_data_path = modelloader.load_file_from_url(
  49. scaler.data_path,
  50. model_dir=self.model_download_path,
  51. )
  52. if not os.path.exists(scaler.local_data_path):
  53. raise FileNotFoundError(f"RealESRGAN data missing: {scaler.local_data_path}")
  54. return scaler
  55. raise ValueError(f"Unable to find model info: {path}")
  56. def get_realesrgan_models(scaler: UpscalerRealESRGAN):
  57. return [
  58. UpscalerData(
  59. name="R-ESRGAN General 4xV3",
  60. path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth",
  61. scale=4,
  62. upscaler=scaler,
  63. ),
  64. UpscalerData(
  65. name="R-ESRGAN General WDN 4xV3",
  66. path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-wdn-x4v3.pth",
  67. scale=4,
  68. upscaler=scaler,
  69. ),
  70. UpscalerData(
  71. name="R-ESRGAN AnimeVideo",
  72. path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-animevideov3.pth",
  73. scale=4,
  74. upscaler=scaler,
  75. ),
  76. UpscalerData(
  77. name="R-ESRGAN 4x+",
  78. path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth",
  79. scale=4,
  80. upscaler=scaler,
  81. ),
  82. UpscalerData(
  83. name="R-ESRGAN 4x+ Anime6B",
  84. path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth",
  85. scale=4,
  86. upscaler=scaler,
  87. ),
  88. UpscalerData(
  89. name="R-ESRGAN 2x+",
  90. path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth",
  91. scale=2,
  92. upscaler=scaler,
  93. ),
  94. ]