Răsfoiți Sursa

add checkpoints tab for extra networks UI

AUTOMATIC 2 ani în urmă
părinte
comite
1d8e06d542

+ 1 - 1
extensions-builtin/Lora/ui_extra_networks_lora.py

@@ -20,7 +20,7 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage):
             preview = None
             for file in previews:
                 if os.path.isfile(file):
-                    preview = "./file=" + file.replace('\\', '/') + "?mtime=" + str(os.path.getmtime(file))
+                    preview = self.link_preview(file)
                     break
 
             yield {

+ 7 - 0
javascript/ui.js

@@ -309,3 +309,10 @@ function updateInput(target){
 	Object.defineProperty(e, "target", {value: target})
 	target.dispatchEvent(e);
 }
+
+
+var desiredCheckpointName = null;
+function selectCheckpoint(name){
+    desiredCheckpointName = name;
+    gradioApp().getElementById('change_checkpoint').click()
+}

+ 8 - 0
modules/ui.py

@@ -1560,6 +1560,14 @@ def create_ui():
                 outputs=[component, text_settings],
             )
 
+        button_set_checkpoint = gr.Button('Change checkpoint', elem_id='change_checkpoint', visible=False)
+        button_set_checkpoint.click(
+            fn=lambda value, _: run_settings_single(value, key='sd_model_checkpoint'),
+            _js="function(v){ var res = desiredCheckpointName; desiredCheckpointName = ''; return [res || v, null]; }",
+            inputs=[component_dict['sd_model_checkpoint'], dummy_component],
+            outputs=[component_dict['sd_model_checkpoint'], text_settings],
+        )
+
         component_keys = [k for k in opts.data_labels.keys() if k in component_dict]
 
         def get_settings_values():

+ 33 - 4
modules/ui_extra_networks.py

@@ -1,4 +1,6 @@
 import os.path
+import urllib.parse
+from pathlib import Path
 
 from modules import shared
 import gradio as gr
@@ -8,12 +10,31 @@ import html
 from modules.generation_parameters_copypaste import image_from_url_text
 
 extra_pages = []
+allowed_dirs = set()
 
 
 def register_page(page):
     """registers extra networks page for the UI; recommend doing it in on_before_ui() callback for extensions"""
 
     extra_pages.append(page)
+    allowed_dirs.clear()
+    allowed_dirs.update(set(sum([x.allowed_directories_for_previews() for x in extra_pages], [])))
+
+
+def add_pages_to_demo(app):
+    def fetch_file(filename: str = ""):
+        from starlette.responses import FileResponse
+
+        if not any([Path(x).resolve() in Path(filename).resolve().parents for x in allowed_dirs]):
+            raise ValueError(f"File cannot be fetched: {filename}. Must be in one of directories registered by extra pages.")
+
+        if os.path.splitext(filename)[1].lower() != ".png":
+            raise ValueError(f"File cannot be fetched: {filename}. Only png.")
+
+        # would profit from returning 304
+        return FileResponse(filename, headers={"Accept-Ranges": "bytes"})
+
+    app.add_api_route("/sd_extra_networks/thumb", fetch_file, methods=["GET"])
 
 
 class ExtraNetworksPage:
@@ -26,6 +47,9 @@ class ExtraNetworksPage:
     def refresh(self):
         pass
 
+    def link_preview(self, filename):
+        return "./sd_extra_networks/thumb?filename=" + urllib.parse.quote(filename.replace('\\', '/')) + "&mtime=" + str(os.path.getmtime(filename))
+
     def create_html(self, tabname):
         view = shared.opts.extra_networks_default_view
         items_html = ''
@@ -54,13 +78,17 @@ class ExtraNetworksPage:
     def create_html_for_item(self, item, tabname):
         preview = item.get("preview", None)
 
+        onclick = item.get("onclick", None)
+        if onclick is None:
+            onclick = '"' + html.escape(f"""return cardClicked({json.dumps(tabname)}, {item["prompt"]}, {"true" if self.allow_negative_prompt else "false"})""") + '"'
+
         args = {
             "preview_html": "style='background-image: url(\"" + html.escape(preview) + "\")'" if preview else '',
-            "prompt": item["prompt"],
+            "prompt": item.get("prompt", None),
             "tabname": json.dumps(tabname),
             "local_preview": json.dumps(item["local_preview"]),
             "name": item["name"],
-            "card_clicked": '"' + html.escape(f"""return cardClicked({json.dumps(tabname)}, {item["prompt"]}, {"true" if self.allow_negative_prompt else "false"})""") + '"',
+            "card_clicked": onclick,
             "save_card_preview": '"' + html.escape(f"""return saveCardPreview(event, {json.dumps(tabname)}, {json.dumps(item["local_preview"])})""") + '"',
         }
 
@@ -143,7 +171,7 @@ def path_is_parent(parent_path, child_path):
     parent_path = os.path.abspath(parent_path)
     child_path = os.path.abspath(child_path)
 
-    return os.path.commonpath([parent_path]) == os.path.commonpath([parent_path, child_path])
+    return child_path.startswith(parent_path)
 
 
 def setup_ui(ui, gallery):
@@ -173,7 +201,8 @@ def setup_ui(ui, gallery):
 
     ui.button_save_preview.click(
         fn=save_preview,
-        _js="function(x, y, z){console.log(x, y, z); return [selected_gallery_index(), y, z]}",
+        _js="function(x, y, z){return [selected_gallery_index(), y, z]}",
         inputs=[ui.preview_target_filename, gallery, ui.preview_target_filename],
         outputs=[*ui.pages]
     )
+

+ 38 - 0
modules/ui_extra_networks_checkpoints.py

@@ -0,0 +1,38 @@
+import html
+import json
+import os
+import urllib.parse
+
+from modules import shared, ui_extra_networks, sd_models
+
+
+class ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage):
+    def __init__(self):
+        super().__init__('Checkpoints')
+
+    def refresh(self):
+        shared.refresh_checkpoints()
+
+    def list_items(self):
+        for name, checkpoint1 in sd_models.checkpoints_list.items():
+            checkpoint: sd_models.CheckpointInfo = checkpoint1
+            path, ext = os.path.splitext(checkpoint.filename)
+            previews = [path + ".png", path + ".preview.png"]
+
+            preview = None
+            for file in previews:
+                if os.path.isfile(file):
+                    preview = self.link_preview(file)
+                    break
+
+            yield {
+                "name": checkpoint.model_name,
+                "filename": path,
+                "preview": preview,
+                "onclick": '"' + html.escape(f"""return selectCheckpoint({json.dumps(name)})""") + '"',
+                "local_preview": path + ".png",
+            }
+
+    def allowed_directories_for_previews(self):
+        return [shared.cmd_opts.ckpt_dir, sd_models.model_path]
+

+ 1 - 1
modules/ui_extra_networks_hypernets.py

@@ -19,7 +19,7 @@ class ExtraNetworksPageHypernetworks(ui_extra_networks.ExtraNetworksPage):
             preview = None
             for file in previews:
                 if os.path.isfile(file):
-                    preview = "./file=" + file.replace('\\', '/') + "?mtime=" + str(os.path.getmtime(file))
+                    preview = self.link_preview(file)
                     break
 
             yield {

+ 1 - 1
modules/ui_extra_networks_textual_inversion.py

@@ -19,7 +19,7 @@ class ExtraNetworksPageTextualInversion(ui_extra_networks.ExtraNetworksPage):
 
             preview = None
             if os.path.isfile(preview_file):
-                preview = "./file=" + preview_file.replace('\\', '/') + "?mtime=" + str(os.path.getmtime(preview_file))
+                preview = self.link_preview(preview_file)
 
             yield {
                 "name": embedding.name,

+ 5 - 1
webui.py

@@ -12,7 +12,7 @@ from packaging import version
 import logging
 logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage())
 
-from modules import import_hook, errors, extra_networks
+from modules import import_hook, errors, extra_networks, ui_extra_networks_checkpoints
 from modules import extra_networks_hypernet, ui_extra_networks_hypernets, ui_extra_networks_textual_inversion
 from modules.call_queue import wrap_queued_call, queue_lock, wrap_gradio_gpu_call
 
@@ -119,6 +119,7 @@ def initialize():
     ui_extra_networks.intialize()
     ui_extra_networks.register_page(ui_extra_networks_textual_inversion.ExtraNetworksPageTextualInversion())
     ui_extra_networks.register_page(ui_extra_networks_hypernets.ExtraNetworksPageHypernetworks())
+    ui_extra_networks.register_page(ui_extra_networks_checkpoints.ExtraNetworksPageCheckpoints())
 
     extra_networks.initialize()
     extra_networks.register_extra_network(extra_networks_hypernet.ExtraNetworkHypernet())
@@ -227,6 +228,8 @@ def webui():
         if launch_api:
             create_api(app)
 
+        ui_extra_networks.add_pages_to_demo(app)
+
         modules.script_callbacks.app_started_callback(shared.demo, app)
 
         wait_on_server(shared.demo)
@@ -254,6 +257,7 @@ def webui():
         ui_extra_networks.intialize()
         ui_extra_networks.register_page(ui_extra_networks_textual_inversion.ExtraNetworksPageTextualInversion())
         ui_extra_networks.register_page(ui_extra_networks_hypernets.ExtraNetworksPageHypernetworks())
+        ui_extra_networks.register_page(ui_extra_networks_checkpoints.ExtraNetworksPageCheckpoints())
 
         extra_networks.initialize()
         extra_networks.register_extra_network(extra_networks_hypernet.ExtraNetworkHypernet())