Browse Source

Merge pull request #16604 from Haoming02/ext-updt-parallel

Check for Extension Updates in Parallel
w-e-w 8 months ago
parent
commit
04903af798
2 changed files with 13 additions and 5 deletions
  1. 1 0
      modules/shared_options.py
  2. 12 5
      modules/ui_extensions.py

+ 1 - 0
modules/shared_options.py

@@ -128,6 +128,7 @@ options_templates.update(options_section(('system', "System", "system"), {
     "disable_mmap_load_safetensors": OptionInfo(False, "Disable memmapping for loading .safetensors files.").info("fixes very slow loading speed in some cases"),
     "disable_mmap_load_safetensors": OptionInfo(False, "Disable memmapping for loading .safetensors files.").info("fixes very slow loading speed in some cases"),
     "hide_ldm_prints": OptionInfo(True, "Prevent Stability-AI's ldm/sgm modules from printing noise to console."),
     "hide_ldm_prints": OptionInfo(True, "Prevent Stability-AI's ldm/sgm modules from printing noise to console."),
     "dump_stacks_on_signal": OptionInfo(False, "Print stack traces before exiting the program with ctrl+c."),
     "dump_stacks_on_signal": OptionInfo(False, "Print stack traces before exiting the program with ctrl+c."),
+    "concurrent_git_fetch_limit": OptionInfo(16, "Number of simultaneous extension update checks ", gr.Slider, {"step": 1, "minimum": 1, "maximum": 100}).info("reduce extension update check time"),
 }))
 }))
 
 
 options_templates.update(options_section(('profiler', "Profiler", "system"), {
 options_templates.update(options_section(('profiler', "Profiler", "system"), {

+ 12 - 5
modules/ui_extensions.py

@@ -1,5 +1,6 @@
 import json
 import json
 import os
 import os
+from concurrent.futures import ThreadPoolExecutor
 import threading
 import threading
 import time
 import time
 from datetime import datetime, timezone
 from datetime import datetime, timezone
@@ -106,18 +107,24 @@ def check_updates(id_task, disable_list):
     exts = [ext for ext in extensions.extensions if ext.remote is not None and ext.name not in disabled]
     exts = [ext for ext in extensions.extensions if ext.remote is not None and ext.name not in disabled]
     shared.state.job_count = len(exts)
     shared.state.job_count = len(exts)
 
 
-    for ext in exts:
-        shared.state.textinfo = ext.name
+    lock = threading.Lock()
 
 
+    def _check_update(ext):
         try:
         try:
             ext.check_updates()
             ext.check_updates()
         except FileNotFoundError as e:
         except FileNotFoundError as e:
             if 'FETCH_HEAD' not in str(e):
             if 'FETCH_HEAD' not in str(e):
                 raise
                 raise
         except Exception:
         except Exception:
-            errors.report(f"Error checking updates for {ext.name}", exc_info=True)
-
-        shared.state.nextjob()
+            with lock:
+                errors.report(f"Error checking updates for {ext.name}", exc_info=True)
+        with lock:
+            shared.state.textinfo = ext.name
+            shared.state.nextjob()
+
+    with ThreadPoolExecutor(max_workers=max(1, int(shared.opts.concurrent_git_fetch_limit))) as executor:
+        for ext in exts:
+            executor.submit(_check_update, ext)
 
 
     return extension_table(), ""
     return extension_table(), ""