hat_model.py 1.6 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243
  1. import os
  2. import sys
  3. from modules import modelloader, devices
  4. from modules.shared import opts
  5. from modules.upscaler import Upscaler, UpscalerData
  6. from modules.upscaler_utils import upscale_with_model
  7. class UpscalerHAT(Upscaler):
  8. def __init__(self, dirname):
  9. self.name = "HAT"
  10. self.scalers = []
  11. self.user_path = dirname
  12. super().__init__()
  13. for file in self.find_models(ext_filter=[".pt", ".pth"]):
  14. name = modelloader.friendly_name(file)
  15. scale = 4 # TODO: scale might not be 4, but we can't know without loading the model
  16. scaler_data = UpscalerData(name, file, upscaler=self, scale=scale)
  17. self.scalers.append(scaler_data)
  18. def do_upscale(self, img, selected_model):
  19. try:
  20. model = self.load_model(selected_model)
  21. except Exception as e:
  22. print(f"Unable to load HAT model {selected_model}: {e}", file=sys.stderr)
  23. return img
  24. model.to(devices.device_esrgan) # TODO: should probably be device_hat
  25. return upscale_with_model(
  26. model,
  27. img,
  28. tile_size=opts.ESRGAN_tile, # TODO: should probably be HAT_tile
  29. tile_overlap=opts.ESRGAN_tile_overlap, # TODO: should probably be HAT_tile_overlap
  30. )
  31. def load_model(self, path: str):
  32. if not os.path.isfile(path):
  33. raise FileNotFoundError(f"Model file {path} not found")
  34. return modelloader.load_spandrel_model(
  35. path,
  36. device=devices.device_esrgan, # TODO: should probably be device_hat
  37. expected_architecture='HAT',
  38. )