util.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336
  1. import os
  2. import re
  3. from modules import shared
  4. from modules.paths_internal import script_path, cwd
  5. def natural_sort_key(s, regex=re.compile('([0-9]+)')):
  6. return [int(text) if text.isdigit() else text.lower() for text in regex.split(s)]
  7. def listfiles(dirname):
  8. filenames = [os.path.join(dirname, x) for x in sorted(os.listdir(dirname), key=natural_sort_key) if not x.startswith(".")]
  9. return [file for file in filenames if os.path.isfile(file)]
  10. def html_path(filename):
  11. return os.path.join(script_path, "html", filename)
  12. def html(filename):
  13. path = html_path(filename)
  14. try:
  15. with open(path, encoding="utf8") as file:
  16. return file.read()
  17. except OSError:
  18. return ""
  19. def walk_files(path, allowed_extensions=None):
  20. if not os.path.exists(path):
  21. return
  22. if allowed_extensions is not None:
  23. allowed_extensions = set(allowed_extensions)
  24. items = list(os.walk(path, followlinks=True))
  25. items = sorted(items, key=lambda x: natural_sort_key(x[0]))
  26. for root, _, files in items:
  27. for filename in sorted(files, key=natural_sort_key):
  28. if allowed_extensions is not None:
  29. _, ext = os.path.splitext(filename)
  30. if ext.lower() not in allowed_extensions:
  31. continue
  32. if not shared.opts.list_hidden_files and ("/." in root or "\\." in root):
  33. continue
  34. yield os.path.join(root, filename)
  35. def ldm_print(*args, **kwargs):
  36. if shared.opts.hide_ldm_prints:
  37. return
  38. print(*args, **kwargs)
  39. def truncate_path(target_path, base_path=cwd):
  40. abs_target, abs_base = os.path.abspath(target_path), os.path.abspath(base_path)
  41. try:
  42. if os.path.commonpath([abs_target, abs_base]) == abs_base:
  43. return os.path.relpath(abs_target, abs_base)
  44. except ValueError:
  45. pass
  46. return abs_target
  47. class MassFileListerCachedDir:
  48. """A class that caches file metadata for a specific directory."""
  49. def __init__(self, dirname):
  50. self.files = None
  51. self.files_cased = None
  52. self.dirname = dirname
  53. stats = ((x.name, x.stat(follow_symlinks=False)) for x in os.scandir(self.dirname))
  54. files = [(n, s.st_mtime, s.st_ctime) for n, s in stats]
  55. self.files = {x[0].lower(): x for x in files}
  56. self.files_cased = {x[0]: x for x in files}
  57. def update_entry(self, filename):
  58. """Add a file to the cache"""
  59. file_path = os.path.join(self.dirname, filename)
  60. try:
  61. stat = os.stat(file_path)
  62. entry = (filename, stat.st_mtime, stat.st_ctime)
  63. self.files[filename.lower()] = entry
  64. self.files_cased[filename] = entry
  65. except FileNotFoundError as e:
  66. print(f'MassFileListerCachedDir.add_entry: "{file_path}" {e}')
  67. class MassFileLister:
  68. """A class that provides a way to check for the existence and mtime/ctile of files without doing more than one stat call per file."""
  69. def __init__(self):
  70. self.cached_dirs = {}
  71. def find(self, path):
  72. """
  73. Find the metadata for a file at the given path.
  74. Returns:
  75. tuple or None: A tuple of (name, mtime, ctime) if the file exists, or None if it does not.
  76. """
  77. dirname, filename = os.path.split(path)
  78. cached_dir = self.cached_dirs.get(dirname)
  79. if cached_dir is None:
  80. cached_dir = MassFileListerCachedDir(dirname)
  81. self.cached_dirs[dirname] = cached_dir
  82. stats = cached_dir.files_cased.get(filename)
  83. if stats is not None:
  84. return stats
  85. stats = cached_dir.files.get(filename.lower())
  86. if stats is None:
  87. return None
  88. try:
  89. os_stats = os.stat(path, follow_symlinks=False)
  90. return filename, os_stats.st_mtime, os_stats.st_ctime
  91. except Exception:
  92. return None
  93. def exists(self, path):
  94. """Check if a file exists at the given path."""
  95. return self.find(path) is not None
  96. def mctime(self, path):
  97. """
  98. Get the modification and creation times for a file at the given path.
  99. Returns:
  100. tuple: A tuple of (mtime, ctime) if the file exists, or (0, 0) if it does not.
  101. """
  102. stats = self.find(path)
  103. return (0, 0) if stats is None else stats[1:3]
  104. def reset(self):
  105. """Clear the cache of all directories."""
  106. self.cached_dirs.clear()
  107. def update_file_entry(self, path):
  108. """Update the cache for a specific directory."""
  109. dirname, filename = os.path.split(path)
  110. if cached_dir := self.cached_dirs.get(dirname):
  111. cached_dir.update_entry(filename)
  112. def topological_sort(dependencies):
  113. """Accepts a dictionary mapping name to its dependencies, returns a list of names ordered according to dependencies.
  114. Ignores errors relating to missing dependencies or circular dependencies
  115. """
  116. visited = {}
  117. result = []
  118. def inner(name):
  119. visited[name] = True
  120. for dep in dependencies.get(name, []):
  121. if dep in dependencies and dep not in visited:
  122. inner(dep)
  123. result.append(name)
  124. for depname in dependencies:
  125. if depname not in visited:
  126. inner(depname)
  127. return result
  128. def open_folder(path):
  129. """Open a folder in the file manager of the respect OS."""
  130. # import at function level to avoid potential issues
  131. import gradio as gr
  132. import platform
  133. import sys
  134. import subprocess
  135. if not os.path.exists(path):
  136. msg = f'Folder "{path}" does not exist. after you save an image, the folder will be created.'
  137. print(msg)
  138. gr.Info(msg)
  139. return
  140. elif not os.path.isdir(path):
  141. msg = f"""
  142. WARNING
  143. An open_folder request was made with an path that is not a folder.
  144. This could be an error or a malicious attempt to run code on your computer.
  145. Requested path was: {path}
  146. """
  147. print(msg, file=sys.stderr)
  148. gr.Warning(msg)
  149. return
  150. path = os.path.normpath(path)
  151. if platform.system() == "Windows":
  152. os.startfile(path)
  153. elif platform.system() == "Darwin":
  154. subprocess.Popen(["open", path])
  155. elif "microsoft-standard-WSL2" in platform.uname().release:
  156. subprocess.Popen(["explorer.exe", subprocess.check_output(["wslpath", "-w", path])])
  157. else:
  158. subprocess.Popen(["xdg-open", path])
  159. def load_file_from_url(
  160. url: str,
  161. *,
  162. model_dir: str,
  163. progress: bool = True,
  164. file_name: str | None = None,
  165. hash_prefix: str | None = None,
  166. re_download: bool = False,
  167. ) -> str:
  168. """Download a file from `url` into `model_dir`, using the file present if possible.
  169. Returns the path to the downloaded file.
  170. file_name: if specified, it will be used as the filename, otherwise the filename will be extracted from the url.
  171. file is downloaded to {file_name}.tmp then moved to the final location after download is complete.
  172. hash_prefix: sha256 hex string, if provided, the hash of the downloaded file will be checked against this prefix.
  173. if the hash does not match, the temporary file is deleted and a ValueError is raised.
  174. re_download: forcibly re-download the file even if it already exists.
  175. """
  176. from urllib.parse import urlparse
  177. import requests
  178. try:
  179. from tqdm import tqdm
  180. except ImportError:
  181. class tqdm:
  182. def __init__(self, *args, **kwargs):
  183. pass
  184. def update(self, n=1, *args, **kwargs):
  185. pass
  186. def __enter__(self):
  187. return self
  188. def __exit__(self, exc_type, exc_val, exc_tb):
  189. pass
  190. if not file_name:
  191. parts = urlparse(url)
  192. file_name = os.path.basename(parts.path)
  193. cached_file = os.path.abspath(os.path.join(model_dir, file_name))
  194. if re_download or not os.path.exists(cached_file):
  195. os.makedirs(model_dir, exist_ok=True)
  196. temp_file = os.path.join(model_dir, f"{file_name}.tmp")
  197. print(f'\nDownloading: "{url}" to {cached_file}')
  198. response = requests.get(url, stream=True)
  199. response.raise_for_status()
  200. total_size = int(response.headers.get('content-length', 0))
  201. with tqdm(total=total_size, unit='B', unit_scale=True, desc=file_name, disable=not progress) as progress_bar:
  202. with open(temp_file, 'wb') as file:
  203. for chunk in response.iter_content(chunk_size=1024):
  204. if chunk:
  205. file.write(chunk)
  206. progress_bar.update(len(chunk))
  207. if hash_prefix and not compare_sha256(temp_file, hash_prefix):
  208. print(f"Hash mismatch for {temp_file}. Deleting the temporary file.")
  209. os.remove(temp_file)
  210. raise ValueError(f"File hash does not match the expected hash prefix {hash_prefix}!")
  211. os.rename(temp_file, cached_file)
  212. return cached_file
  213. def compare_sha256(file_path: str, hash_prefix: str) -> bool:
  214. """Check if the SHA256 hash of the file matches the given prefix."""
  215. import hashlib
  216. hash_sha256 = hashlib.sha256()
  217. blksize = 1024 * 1024
  218. with open(file_path, "rb") as f:
  219. for chunk in iter(lambda: f.read(blksize), b""):
  220. hash_sha256.update(chunk)
  221. return hash_sha256.hexdigest().startswith(hash_prefix.strip().lower())
  222. class GenerationParametersList(list):
  223. """A special object used in sd_hijack.StableDiffusionModelHijack for setting extra_generation_params
  224. due to StableDiffusionProcessing.get_conds_with_caching
  225. extra_generation_params set in StableDiffusionModelHijack will be lost when cached is used
  226. When an extra_generation_params is set in StableDiffusionModelHijack using this object,
  227. the params will be extracted by StableDiffusionModelHijack.extract_generation_params_states
  228. the extracted params will be cached in StableDiffusionProcessing.get_conds_with_caching
  229. and applyed to StableDiffusionProcessing.extra_generation_params by StableDiffusionProcessing.apply_generation_params_states
  230. Example see modules.sd_hijack_clip.TextConditionalModel.hijack.extra_generation_params 'TI hashes' 'Emphasis'
  231. Depending on the use case the methods can be overwritten.
  232. In general __call__ method should return str or None, as normally it's called in modules.processing.create_infotext.
  233. When called by create_infotext it will access to the locals() of the caller,
  234. if return str, the value will be written to infotext, if return None will be ignored.
  235. """
  236. def __init__(self, *args, to_be_clear_before_batch=True, **kwargs):
  237. super().__init__(*args, **kwargs)
  238. self._to_be_clear_before_batch = to_be_clear_before_batch
  239. def __call__(self, *args, **kwargs):
  240. return ', '.join(sorted(set(self), key=natural_sort_key))
  241. @property
  242. def to_be_clear_before_batch(self):
  243. return self._to_be_clear_before_batch
  244. def __add__(self, other):
  245. if isinstance(other, GenerationParametersList):
  246. return self.__class__([*self, *other])
  247. elif isinstance(other, str):
  248. return self.__str__() + other
  249. return NotImplemented
  250. def __radd__(self, other):
  251. if isinstance(other, str):
  252. return other + self.__str__()
  253. return NotImplemented
  254. def __str__(self):
  255. return self.__call__()