upscaler.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151
  1. import os
  2. from abc import abstractmethod
  3. import PIL
  4. from PIL import Image
  5. import modules.shared
  6. from modules import modelloader, shared
  7. LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS)
  8. NEAREST = (Image.Resampling.NEAREST if hasattr(Image, 'Resampling') else Image.NEAREST)
  9. class Upscaler:
  10. name = None
  11. model_path = None
  12. model_name = None
  13. model_url = None
  14. enable = True
  15. filter = None
  16. model = None
  17. user_path = None
  18. scalers: list
  19. tile = True
  20. def __init__(self, create_dirs=False):
  21. self.mod_pad_h = None
  22. self.tile_size = modules.shared.opts.ESRGAN_tile
  23. self.tile_pad = modules.shared.opts.ESRGAN_tile_overlap
  24. self.device = modules.shared.device
  25. self.img = None
  26. self.output = None
  27. self.scale = 1
  28. self.half = not modules.shared.cmd_opts.no_half
  29. self.pre_pad = 0
  30. self.mod_scale = None
  31. self.model_download_path = None
  32. if self.model_path is None and self.name:
  33. self.model_path = os.path.join(shared.models_path, self.name)
  34. if self.model_path and create_dirs:
  35. os.makedirs(self.model_path, exist_ok=True)
  36. try:
  37. import cv2 # noqa: F401
  38. self.can_tile = True
  39. except Exception:
  40. pass
  41. @abstractmethod
  42. def do_upscale(self, img: PIL.Image, selected_model: str):
  43. return img
  44. def upscale(self, img: PIL.Image, scale, selected_model: str = None):
  45. self.scale = scale
  46. dest_w = int((img.width * scale) // 8 * 8)
  47. dest_h = int((img.height * scale) // 8 * 8)
  48. for i in range(3):
  49. if img.width >= dest_w and img.height >= dest_h and (i > 0 or scale != 1):
  50. break
  51. if shared.state.interrupted:
  52. break
  53. shape = (img.width, img.height)
  54. img = self.do_upscale(img, selected_model)
  55. if shape == (img.width, img.height):
  56. break
  57. if img.width != dest_w or img.height != dest_h:
  58. img = img.resize((int(dest_w), int(dest_h)), resample=LANCZOS)
  59. return img
  60. @abstractmethod
  61. def load_model(self, path: str):
  62. pass
  63. def find_models(self, ext_filter=None) -> list:
  64. return modelloader.load_models(model_path=self.model_path, model_url=self.model_url, command_path=self.user_path, ext_filter=ext_filter)
  65. def update_status(self, prompt):
  66. print(f"\nextras: {prompt}", file=shared.progress_print_out)
  67. class UpscalerData:
  68. name = None
  69. data_path = None
  70. scale: int = 4
  71. scaler: Upscaler = None
  72. model: None
  73. def __init__(self, name: str, path: str, upscaler: Upscaler = None, scale: int = 4, model=None, sha256: str = None):
  74. self.name = name
  75. self.data_path = path
  76. self.local_data_path = path
  77. self.scaler = upscaler
  78. self.scale = scale
  79. self.model = model
  80. self.sha256 = sha256
  81. def __repr__(self):
  82. return f"<UpscalerData name={self.name} path={self.data_path} scale={self.scale}>"
  83. class UpscalerNone(Upscaler):
  84. name = "None"
  85. scalers = []
  86. def load_model(self, path):
  87. pass
  88. def do_upscale(self, img, selected_model=None):
  89. return img
  90. def __init__(self, dirname=None):
  91. super().__init__(False)
  92. self.scalers = [UpscalerData("None", None, self)]
  93. class UpscalerLanczos(Upscaler):
  94. scalers = []
  95. def do_upscale(self, img, selected_model=None):
  96. return img.resize((int(img.width * self.scale), int(img.height * self.scale)), resample=LANCZOS)
  97. def load_model(self, _):
  98. pass
  99. def __init__(self, dirname=None):
  100. super().__init__(False)
  101. self.name = "Lanczos"
  102. self.scalers = [UpscalerData("Lanczos", None, self)]
  103. class UpscalerNearest(Upscaler):
  104. scalers = []
  105. def do_upscale(self, img, selected_model=None):
  106. return img.resize((int(img.width * self.scale), int(img.height * self.scale)), resample=NEAREST)
  107. def load_model(self, _):
  108. pass
  109. def __init__(self, dirname=None):
  110. super().__init__(False)
  111. self.name = "Nearest"
  112. self.scalers = [UpscalerData("Nearest", None, self)]