modelloader.py 7.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206
  1. from __future__ import annotations
  2. import logging
  3. import os
  4. import shutil
  5. import importlib
  6. from urllib.parse import urlparse
  7. import torch
  8. from modules import shared
  9. from modules.upscaler import Upscaler, UpscalerLanczos, UpscalerNearest, UpscalerNone
  10. from modules.paths import script_path, models_path
  11. logger = logging.getLogger(__name__)
  12. def load_file_from_url(
  13. url: str,
  14. *,
  15. model_dir: str,
  16. progress: bool = True,
  17. file_name: str | None = None,
  18. ) -> str:
  19. """Download a file from `url` into `model_dir`, using the file present if possible.
  20. Returns the path to the downloaded file.
  21. """
  22. os.makedirs(model_dir, exist_ok=True)
  23. if not file_name:
  24. parts = urlparse(url)
  25. file_name = os.path.basename(parts.path)
  26. cached_file = os.path.abspath(os.path.join(model_dir, file_name))
  27. if not os.path.exists(cached_file):
  28. print(f'Downloading: "{url}" to {cached_file}\n')
  29. from torch.hub import download_url_to_file
  30. download_url_to_file(url, cached_file, progress=progress)
  31. return cached_file
  32. def load_models(model_path: str, model_url: str = None, command_path: str = None, ext_filter=None, download_name=None, ext_blacklist=None) -> list:
  33. """
  34. A one-and done loader to try finding the desired models in specified directories.
  35. @param download_name: Specify to download from model_url immediately.
  36. @param model_url: If no other models are found, this will be downloaded on upscale.
  37. @param model_path: The location to store/find models in.
  38. @param command_path: A command-line argument to search for models in first.
  39. @param ext_filter: An optional list of filename extensions to filter by
  40. @return: A list of paths containing the desired model(s)
  41. """
  42. output = []
  43. try:
  44. places = []
  45. if command_path is not None and command_path != model_path:
  46. pretrained_path = os.path.join(command_path, 'experiments/pretrained_models')
  47. if os.path.exists(pretrained_path):
  48. print(f"Appending path: {pretrained_path}")
  49. places.append(pretrained_path)
  50. elif os.path.exists(command_path):
  51. places.append(command_path)
  52. places.append(model_path)
  53. for place in places:
  54. for full_path in shared.walk_files(place, allowed_extensions=ext_filter):
  55. if os.path.islink(full_path) and not os.path.exists(full_path):
  56. print(f"Skipping broken symlink: {full_path}")
  57. continue
  58. if ext_blacklist is not None and any(full_path.endswith(x) for x in ext_blacklist):
  59. continue
  60. if full_path not in output:
  61. output.append(full_path)
  62. if model_url is not None and len(output) == 0:
  63. if download_name is not None:
  64. output.append(load_file_from_url(model_url, model_dir=places[0], file_name=download_name))
  65. else:
  66. output.append(model_url)
  67. except Exception:
  68. pass
  69. return output
  70. def friendly_name(file: str):
  71. if file.startswith("http"):
  72. file = urlparse(file).path
  73. file = os.path.basename(file)
  74. model_name, extension = os.path.splitext(file)
  75. return model_name
  76. def cleanup_models():
  77. # This code could probably be more efficient if we used a tuple list or something to store the src/destinations
  78. # and then enumerate that, but this works for now. In the future, it'd be nice to just have every "model" scaler
  79. # somehow auto-register and just do these things...
  80. root_path = script_path
  81. src_path = models_path
  82. dest_path = os.path.join(models_path, "Stable-diffusion")
  83. move_files(src_path, dest_path, ".ckpt")
  84. move_files(src_path, dest_path, ".safetensors")
  85. src_path = os.path.join(root_path, "ESRGAN")
  86. dest_path = os.path.join(models_path, "ESRGAN")
  87. move_files(src_path, dest_path)
  88. src_path = os.path.join(models_path, "BSRGAN")
  89. dest_path = os.path.join(models_path, "ESRGAN")
  90. move_files(src_path, dest_path, ".pth")
  91. src_path = os.path.join(root_path, "gfpgan")
  92. dest_path = os.path.join(models_path, "GFPGAN")
  93. move_files(src_path, dest_path)
  94. src_path = os.path.join(root_path, "SwinIR")
  95. dest_path = os.path.join(models_path, "SwinIR")
  96. move_files(src_path, dest_path)
  97. src_path = os.path.join(root_path, "repositories/latent-diffusion/experiments/pretrained_models/")
  98. dest_path = os.path.join(models_path, "LDSR")
  99. move_files(src_path, dest_path)
  100. def move_files(src_path: str, dest_path: str, ext_filter: str = None):
  101. try:
  102. os.makedirs(dest_path, exist_ok=True)
  103. if os.path.exists(src_path):
  104. for file in os.listdir(src_path):
  105. fullpath = os.path.join(src_path, file)
  106. if os.path.isfile(fullpath):
  107. if ext_filter is not None:
  108. if ext_filter not in file:
  109. continue
  110. print(f"Moving {file} from {src_path} to {dest_path}.")
  111. try:
  112. shutil.move(fullpath, dest_path)
  113. except Exception:
  114. pass
  115. if len(os.listdir(src_path)) == 0:
  116. print(f"Removing empty folder: {src_path}")
  117. shutil.rmtree(src_path, True)
  118. except Exception:
  119. pass
  120. def load_upscalers():
  121. # We can only do this 'magic' method to dynamically load upscalers if they are referenced,
  122. # so we'll try to import any _model.py files before looking in __subclasses__
  123. modules_dir = os.path.join(shared.script_path, "modules")
  124. for file in os.listdir(modules_dir):
  125. if "_model.py" in file:
  126. model_name = file.replace("_model.py", "")
  127. full_model = f"modules.{model_name}_model"
  128. try:
  129. importlib.import_module(full_model)
  130. except Exception:
  131. pass
  132. datas = []
  133. commandline_options = vars(shared.cmd_opts)
  134. # some of upscaler classes will not go away after reloading their modules, and we'll end
  135. # up with two copies of those classes. The newest copy will always be the last in the list,
  136. # so we go from end to beginning and ignore duplicates
  137. used_classes = {}
  138. for cls in reversed(Upscaler.__subclasses__()):
  139. classname = str(cls)
  140. if classname not in used_classes:
  141. used_classes[classname] = cls
  142. for cls in reversed(used_classes.values()):
  143. name = cls.__name__
  144. cmd_name = f"{name.lower().replace('upscaler', '')}_models_path"
  145. commandline_model_path = commandline_options.get(cmd_name, None)
  146. scaler = cls(commandline_model_path)
  147. scaler.user_path = commandline_model_path
  148. scaler.model_download_path = commandline_model_path or scaler.model_path
  149. datas += scaler.scalers
  150. shared.sd_upscalers = sorted(
  151. datas,
  152. # Special case for UpscalerNone keeps it at the beginning of the list.
  153. key=lambda x: x.name.lower() if not isinstance(x.scaler, (UpscalerNone, UpscalerLanczos, UpscalerNearest)) else ""
  154. )
  155. def load_spandrel_model(
  156. path: str,
  157. *,
  158. device: str | torch.device | None,
  159. half: bool = False,
  160. dtype: str | None = None,
  161. expected_architecture: str | None = None,
  162. ):
  163. import spandrel
  164. model = spandrel.ModelLoader(device=device).load_from_file(path)
  165. if expected_architecture and model.architecture != expected_architecture:
  166. raise TypeError(f"Model {path} is not a {expected_architecture} model")
  167. if half:
  168. model = model.model.half()
  169. if dtype:
  170. model = model.model.to(dtype=dtype)
  171. model.eval()
  172. logger.debug("Loaded %s from %s (device=%s, half=%s, dtype=%s)", model, path, device, half, dtype)
  173. return model