esrgan_model.py 2.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162
  1. from modules import modelloader, devices, errors
  2. from modules.shared import opts
  3. from modules.upscaler import Upscaler, UpscalerData
  4. from modules.upscaler_utils import upscale_with_model
  5. class UpscalerESRGAN(Upscaler):
  6. def __init__(self, dirname):
  7. self.name = "ESRGAN"
  8. self.model_url = "https://github.com/cszn/KAIR/releases/download/v1.0/ESRGAN.pth"
  9. self.model_name = "ESRGAN_4x"
  10. self.scalers = []
  11. self.user_path = dirname
  12. super().__init__()
  13. model_paths = self.find_models(ext_filter=[".pt", ".pth"])
  14. scalers = []
  15. if len(model_paths) == 0:
  16. scaler_data = UpscalerData(self.model_name, self.model_url, self, 4)
  17. scalers.append(scaler_data)
  18. for file in model_paths:
  19. if file.startswith("http"):
  20. name = self.model_name
  21. else:
  22. name = modelloader.friendly_name(file)
  23. scaler_data = UpscalerData(name, file, self, 4)
  24. self.scalers.append(scaler_data)
  25. def do_upscale(self, img, selected_model):
  26. try:
  27. model = self.load_model(selected_model)
  28. except Exception:
  29. errors.report(f"Unable to load ESRGAN model {selected_model}", exc_info=True)
  30. return img
  31. model.to(devices.device_esrgan)
  32. return esrgan_upscale(model, img)
  33. def load_model(self, path: str):
  34. if path.startswith("http"):
  35. # TODO: this doesn't use `path` at all?
  36. filename = modelloader.load_file_from_url(
  37. url=self.model_url,
  38. model_dir=self.model_download_path,
  39. file_name=f"{self.model_name}.pth",
  40. )
  41. else:
  42. filename = path
  43. return modelloader.load_spandrel_model(
  44. filename,
  45. device=('cpu' if devices.device_esrgan.type == 'mps' else None),
  46. expected_architecture='ESRGAN',
  47. )
  48. def esrgan_upscale(model, img):
  49. return upscale_with_model(
  50. model,
  51. img,
  52. tile_size=opts.ESRGAN_tile,
  53. tile_overlap=opts.ESRGAN_tile_overlap,
  54. )