Эх сурвалжийг харах

directory hiding for extra networks: dirs starting with . will hide their cards on extra network tabs unless specifically searched for
create HTML for extra network pages only on demand
allow directories starting with . to still list their models for lora, checkpoints, etc
keep "search" filter for extra networks when user refreshes the page

AUTOMATIC 2 жил өмнө
parent
commit
083dc3c76a

+ 1 - 5
extensions-builtin/Lora/lora.py

@@ -352,11 +352,7 @@ def list_available_loras():
 
     os.makedirs(shared.cmd_opts.lora_dir, exist_ok=True)
 
-    candidates = \
-        glob.glob(os.path.join(shared.cmd_opts.lora_dir, '**/*.pt'), recursive=True) + \
-        glob.glob(os.path.join(shared.cmd_opts.lora_dir, '**/*.safetensors'), recursive=True) + \
-        glob.glob(os.path.join(shared.cmd_opts.lora_dir, '**/*.ckpt'), recursive=True)
-
+    candidates = list(shared.walk_files(shared.cmd_opts.lora_dir, allowed_extensions=[".pt", ".ckpt", ".safetensors"]))
     for filename in sorted(candidates, key=str.lower):
         if os.path.isdir(filename):
             continue

+ 1 - 1
html/extra-networks-card.html

@@ -6,7 +6,7 @@
 			<ul>
 				<a href="#" title="replace preview image with currently selected in gallery" onclick={save_card_preview}>replace preview</a>
 			</ul>
-			<span style="display:none" class='search_term'>{search_term}</span>
+			<span style="display:none" class='search_term{serach_only}'>{search_term}</span>
 		</div>
 		<span class='name'>{name}</span>
 		<span class='description'>{description}</span>

+ 21 - 4
javascript/extraNetworks.js

@@ -1,4 +1,3 @@
-
 function setupExtraNetworksForTab(tabname){
     gradioApp().querySelector('#'+tabname+'_extra_tabs').classList.add('extra-networks')
 
@@ -10,16 +9,34 @@ function setupExtraNetworksForTab(tabname){
     tabs.appendChild(search)
     tabs.appendChild(refresh)
 
-    search.addEventListener("input", function(){
+    var applyFilter = function(){
         var searchTerm = search.value.toLowerCase()
 
         gradioApp().querySelectorAll('#'+tabname+'_extra_tabs div.card').forEach(function(elem){
+            var searchOnly = elem.querySelector('.search_only')
             var text = elem.querySelector('.name').textContent.toLowerCase() + " " + elem.querySelector('.search_term').textContent.toLowerCase()
-            elem.style.display = text.indexOf(searchTerm) == -1 ? "none" : ""
+
+            var visible = text.indexOf(searchTerm) != -1
+
+            if(searchOnly && searchTerm.length < 4){
+                visible = false
+            }
+
+            elem.style.display = visible ? "" : "none"
         })
-    });
+    }
+
+    search.addEventListener("input", applyFilter);
+    applyFilter();
+
+    extraNetworksApplyFilter[tabname] = applyFilter;
+}
+
+function applyExtraNetworkFilter(tabname){
+    setTimeout(extraNetworksApplyFilter[tabname], 1);
 }
 
+var extraNetworksApplyFilter = {}
 var activePromptTextarea = {};
 
 function setupExtraNetworks(){

+ 8 - 19
modules/modelloader.py

@@ -22,9 +22,6 @@ def load_models(model_path: str, model_url: str = None, command_path: str = None
     """
     output = []
 
-    if ext_filter is None:
-        ext_filter = []
-
     try:
         places = []
 
@@ -39,22 +36,14 @@ def load_models(model_path: str, model_url: str = None, command_path: str = None
         places.append(model_path)
 
         for place in places:
-            if os.path.exists(place):
-                for file in glob.iglob(place + '**/**', recursive=True):
-                    full_path = file
-                    if os.path.isdir(full_path):
-                        continue
-                    if os.path.islink(full_path) and not os.path.exists(full_path):
-                        print(f"Skipping broken symlink: {full_path}")
-                        continue
-                    if ext_blacklist is not None and any([full_path.endswith(x) for x in ext_blacklist]):
-                        continue
-                    if len(ext_filter) != 0:
-                        model_name, extension = os.path.splitext(file)
-                        if extension not in ext_filter:
-                            continue
-                    if file not in output:
-                        output.append(full_path)
+            for full_path in shared.walk_files(place, allowed_extensions=ext_filter):
+                if os.path.islink(full_path) and not os.path.exists(full_path):
+                    print(f"Skipping broken symlink: {full_path}")
+                    continue
+                if ext_blacklist is not None and any([full_path.endswith(x) for x in ext_blacklist]):
+                    continue
+                if full_path not in output:
+                    output.append(full_path)
 
         if model_url is not None and len(output) == 0:
             if download_name is not None:

+ 17 - 0
modules/shared.py

@@ -726,3 +726,20 @@ def html(filename):
             return file.read()
 
     return ""
+
+
+def walk_files(path, allowed_extensions=None):
+    if not os.path.exists(path):
+        return
+
+    if allowed_extensions is not None:
+        allowed_extensions = set(allowed_extensions)
+
+    for root, dirs, files in os.walk(path):
+        for filename in files:
+            if allowed_extensions is not None:
+                _, ext = os.path.splitext(filename)
+                if ext not in allowed_extensions:
+                    continue
+
+            yield os.path.join(root, filename)

+ 49 - 20
modules/ui_extra_networks.py

@@ -89,19 +89,22 @@ class ExtraNetworksPage:
 
         subdirs = {}
         for parentdir in [os.path.abspath(x) for x in self.allowed_directories_for_previews()]:
-            for x in glob.glob(os.path.join(parentdir, '**/*'), recursive=True):
-                if not os.path.isdir(x):
-                    continue
+            for root, dirs, files in os.walk(parentdir):
+                for dirname in dirs:
+                    x = os.path.join(root, dirname)
 
-                subdir = os.path.abspath(x)[len(parentdir):].replace("\\", "/")
-                while subdir.startswith("/"):
-                    subdir = subdir[1:]
+                    if not os.path.isdir(x):
+                        continue
 
-                is_empty = len(os.listdir(x)) == 0
-                if not is_empty and not subdir.endswith("/"):
-                    subdir = subdir + "/"
+                    subdir = os.path.abspath(x)[len(parentdir):].replace("\\", "/")
+                    while subdir.startswith("/"):
+                        subdir = subdir[1:]
 
-                subdirs[subdir] = 1
+                    is_empty = len(os.listdir(x)) == 0
+                    if not is_empty and not subdir.endswith("/"):
+                        subdir = subdir + "/"
+
+                    subdirs[subdir] = 1
 
         if subdirs:
             subdirs = {"": 1, **subdirs}
@@ -157,8 +160,20 @@ class ExtraNetworksPage:
         if metadata:
             metadata_button = f"<div class='metadata-button' title='Show metadata' onclick='extraNetworksRequestMetadata(event, {json.dumps(self.name)}, {json.dumps(item['name'])})'></div>"
 
+        local_path = ""
+        filename = item.get("filename", "")
+        for reldir in self.allowed_directories_for_previews():
+            absdir = os.path.abspath(reldir)
+
+            if filename.startswith(absdir):
+                local_path = filename[len(absdir):]
+
+        # if this is true, the item must not be show in the default view, and must instead only be
+        # shown when searching for it
+        serach_only = "/." in local_path or "\\." in local_path
+
         args = {
-            "style": f"'{height}{width}{background_image}'",
+            "style": f"'display: none; {height}{width}{background_image}'",
             "prompt": item.get("prompt", None),
             "tabname": json.dumps(tabname),
             "local_preview": json.dumps(item["local_preview"]),
@@ -168,6 +183,7 @@ class ExtraNetworksPage:
             "save_card_preview": '"' + html.escape(f"""return saveCardPreview(event, {json.dumps(tabname)}, {json.dumps(item["local_preview"])})""") + '"',
             "search_term": item.get("search_term", ""),
             "metadata_button": metadata_button,
+            "serach_only": " search_only" if serach_only else "",
         }
 
         return self.card_page.format(**args)
@@ -209,6 +225,11 @@ def intialize():
 class ExtraNetworksUi:
     def __init__(self):
         self.pages = None
+        """gradio HTML components related to extra networks' pages"""
+
+        self.page_contents = None
+        """HTML content of the above; empty initially, filled when extra pages have to be shown"""
+
         self.stored_extra_pages = None
 
         self.button_save_preview = None
@@ -236,17 +257,22 @@ def pages_in_preferred_order(pages):
 def create_ui(container, button, tabname):
     ui = ExtraNetworksUi()
     ui.pages = []
+    ui.pages_contents = []
     ui.stored_extra_pages = pages_in_preferred_order(extra_pages.copy())
     ui.tabname = tabname
 
     with gr.Tabs(elem_id=tabname+"_extra_tabs") as tabs:
         for page in ui.stored_extra_pages:
-            with gr.Tab(page.title, id=page.title.lower().replace(" ", "_")):
+            page_id = page.title.lower().replace(" ", "_")
 
-                page_elem = gr.HTML(page.create_html(ui.tabname))
+            with gr.Tab(page.title, id=page_id):
+                elem_id = f"{tabname}_{page_id}_cards_html"
+                page_elem = gr.HTML('', elem_id=elem_id)
                 ui.pages.append(page_elem)
 
-    filter = gr.Textbox('', show_label=False, elem_id=tabname+"_extra_search", placeholder="Search...", visible=False)
+                page_elem.change(fn=lambda: None, _js='function(){applyExtraNetworkFilter(' + json.dumps(tabname) + '); return []}', inputs=[], outputs=[])
+
+    gr.Textbox('', show_label=False, elem_id=tabname+"_extra_search", placeholder="Search...", visible=False)
     button_refresh = gr.Button('Refresh', elem_id=tabname+"_extra_refresh")
 
     ui.button_save_preview = gr.Button('Save preview', elem_id=tabname+"_save_preview", visible=False)
@@ -254,19 +280,22 @@ def create_ui(container, button, tabname):
 
     def toggle_visibility(is_visible):
         is_visible = not is_visible
-        return is_visible, gr.update(visible=is_visible), gr.update(variant=("secondary-down" if is_visible else "secondary"))
+
+        if is_visible and not ui.pages_contents:
+            refresh()
+
+        return is_visible, gr.update(visible=is_visible), gr.update(variant=("secondary-down" if is_visible else "secondary")), *ui.pages_contents
 
     state_visible = gr.State(value=False)
-    button.click(fn=toggle_visibility, inputs=[state_visible], outputs=[state_visible, container, button])
+    button.click(fn=toggle_visibility, inputs=[state_visible], outputs=[state_visible, container, button, *ui.pages])
 
     def refresh():
-        res = []
-
         for pg in ui.stored_extra_pages:
             pg.refresh()
-            res.append(pg.create_html(ui.tabname))
 
-        return res
+        ui.pages_contents = [pg.create_html(ui.tabname) for pg in ui.stored_extra_pages]
+
+        return ui.pages_contents
 
     button_refresh.click(fn=refresh, inputs=[], outputs=ui.pages)