Эх сурвалжийг харах

add UI for reordering callbacks

AUTOMATIC1111 1 жил өмнө
parent
commit
7e5e67330b

+ 67 - 25
modules/script_callbacks.py

@@ -8,7 +8,7 @@ from typing import Optional, Any
 from fastapi import FastAPI
 from gradio import Blocks
 
-from modules import errors, timer, extensions
+from modules import errors, timer, extensions, shared
 
 
 def report_exception(c, job):
@@ -124,9 +124,10 @@ class ScriptCallback:
     name: str = None
 
 
-def add_callback(callbacks, fun, *, name=None, category='unknown'):
-    stack = [x for x in inspect.stack() if x.filename != __file__]
-    filename = stack[0].filename if stack else 'unknown file'
+def add_callback(callbacks, fun, *, name=None, category='unknown', filename=None):
+    if filename is None:
+        stack = [x for x in inspect.stack() if x.filename != __file__]
+        filename = stack[0].filename if stack else 'unknown file'
 
     extension = extensions.find_extension(filename)
     extension_name = extension.canonical_name if extension else 'base'
@@ -146,6 +147,43 @@ def add_callback(callbacks, fun, *, name=None, category='unknown'):
     callbacks.append(ScriptCallback(filename, fun, unique_callback_name))
 
 
+def sort_callbacks(category, unordered_callbacks, *, enable_user_sort=True):
+    callbacks = unordered_callbacks.copy()
+
+    if enable_user_sort:
+        for name in reversed(getattr(shared.opts, 'prioritized_callbacks_' + category, [])):
+            index = next((i for i, callback in enumerate(callbacks) if callback.name == name), None)
+            if index is not None:
+                callbacks.insert(0, callbacks.pop(index))
+
+    return callbacks
+
+
+def ordered_callbacks(category, unordered_callbacks=None, *, enable_user_sort=True):
+    if unordered_callbacks is None:
+        unordered_callbacks = callback_map.get('callbacks_' + category, [])
+
+    if not enable_user_sort:
+        return sort_callbacks(category, unordered_callbacks, enable_user_sort=False)
+
+    callbacks = ordered_callbacks_map.get(category)
+    if callbacks is not None and len(callbacks) == len(unordered_callbacks):
+        return callbacks
+
+    callbacks = sort_callbacks(category, unordered_callbacks)
+
+    ordered_callbacks_map[category] = callbacks
+    return callbacks
+
+
+def enumerate_callbacks():
+    for category, callbacks in callback_map.items():
+        if category.startswith('callbacks_'):
+            category = category[10:]
+
+        yield category, callbacks
+
+
 callback_map = dict(
     callbacks_app_started=[],
     callbacks_model_loaded=[],
@@ -170,14 +208,18 @@ callback_map = dict(
     callbacks_before_token_counter=[],
 )
 
+ordered_callbacks_map = {}
+
 
 def clear_callbacks():
     for callback_list in callback_map.values():
         callback_list.clear()
 
+    ordered_callbacks_map.clear()
+
 
 def app_started_callback(demo: Optional[Blocks], app: FastAPI):
-    for c in callback_map['callbacks_app_started']:
+    for c in ordered_callbacks('app_started'):
         try:
             c.callback(demo, app)
             timer.startup_timer.record(os.path.basename(c.script))
@@ -186,7 +228,7 @@ def app_started_callback(demo: Optional[Blocks], app: FastAPI):
 
 
 def app_reload_callback():
-    for c in callback_map['callbacks_on_reload']:
+    for c in ordered_callbacks('on_reload'):
         try:
             c.callback()
         except Exception:
@@ -194,7 +236,7 @@ def app_reload_callback():
 
 
 def model_loaded_callback(sd_model):
-    for c in callback_map['callbacks_model_loaded']:
+    for c in ordered_callbacks('model_loaded'):
         try:
             c.callback(sd_model)
         except Exception:
@@ -204,7 +246,7 @@ def model_loaded_callback(sd_model):
 def ui_tabs_callback():
     res = []
 
-    for c in callback_map['callbacks_ui_tabs']:
+    for c in ordered_callbacks('ui_tabs'):
         try:
             res += c.callback() or []
         except Exception:
@@ -214,7 +256,7 @@ def ui_tabs_callback():
 
 
 def ui_train_tabs_callback(params: UiTrainTabParams):
-    for c in callback_map['callbacks_ui_train_tabs']:
+    for c in ordered_callbacks('ui_train_tabs'):
         try:
             c.callback(params)
         except Exception:
@@ -222,7 +264,7 @@ def ui_train_tabs_callback(params: UiTrainTabParams):
 
 
 def ui_settings_callback():
-    for c in callback_map['callbacks_ui_settings']:
+    for c in ordered_callbacks('ui_settings'):
         try:
             c.callback()
         except Exception:
@@ -230,7 +272,7 @@ def ui_settings_callback():
 
 
 def before_image_saved_callback(params: ImageSaveParams):
-    for c in callback_map['callbacks_before_image_saved']:
+    for c in ordered_callbacks('before_image_saved'):
         try:
             c.callback(params)
         except Exception:
@@ -238,7 +280,7 @@ def before_image_saved_callback(params: ImageSaveParams):
 
 
 def image_saved_callback(params: ImageSaveParams):
-    for c in callback_map['callbacks_image_saved']:
+    for c in ordered_callbacks('image_saved'):
         try:
             c.callback(params)
         except Exception:
@@ -246,7 +288,7 @@ def image_saved_callback(params: ImageSaveParams):
 
 
 def extra_noise_callback(params: ExtraNoiseParams):
-    for c in callback_map['callbacks_extra_noise']:
+    for c in ordered_callbacks('extra_noise'):
         try:
             c.callback(params)
         except Exception:
@@ -254,7 +296,7 @@ def extra_noise_callback(params: ExtraNoiseParams):
 
 
 def cfg_denoiser_callback(params: CFGDenoiserParams):
-    for c in callback_map['callbacks_cfg_denoiser']:
+    for c in ordered_callbacks('cfg_denoiser'):
         try:
             c.callback(params)
         except Exception:
@@ -262,7 +304,7 @@ def cfg_denoiser_callback(params: CFGDenoiserParams):
 
 
 def cfg_denoised_callback(params: CFGDenoisedParams):
-    for c in callback_map['callbacks_cfg_denoised']:
+    for c in ordered_callbacks('cfg_denoised'):
         try:
             c.callback(params)
         except Exception:
@@ -270,7 +312,7 @@ def cfg_denoised_callback(params: CFGDenoisedParams):
 
 
 def cfg_after_cfg_callback(params: AfterCFGCallbackParams):
-    for c in callback_map['callbacks_cfg_after_cfg']:
+    for c in ordered_callbacks('cfg_after_cfg'):
         try:
             c.callback(params)
         except Exception:
@@ -278,7 +320,7 @@ def cfg_after_cfg_callback(params: AfterCFGCallbackParams):
 
 
 def before_component_callback(component, **kwargs):
-    for c in callback_map['callbacks_before_component']:
+    for c in ordered_callbacks('before_component'):
         try:
             c.callback(component, **kwargs)
         except Exception:
@@ -286,7 +328,7 @@ def before_component_callback(component, **kwargs):
 
 
 def after_component_callback(component, **kwargs):
-    for c in callback_map['callbacks_after_component']:
+    for c in ordered_callbacks('after_component'):
         try:
             c.callback(component, **kwargs)
         except Exception:
@@ -294,7 +336,7 @@ def after_component_callback(component, **kwargs):
 
 
 def image_grid_callback(params: ImageGridLoopParams):
-    for c in callback_map['callbacks_image_grid']:
+    for c in ordered_callbacks('image_grid'):
         try:
             c.callback(params)
         except Exception:
@@ -302,7 +344,7 @@ def image_grid_callback(params: ImageGridLoopParams):
 
 
 def infotext_pasted_callback(infotext: str, params: dict[str, Any]):
-    for c in callback_map['callbacks_infotext_pasted']:
+    for c in ordered_callbacks('infotext_pasted'):
         try:
             c.callback(infotext, params)
         except Exception:
@@ -310,7 +352,7 @@ def infotext_pasted_callback(infotext: str, params: dict[str, Any]):
 
 
 def script_unloaded_callback():
-    for c in reversed(callback_map['callbacks_script_unloaded']):
+    for c in reversed(ordered_callbacks('script_unloaded')):
         try:
             c.callback()
         except Exception:
@@ -318,7 +360,7 @@ def script_unloaded_callback():
 
 
 def before_ui_callback():
-    for c in reversed(callback_map['callbacks_before_ui']):
+    for c in reversed(ordered_callbacks('before_ui')):
         try:
             c.callback()
         except Exception:
@@ -328,7 +370,7 @@ def before_ui_callback():
 def list_optimizers_callback():
     res = []
 
-    for c in callback_map['callbacks_list_optimizers']:
+    for c in ordered_callbacks('list_optimizers'):
         try:
             c.callback(res)
         except Exception:
@@ -340,7 +382,7 @@ def list_optimizers_callback():
 def list_unets_callback():
     res = []
 
-    for c in callback_map['callbacks_list_unets']:
+    for c in ordered_callbacks('list_unets'):
         try:
             c.callback(res)
         except Exception:
@@ -350,7 +392,7 @@ def list_unets_callback():
 
 
 def before_token_counter_callback(params: BeforeTokenCounterParams):
-    for c in callback_map['callbacks_before_token_counter']:
+    for c in ordered_callbacks('before_token_counter'):
         try:
             c.callback(params)
         except Exception:

+ 72 - 18
modules/scripts.py

@@ -138,7 +138,6 @@ class Script:
         """
         pass
 
-
     def before_process(self, p, *args):
         """
         This function is called very early during processing begins for AlwaysVisible scripts.
@@ -562,6 +561,25 @@ class ScriptRunner:
         self.paste_field_names = []
         self.inputs = [None]
 
+        self.callback_map = {}
+        self.callback_names = [
+            'before_process',
+            'process',
+            'before_process_batch',
+            'after_extra_networks_activate',
+            'process_batch',
+            'postprocess',
+            'postprocess_batch',
+            'postprocess_batch_list',
+            'post_sample',
+            'on_mask_blend',
+            'postprocess_image',
+            'postprocess_maskoverlay',
+            'postprocess_image_after_composite',
+            'before_component',
+            'after_component',
+        ]
+
         self.on_before_component_elem_id = {}
         """dict of callbacks to be called before an element is created; key=elem_id, value=list of callbacks"""
 
@@ -600,6 +618,8 @@ class ScriptRunner:
                 self.scripts.append(script)
                 self.selectable_scripts.append(script)
 
+        self.callback_map.clear()
+
         self.apply_on_before_component_callbacks()
 
     def apply_on_before_component_callbacks(self):
@@ -769,8 +789,42 @@ class ScriptRunner:
 
         return processed
 
+    def list_scripts_for_method(self, method_name):
+        if method_name in ('before_component', 'after_component'):
+            return self.scripts
+        else:
+            return self.alwayson_scripts
+
+    def create_ordered_callbacks_list(self,  method_name, *, enable_user_sort=True):
+        script_list = self.list_scripts_for_method(method_name)
+        category = f'script_{method_name}'
+        callbacks = []
+
+        for script in script_list:
+            if getattr(script.__class__, method_name, None) == getattr(Script, method_name, None):
+                continue
+
+            script_callbacks.add_callback(callbacks, script, category=category, name=script.__class__.__name__, filename=script.filename)
+
+        return script_callbacks.sort_callbacks(category, callbacks, enable_user_sort=enable_user_sort)
+
+    def ordered_callbacks(self, method_name, *, enable_user_sort=True):
+        script_list = self.list_scripts_for_method(method_name)
+        category = f'script_{method_name}'
+
+        scrpts_len, callbacks = self.callback_map.get(category, (-1, None))
+
+        if callbacks is None or scrpts_len != len(script_list):
+            callbacks = self.create_ordered_callbacks_list(method_name, enable_user_sort=enable_user_sort)
+            self.callback_map[category] = len(script_list), callbacks
+
+        return callbacks
+
+    def ordered_scripts(self, method_name):
+        return [x.callback for x in self.ordered_callbacks(method_name)]
+
     def before_process(self, p):
-        for script in self.alwayson_scripts:
+        for script in self.ordered_scripts('before_process'):
             try:
                 script_args = p.script_args[script.args_from:script.args_to]
                 script.before_process(p, *script_args)
@@ -778,7 +832,7 @@ class ScriptRunner:
                 errors.report(f"Error running before_process: {script.filename}", exc_info=True)
 
     def process(self, p):
-        for script in self.alwayson_scripts:
+        for script in self.ordered_scripts('process'):
             try:
                 script_args = p.script_args[script.args_from:script.args_to]
                 script.process(p, *script_args)
@@ -786,7 +840,7 @@ class ScriptRunner:
                 errors.report(f"Error running process: {script.filename}", exc_info=True)
 
     def before_process_batch(self, p, **kwargs):
-        for script in self.alwayson_scripts:
+        for script in self.ordered_scripts('before_process_batch'):
             try:
                 script_args = p.script_args[script.args_from:script.args_to]
                 script.before_process_batch(p, *script_args, **kwargs)
@@ -794,7 +848,7 @@ class ScriptRunner:
                 errors.report(f"Error running before_process_batch: {script.filename}", exc_info=True)
 
     def after_extra_networks_activate(self, p, **kwargs):
-        for script in self.alwayson_scripts:
+        for script in self.ordered_scripts('after_extra_networks_activate'):
             try:
                 script_args = p.script_args[script.args_from:script.args_to]
                 script.after_extra_networks_activate(p, *script_args, **kwargs)
@@ -802,7 +856,7 @@ class ScriptRunner:
                 errors.report(f"Error running after_extra_networks_activate: {script.filename}", exc_info=True)
 
     def process_batch(self, p, **kwargs):
-        for script in self.alwayson_scripts:
+        for script in self.ordered_scripts('process_batch'):
             try:
                 script_args = p.script_args[script.args_from:script.args_to]
                 script.process_batch(p, *script_args, **kwargs)
@@ -810,7 +864,7 @@ class ScriptRunner:
                 errors.report(f"Error running process_batch: {script.filename}", exc_info=True)
 
     def postprocess(self, p, processed):
-        for script in self.alwayson_scripts:
+        for script in self.ordered_scripts('postprocess'):
             try:
                 script_args = p.script_args[script.args_from:script.args_to]
                 script.postprocess(p, processed, *script_args)
@@ -818,7 +872,7 @@ class ScriptRunner:
                 errors.report(f"Error running postprocess: {script.filename}", exc_info=True)
 
     def postprocess_batch(self, p, images, **kwargs):
-        for script in self.alwayson_scripts:
+        for script in self.ordered_scripts('postprocess_batch'):
             try:
                 script_args = p.script_args[script.args_from:script.args_to]
                 script.postprocess_batch(p, *script_args, images=images, **kwargs)
@@ -826,7 +880,7 @@ class ScriptRunner:
                 errors.report(f"Error running postprocess_batch: {script.filename}", exc_info=True)
 
     def postprocess_batch_list(self, p, pp: PostprocessBatchListArgs, **kwargs):
-        for script in self.alwayson_scripts:
+        for script in self.ordered_scripts('postprocess_batch_list'):
             try:
                 script_args = p.script_args[script.args_from:script.args_to]
                 script.postprocess_batch_list(p, pp, *script_args, **kwargs)
@@ -834,7 +888,7 @@ class ScriptRunner:
                 errors.report(f"Error running postprocess_batch_list: {script.filename}", exc_info=True)
 
     def post_sample(self, p, ps: PostSampleArgs):
-        for script in self.alwayson_scripts:
+        for script in self.ordered_scripts('post_sample'):
             try:
                 script_args = p.script_args[script.args_from:script.args_to]
                 script.post_sample(p, ps, *script_args)
@@ -842,7 +896,7 @@ class ScriptRunner:
                 errors.report(f"Error running post_sample: {script.filename}", exc_info=True)
 
     def on_mask_blend(self, p, mba: MaskBlendArgs):
-        for script in self.alwayson_scripts:
+        for script in self.ordered_scripts('on_mask_blend'):
             try:
                 script_args = p.script_args[script.args_from:script.args_to]
                 script.on_mask_blend(p, mba, *script_args)
@@ -850,7 +904,7 @@ class ScriptRunner:
                 errors.report(f"Error running post_sample: {script.filename}", exc_info=True)
 
     def postprocess_image(self, p, pp: PostprocessImageArgs):
-        for script in self.alwayson_scripts:
+        for script in self.ordered_scripts('postprocess_image'):
             try:
                 script_args = p.script_args[script.args_from:script.args_to]
                 script.postprocess_image(p, pp, *script_args)
@@ -858,7 +912,7 @@ class ScriptRunner:
                 errors.report(f"Error running postprocess_image: {script.filename}", exc_info=True)
 
     def postprocess_maskoverlay(self, p, ppmo: PostProcessMaskOverlayArgs):
-        for script in self.alwayson_scripts:
+        for script in self.ordered_scripts('postprocess_maskoverlay'):
             try:
                 script_args = p.script_args[script.args_from:script.args_to]
                 script.postprocess_maskoverlay(p, ppmo, *script_args)
@@ -866,7 +920,7 @@ class ScriptRunner:
                 errors.report(f"Error running postprocess_image: {script.filename}", exc_info=True)
 
     def postprocess_image_after_composite(self, p, pp: PostprocessImageArgs):
-        for script in self.alwayson_scripts:
+        for script in self.ordered_scripts('postprocess_image_after_composite'):
             try:
                 script_args = p.script_args[script.args_from:script.args_to]
                 script.postprocess_image_after_composite(p, pp, *script_args)
@@ -880,7 +934,7 @@ class ScriptRunner:
             except Exception:
                 errors.report(f"Error running on_before_component: {script.filename}", exc_info=True)
 
-        for script in self.scripts:
+        for script in self.ordered_scripts('before_component'):
             try:
                 script.before_component(component, **kwargs)
             except Exception:
@@ -893,7 +947,7 @@ class ScriptRunner:
             except Exception:
                 errors.report(f"Error running on_after_component: {script.filename}", exc_info=True)
 
-        for script in self.scripts:
+        for script in self.ordered_scripts('after_component'):
             try:
                 script.after_component(component, **kwargs)
             except Exception:
@@ -921,7 +975,7 @@ class ScriptRunner:
                     self.scripts[si].args_to = args_to
 
     def before_hr(self, p):
-        for script in self.alwayson_scripts:
+        for script in self.ordered_scripts('before_hr'):
             try:
                 script_args = p.script_args[script.args_from:script.args_to]
                 script.before_hr(p, *script_args)
@@ -929,7 +983,7 @@ class ScriptRunner:
                 errors.report(f"Error running before_hr: {script.filename}", exc_info=True)
 
     def setup_scrips(self, p, *, is_ui=True):
-        for script in self.alwayson_scripts:
+        for script in self.ordered_scripts('setup'):
             if not is_ui and script.setup_for_ui_only:
                 continue
 

+ 42 - 0
modules/shared_items.py

@@ -1,5 +1,8 @@
+import html
 import sys
 
+from modules import script_callbacks, scripts, ui_components
+from modules.options import OptionHTML, OptionInfo
 from modules.shared_cmd_options import cmd_opts
 
 
@@ -118,6 +121,45 @@ def ui_reorder_categories():
     yield "scripts"
 
 
+def callbacks_order_settings():
+    options = {
+        "sd_vae_explanation": OptionHTML("""
+    For categories below, callbacks added to dropdowns happen before others, in order listed.
+    """),
+
+    }
+
+    callback_options = {}
+
+    for category, _ in script_callbacks.enumerate_callbacks():
+        callback_options[category] = script_callbacks.ordered_callbacks(category, enable_user_sort=False)
+
+    for method_name in scripts.scripts_txt2img.callback_names:
+        callback_options["script_" + method_name] = scripts.scripts_txt2img.create_ordered_callbacks_list(method_name, enable_user_sort=False)
+
+    for method_name in scripts.scripts_img2img.callback_names:
+        callbacks = callback_options.get("script_" + method_name, [])
+
+        for addition in scripts.scripts_img2img.create_ordered_callbacks_list(method_name, enable_user_sort=False):
+            if any(x.name == addition.name for x in callbacks):
+                continue
+
+            callbacks.append(addition)
+
+        callback_options["script_" + method_name] = callbacks
+
+    for category, callbacks in callback_options.items():
+        if not callbacks:
+            continue
+
+        option_info = OptionInfo([], f"{category} callback priority", ui_components.DropdownMulti, {"choices": [x.name for x in callbacks]})
+        option_info.needs_restart()
+        option_info.html("<div class='info'>Default order: <ol>" + "".join(f"<li>{html.escape(x.name)}</li>\n" for x in callbacks) + "</ol></div>")
+        options['prioritized_callbacks_' + category] = option_info
+
+    return options
+
+
 class Shared(sys.modules[__name__].__class__):
     """
     this class is here to provide sd_model field as a property, so that it can be created and loaded on demand rather than

+ 7 - 1
modules/ui_settings.py

@@ -1,7 +1,8 @@
 import gradio as gr
 
-from modules import ui_common, shared, script_callbacks, scripts, sd_models, sysinfo, timer
+from modules import ui_common, shared, script_callbacks, scripts, sd_models, sysinfo, timer, shared_items
 from modules.call_queue import wrap_gradio_call
+from modules.options import options_section
 from modules.shared import opts
 from modules.ui_components import FormRow
 from modules.ui_gradio_extensions import reload_javascript
@@ -108,6 +109,11 @@ class UiSettings:
 
         shared.settings_components = self.component_dict
 
+        # we add this as late as possible so that scripts have already registered their callbacks
+        opts.data_labels.update(options_section(('callbacks', "Callbacks", "system"), {
+            **shared_items.callbacks_order_settings(),
+        }))
+
         opts.reorder()
 
         with gr.Blocks(analytics_enabled=False) as settings_interface:

+ 4 - 0
style.css

@@ -528,6 +528,10 @@ table.popup-table .link{
     opacity: 0.75;
 }
 
+.settings-comment .info ol{
+    margin: 0.4em 0 0.8em 1em;
+}
+
 #sysinfo_download a.sysinfo_big_link{
     font-size: 24pt;
 }