dat_model.py 2.6 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879
  1. import os
  2. from modules import modelloader, errors
  3. from modules.shared import cmd_opts, opts
  4. from modules.upscaler import Upscaler, UpscalerData
  5. from modules.upscaler_utils import upscale_with_model
  6. class UpscalerDAT(Upscaler):
  7. def __init__(self, user_path):
  8. self.name = "DAT"
  9. self.user_path = user_path
  10. self.scalers = []
  11. super().__init__()
  12. for file in self.find_models(ext_filter=[".pt", ".pth"]):
  13. name = modelloader.friendly_name(file)
  14. scaler_data = UpscalerData(name, file, upscaler=self, scale=None)
  15. self.scalers.append(scaler_data)
  16. for model in get_dat_models(self):
  17. if model.name in opts.dat_enabled_models:
  18. self.scalers.append(model)
  19. def do_upscale(self, img, path):
  20. try:
  21. info = self.load_model(path)
  22. except Exception:
  23. errors.report(f"Unable to load DAT model {path}", exc_info=True)
  24. return img
  25. model_descriptor = modelloader.load_spandrel_model(
  26. info.local_data_path,
  27. device=self.device,
  28. prefer_half=(not cmd_opts.no_half and not cmd_opts.upcast_sampling),
  29. expected_architecture="DAT",
  30. )
  31. return upscale_with_model(
  32. model_descriptor,
  33. img,
  34. tile_size=opts.DAT_tile,
  35. tile_overlap=opts.DAT_tile_overlap,
  36. )
  37. def load_model(self, path):
  38. for scaler in self.scalers:
  39. if scaler.data_path == path:
  40. if scaler.local_data_path.startswith("http"):
  41. scaler.local_data_path = modelloader.load_file_from_url(
  42. scaler.data_path,
  43. model_dir=self.model_download_path,
  44. )
  45. if not os.path.exists(scaler.local_data_path):
  46. raise FileNotFoundError(f"DAT data missing: {scaler.local_data_path}")
  47. return scaler
  48. raise ValueError(f"Unable to find model info: {path}")
  49. def get_dat_models(scaler):
  50. return [
  51. UpscalerData(
  52. name="DAT x2",
  53. path="https://github.com/n0kovo/dat_upscaler_models/raw/main/DAT/DAT_x2.pth",
  54. scale=2,
  55. upscaler=scaler,
  56. ),
  57. UpscalerData(
  58. name="DAT x3",
  59. path="https://github.com/n0kovo/dat_upscaler_models/raw/main/DAT/DAT_x3.pth",
  60. scale=3,
  61. upscaler=scaler,
  62. ),
  63. UpscalerData(
  64. name="DAT x4",
  65. path="https://github.com/n0kovo/dat_upscaler_models/raw/main/DAT/DAT_x4.pth",
  66. scale=4,
  67. upscaler=scaler,
  68. ),
  69. ]