ui_extra_networks.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396
  1. import os.path
  2. import urllib.parse
  3. from pathlib import Path
  4. from modules import shared
  5. from modules.images import read_info_from_image, save_image_with_geninfo
  6. from modules.ui import up_down_symbol
  7. import gradio as gr
  8. import json
  9. import html
  10. from modules.generation_parameters_copypaste import image_from_url_text
  11. extra_pages = []
  12. allowed_dirs = set()
  13. def register_page(page):
  14. """registers extra networks page for the UI; recommend doing it in on_before_ui() callback for extensions"""
  15. extra_pages.append(page)
  16. allowed_dirs.clear()
  17. allowed_dirs.update(set(sum([x.allowed_directories_for_previews() for x in extra_pages], [])))
  18. def fetch_file(filename: str = ""):
  19. from starlette.responses import FileResponse
  20. if not any(Path(x).absolute() in Path(filename).absolute().parents for x in allowed_dirs):
  21. raise ValueError(f"File cannot be fetched: {filename}. Must be in one of directories registered by extra pages.")
  22. ext = os.path.splitext(filename)[1].lower()
  23. if ext not in (".png", ".jpg", ".jpeg", ".webp"):
  24. raise ValueError(f"File cannot be fetched: {filename}. Only png and jpg and webp.")
  25. # would profit from returning 304
  26. return FileResponse(filename, headers={"Accept-Ranges": "bytes"})
  27. def get_metadata(page: str = "", item: str = ""):
  28. from starlette.responses import JSONResponse
  29. page = next(iter([x for x in extra_pages if x.name == page]), None)
  30. if page is None:
  31. return JSONResponse({})
  32. metadata = page.metadata.get(item)
  33. if metadata is None:
  34. return JSONResponse({})
  35. return JSONResponse({"metadata": metadata})
  36. def add_pages_to_demo(app):
  37. app.add_api_route("/sd_extra_networks/thumb", fetch_file, methods=["GET"])
  38. app.add_api_route("/sd_extra_networks/metadata", get_metadata, methods=["GET"])
  39. class ExtraNetworksPage:
  40. def __init__(self, title):
  41. self.title = title
  42. self.name = title.lower()
  43. self.card_page = shared.html("extra-networks-card.html")
  44. self.allow_negative_prompt = False
  45. self.metadata = {}
  46. def refresh(self):
  47. pass
  48. def link_preview(self, filename):
  49. quoted_filename = urllib.parse.quote(filename.replace('\\', '/'))
  50. mtime = os.path.getmtime(filename)
  51. return f"./sd_extra_networks/thumb?filename={quoted_filename}&mtime={mtime}"
  52. def search_terms_from_path(self, filename, possible_directories=None):
  53. abspath = os.path.abspath(filename)
  54. for parentdir in (possible_directories if possible_directories is not None else self.allowed_directories_for_previews()):
  55. parentdir = os.path.abspath(parentdir)
  56. if abspath.startswith(parentdir):
  57. return abspath[len(parentdir):].replace('\\', '/')
  58. return ""
  59. def create_html(self, tabname):
  60. view = shared.opts.extra_networks_default_view
  61. items_html = ''
  62. self.metadata = {}
  63. subdirs = {}
  64. for parentdir in [os.path.abspath(x) for x in self.allowed_directories_for_previews()]:
  65. for root, dirs, _ in os.walk(parentdir, followlinks=True):
  66. for dirname in dirs:
  67. x = os.path.join(root, dirname)
  68. if not os.path.isdir(x):
  69. continue
  70. subdir = os.path.abspath(x)[len(parentdir):].replace("\\", "/")
  71. while subdir.startswith("/"):
  72. subdir = subdir[1:]
  73. is_empty = len(os.listdir(x)) == 0
  74. if not is_empty and not subdir.endswith("/"):
  75. subdir = subdir + "/"
  76. if ("/." in subdir or subdir.startswith(".")) and not shared.opts.extra_networks_show_hidden_directories:
  77. continue
  78. subdirs[subdir] = 1
  79. if subdirs:
  80. subdirs = {"": 1, **subdirs}
  81. subdirs_html = "".join([f"""
  82. <button class='lg secondary gradio-button custom-button{" search-all" if subdir=="" else ""}' onclick='extraNetworksSearchButton("{tabname}_extra_tabs", event)'>
  83. {html.escape(subdir if subdir!="" else "all")}
  84. </button>
  85. """ for subdir in subdirs])
  86. for item in self.list_items():
  87. metadata = item.get("metadata")
  88. if metadata:
  89. self.metadata[item["name"]] = metadata
  90. items_html += self.create_html_for_item(item, tabname)
  91. if items_html == '':
  92. dirs = "".join([f"<li>{x}</li>" for x in self.allowed_directories_for_previews()])
  93. items_html = shared.html("extra-networks-no-cards.html").format(dirs=dirs)
  94. self_name_id = self.name.replace(" ", "_")
  95. res = f"""
  96. <div id='{tabname}_{self_name_id}_subdirs' class='extra-network-subdirs extra-network-subdirs-{view}'>
  97. {subdirs_html}
  98. </div>
  99. <div id='{tabname}_{self_name_id}_cards' class='extra-network-{view}'>
  100. {items_html}
  101. </div>
  102. """
  103. return res
  104. def list_items(self):
  105. raise NotImplementedError()
  106. def allowed_directories_for_previews(self):
  107. return []
  108. def create_html_for_item(self, item, tabname):
  109. """
  110. Create HTML for card item in tab tabname; can return empty string if the item is not meant to be shown.
  111. """
  112. preview = item.get("preview", None)
  113. onclick = item.get("onclick", None)
  114. if onclick is None:
  115. onclick = '"' + html.escape(f"""return cardClicked({json.dumps(tabname)}, {item["prompt"]}, {"true" if self.allow_negative_prompt else "false"})""") + '"'
  116. height = f"height: {shared.opts.extra_networks_card_height}px;" if shared.opts.extra_networks_card_height else ''
  117. width = f"width: {shared.opts.extra_networks_card_width}px;" if shared.opts.extra_networks_card_width else ''
  118. background_image = f'<img src="{html.escape(preview)}" class="preview" loading="lazy">' if preview else ''
  119. metadata_button = ""
  120. metadata = item.get("metadata")
  121. if metadata:
  122. metadata_button = f"<div class='metadata-button' title='Show metadata' onclick='extraNetworksRequestMetadata(event, {json.dumps(self.name)}, {json.dumps(item['name'])})'></div>"
  123. local_path = ""
  124. filename = item.get("filename", "")
  125. for reldir in self.allowed_directories_for_previews():
  126. absdir = os.path.abspath(reldir)
  127. if filename.startswith(absdir):
  128. local_path = filename[len(absdir):]
  129. # if this is true, the item must not be shown in the default view, and must instead only be
  130. # shown when searching for it
  131. if shared.opts.extra_networks_hidden_models == "Always":
  132. search_only = False
  133. else:
  134. search_only = "/." in local_path or "\\." in local_path
  135. if search_only and shared.opts.extra_networks_hidden_models == "Never":
  136. return ""
  137. sort_keys = " ".join([html.escape(f'data-sort-{k}={v}') for k, v in item.get("sort_keys", {}).items()]).strip()
  138. args = {
  139. "background_image": background_image,
  140. "style": f"'display: none; {height}{width}'",
  141. "prompt": item.get("prompt", None),
  142. "tabname": json.dumps(tabname),
  143. "local_preview": json.dumps(item["local_preview"]),
  144. "name": item["name"],
  145. "description": (item.get("description") or ""),
  146. "card_clicked": onclick,
  147. "save_card_preview": '"' + html.escape(f"""return saveCardPreview(event, {json.dumps(tabname)}, {json.dumps(item["local_preview"])})""") + '"',
  148. "search_term": item.get("search_term", ""),
  149. "metadata_button": metadata_button,
  150. "search_only": " search_only" if search_only else "",
  151. "sort_keys": sort_keys,
  152. }
  153. return self.card_page.format(**args)
  154. def get_sort_keys(self, path):
  155. """
  156. List of default keys used for sorting in the UI.
  157. """
  158. pth = Path(path)
  159. stat = pth.stat()
  160. return {
  161. "date_created": int(stat.st_ctime or 0),
  162. "date_modified": int(stat.st_mtime or 0),
  163. "name": pth.name.lower(),
  164. }
  165. def find_preview(self, path):
  166. """
  167. Find a preview PNG for a given path (without extension) and call link_preview on it.
  168. """
  169. preview_extensions = ["png", "jpg", "jpeg", "webp"]
  170. if shared.opts.samples_format not in preview_extensions:
  171. preview_extensions.append(shared.opts.samples_format)
  172. potential_files = sum([[path + "." + ext, path + ".preview." + ext] for ext in preview_extensions], [])
  173. for file in potential_files:
  174. if os.path.isfile(file):
  175. return self.link_preview(file)
  176. return None
  177. def find_description(self, path):
  178. """
  179. Find and read a description file for a given path (without extension).
  180. """
  181. for file in [f"{path}.txt", f"{path}.description.txt"]:
  182. try:
  183. with open(file, "r", encoding="utf-8", errors="replace") as f:
  184. return f.read()
  185. except OSError:
  186. pass
  187. return None
  188. def initialize():
  189. extra_pages.clear()
  190. def register_default_pages():
  191. from modules.ui_extra_networks_textual_inversion import ExtraNetworksPageTextualInversion
  192. from modules.ui_extra_networks_hypernets import ExtraNetworksPageHypernetworks
  193. from modules.ui_extra_networks_checkpoints import ExtraNetworksPageCheckpoints
  194. register_page(ExtraNetworksPageTextualInversion())
  195. register_page(ExtraNetworksPageHypernetworks())
  196. register_page(ExtraNetworksPageCheckpoints())
  197. class ExtraNetworksUi:
  198. def __init__(self):
  199. self.pages = None
  200. """gradio HTML components related to extra networks' pages"""
  201. self.page_contents = None
  202. """HTML content of the above; empty initially, filled when extra pages have to be shown"""
  203. self.stored_extra_pages = None
  204. self.button_save_preview = None
  205. self.preview_target_filename = None
  206. self.tabname = None
  207. def pages_in_preferred_order(pages):
  208. tab_order = [x.lower().strip() for x in shared.opts.ui_extra_networks_tab_reorder.split(",")]
  209. def tab_name_score(name):
  210. name = name.lower()
  211. for i, possible_match in enumerate(tab_order):
  212. if possible_match in name:
  213. return i
  214. return len(pages)
  215. tab_scores = {page.name: (tab_name_score(page.name), original_index) for original_index, page in enumerate(pages)}
  216. return sorted(pages, key=lambda x: tab_scores[x.name])
  217. def create_ui(container, button, tabname):
  218. ui = ExtraNetworksUi()
  219. ui.pages = []
  220. ui.pages_contents = []
  221. ui.stored_extra_pages = pages_in_preferred_order(extra_pages.copy())
  222. ui.tabname = tabname
  223. with gr.Tabs(elem_id=tabname+"_extra_tabs"):
  224. for page in ui.stored_extra_pages:
  225. page_id = page.title.lower().replace(" ", "_")
  226. with gr.Tab(page.title, id=page_id):
  227. elem_id = f"{tabname}_{page_id}_cards_html"
  228. page_elem = gr.HTML('Loading...', elem_id=elem_id)
  229. ui.pages.append(page_elem)
  230. page_elem.change(fn=lambda: None, _js='function(){applyExtraNetworkFilter(' + json.dumps(tabname) + '); return []}', inputs=[], outputs=[])
  231. gr.Textbox('', show_label=False, elem_id=tabname+"_extra_search", placeholder="Search...", visible=False)
  232. gr.Dropdown(choices=['Default Sort', 'Date Created', 'Date Modified', 'Name'], value='Default Sort', elem_id=tabname+"_extra_sort", multiselect=False, visible=False, show_label=False, interactive=True)
  233. gr.Button(up_down_symbol, elem_id=tabname+"_extra_sortorder")
  234. button_refresh = gr.Button('Refresh', elem_id=tabname+"_extra_refresh")
  235. ui.button_save_preview = gr.Button('Save preview', elem_id=tabname+"_save_preview", visible=False)
  236. ui.preview_target_filename = gr.Textbox('Preview save filename', elem_id=tabname+"_preview_filename", visible=False)
  237. def toggle_visibility(is_visible):
  238. is_visible = not is_visible
  239. return is_visible, gr.update(visible=is_visible), gr.update(variant=("secondary-down" if is_visible else "secondary"))
  240. def fill_tabs(is_empty):
  241. """Creates HTML for extra networks' tabs when the extra networks button is clicked for the first time."""
  242. if not ui.pages_contents:
  243. refresh()
  244. if is_empty:
  245. return True, *ui.pages_contents
  246. return True, *[gr.update() for _ in ui.pages_contents]
  247. state_visible = gr.State(value=False)
  248. button.click(fn=toggle_visibility, inputs=[state_visible], outputs=[state_visible, container, button], show_progress=False)
  249. state_empty = gr.State(value=True)
  250. button.click(fn=fill_tabs, inputs=[state_empty], outputs=[state_empty, *ui.pages], show_progress=False)
  251. def refresh():
  252. for pg in ui.stored_extra_pages:
  253. pg.refresh()
  254. ui.pages_contents = [pg.create_html(ui.tabname) for pg in ui.stored_extra_pages]
  255. return ui.pages_contents
  256. button_refresh.click(fn=refresh, inputs=[], outputs=ui.pages)
  257. return ui
  258. def path_is_parent(parent_path, child_path):
  259. parent_path = os.path.abspath(parent_path)
  260. child_path = os.path.abspath(child_path)
  261. return child_path.startswith(parent_path)
  262. def setup_ui(ui, gallery):
  263. def save_preview(index, images, filename):
  264. if len(images) == 0:
  265. print("There is no image in gallery to save as a preview.")
  266. return [page.create_html(ui.tabname) for page in ui.stored_extra_pages]
  267. index = int(index)
  268. index = 0 if index < 0 else index
  269. index = len(images) - 1 if index >= len(images) else index
  270. img_info = images[index if index >= 0 else 0]
  271. image = image_from_url_text(img_info)
  272. geninfo, items = read_info_from_image(image)
  273. is_allowed = False
  274. for extra_page in ui.stored_extra_pages:
  275. if any(path_is_parent(x, filename) for x in extra_page.allowed_directories_for_previews()):
  276. is_allowed = True
  277. break
  278. assert is_allowed, f'writing to {filename} is not allowed'
  279. save_image_with_geninfo(image, geninfo, filename)
  280. return [page.create_html(ui.tabname) for page in ui.stored_extra_pages]
  281. ui.button_save_preview.click(
  282. fn=save_preview,
  283. _js="function(x, y, z){return [selected_gallery_index(), y, z]}",
  284. inputs=[ui.preview_target_filename, gallery, ui.preview_target_filename],
  285. outputs=[*ui.pages]
  286. )