modelloader.py 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157
  1. import os
  2. import shutil
  3. import importlib
  4. from urllib.parse import urlparse
  5. from modules import shared
  6. from modules.upscaler import Upscaler, UpscalerLanczos, UpscalerNearest, UpscalerNone
  7. from modules.paths import script_path, models_path
  8. def load_models(model_path: str, model_url: str = None, command_path: str = None, ext_filter=None, download_name=None, ext_blacklist=None) -> list:
  9. """
  10. A one-and done loader to try finding the desired models in specified directories.
  11. @param download_name: Specify to download from model_url immediately.
  12. @param model_url: If no other models are found, this will be downloaded on upscale.
  13. @param model_path: The location to store/find models in.
  14. @param command_path: A command-line argument to search for models in first.
  15. @param ext_filter: An optional list of filename extensions to filter by
  16. @return: A list of paths containing the desired model(s)
  17. """
  18. output = []
  19. try:
  20. places = []
  21. if command_path is not None and command_path != model_path:
  22. pretrained_path = os.path.join(command_path, 'experiments/pretrained_models')
  23. if os.path.exists(pretrained_path):
  24. print(f"Appending path: {pretrained_path}")
  25. places.append(pretrained_path)
  26. elif os.path.exists(command_path):
  27. places.append(command_path)
  28. places.append(model_path)
  29. for place in places:
  30. for full_path in shared.walk_files(place, allowed_extensions=ext_filter):
  31. if os.path.islink(full_path) and not os.path.exists(full_path):
  32. print(f"Skipping broken symlink: {full_path}")
  33. continue
  34. if ext_blacklist is not None and any(full_path.endswith(x) for x in ext_blacklist):
  35. continue
  36. if full_path not in output:
  37. output.append(full_path)
  38. if model_url is not None and len(output) == 0:
  39. if download_name is not None:
  40. from basicsr.utils.download_util import load_file_from_url
  41. dl = load_file_from_url(model_url, places[0], True, download_name)
  42. output.append(dl)
  43. else:
  44. output.append(model_url)
  45. except Exception:
  46. pass
  47. return output
  48. def friendly_name(file: str):
  49. if "http" in file:
  50. file = urlparse(file).path
  51. file = os.path.basename(file)
  52. model_name, extension = os.path.splitext(file)
  53. return model_name
  54. def cleanup_models():
  55. # This code could probably be more efficient if we used a tuple list or something to store the src/destinations
  56. # and then enumerate that, but this works for now. In the future, it'd be nice to just have every "model" scaler
  57. # somehow auto-register and just do these things...
  58. root_path = script_path
  59. src_path = models_path
  60. dest_path = os.path.join(models_path, "Stable-diffusion")
  61. move_files(src_path, dest_path, ".ckpt")
  62. move_files(src_path, dest_path, ".safetensors")
  63. src_path = os.path.join(root_path, "ESRGAN")
  64. dest_path = os.path.join(models_path, "ESRGAN")
  65. move_files(src_path, dest_path)
  66. src_path = os.path.join(models_path, "BSRGAN")
  67. dest_path = os.path.join(models_path, "ESRGAN")
  68. move_files(src_path, dest_path, ".pth")
  69. src_path = os.path.join(root_path, "gfpgan")
  70. dest_path = os.path.join(models_path, "GFPGAN")
  71. move_files(src_path, dest_path)
  72. src_path = os.path.join(root_path, "SwinIR")
  73. dest_path = os.path.join(models_path, "SwinIR")
  74. move_files(src_path, dest_path)
  75. src_path = os.path.join(root_path, "repositories/latent-diffusion/experiments/pretrained_models/")
  76. dest_path = os.path.join(models_path, "LDSR")
  77. move_files(src_path, dest_path)
  78. def move_files(src_path: str, dest_path: str, ext_filter: str = None):
  79. try:
  80. if not os.path.exists(dest_path):
  81. os.makedirs(dest_path)
  82. if os.path.exists(src_path):
  83. for file in os.listdir(src_path):
  84. fullpath = os.path.join(src_path, file)
  85. if os.path.isfile(fullpath):
  86. if ext_filter is not None:
  87. if ext_filter not in file:
  88. continue
  89. print(f"Moving {file} from {src_path} to {dest_path}.")
  90. try:
  91. shutil.move(fullpath, dest_path)
  92. except Exception:
  93. pass
  94. if len(os.listdir(src_path)) == 0:
  95. print(f"Removing empty folder: {src_path}")
  96. shutil.rmtree(src_path, True)
  97. except Exception:
  98. pass
  99. def load_upscalers():
  100. # We can only do this 'magic' method to dynamically load upscalers if they are referenced,
  101. # so we'll try to import any _model.py files before looking in __subclasses__
  102. modules_dir = os.path.join(shared.script_path, "modules")
  103. for file in os.listdir(modules_dir):
  104. if "_model.py" in file:
  105. model_name = file.replace("_model.py", "")
  106. full_model = f"modules.{model_name}_model"
  107. try:
  108. importlib.import_module(full_model)
  109. except Exception:
  110. pass
  111. datas = []
  112. commandline_options = vars(shared.cmd_opts)
  113. # some of upscaler classes will not go away after reloading their modules, and we'll end
  114. # up with two copies of those classes. The newest copy will always be the last in the list,
  115. # so we go from end to beginning and ignore duplicates
  116. used_classes = {}
  117. for cls in reversed(Upscaler.__subclasses__()):
  118. classname = str(cls)
  119. if classname not in used_classes:
  120. used_classes[classname] = cls
  121. for cls in reversed(used_classes.values()):
  122. name = cls.__name__
  123. cmd_name = f"{name.lower().replace('upscaler', '')}_models_path"
  124. commandline_model_path = commandline_options.get(cmd_name, None)
  125. scaler = cls(commandline_model_path)
  126. scaler.user_path = commandline_model_path
  127. scaler.model_download_path = commandline_model_path or scaler.model_path
  128. datas += scaler.scalers
  129. shared.sd_upscalers = sorted(
  130. datas,
  131. # Special case for UpscalerNone keeps it at the beginning of the list.
  132. key=lambda x: x.name.lower() if not isinstance(x.scaler, (UpscalerNone, UpscalerLanczos, UpscalerNearest)) else ""
  133. )