Browse Source

Merge pull request #15205 from AUTOMATIC1111/callback_order

Callback order
AUTOMATIC1111 1 year ago
parent
commit
5bd2724765
7 changed files with 361 additions and 121 deletions
  1. 41 0
      modules/extensions.py
  2. 169 73
      modules/script_callbacks.py
  3. 75 42
      modules/scripts.py
  4. 42 0
      modules/shared_items.py
  5. 7 1
      modules/ui_settings.py
  6. 23 5
      modules/util.py
  7. 4 0
      style.css

+ 41 - 0
modules/extensions.py

@@ -1,6 +1,7 @@
 from __future__ import annotations
 
 import configparser
+import dataclasses
 import os
 import threading
 import re
@@ -22,6 +23,13 @@ def active():
         return [x for x in extensions if x.enabled]
 
 
+@dataclasses.dataclass
+class CallbackOrderInfo:
+    name: str
+    before: list
+    after: list
+
+
 class ExtensionMetadata:
     filename = "metadata.ini"
     config: configparser.ConfigParser
@@ -65,6 +73,22 @@ class ExtensionMetadata:
         # both "," and " " are accepted as separator
         return [x for x in re.split(r"[,\s]+", text.strip()) if x]
 
+    def list_callback_order_instructions(self):
+        for section in self.config.sections():
+            if not section.startswith("callbacks/"):
+                continue
+
+            callback_name = section[10:]
+
+            if not callback_name.startswith(self.canonical_name):
+                errors.report(f"Callback order section for extension {self.canonical_name} is referencing the wrong extension: {section}")
+                continue
+
+            before = self.parse_list(self.config.get(section, 'Before', fallback=''))
+            after = self.parse_list(self.config.get(section, 'After', fallback=''))
+
+            yield CallbackOrderInfo(callback_name, before, after)
+
 
 class Extension:
     lock = threading.Lock()
@@ -188,6 +212,7 @@ class Extension:
 
 def list_extensions():
     extensions.clear()
+    extension_paths.clear()
 
     if shared.cmd_opts.disable_all_extensions:
         print("*** \"--disable-all-extensions\" arg was used, will not load any extensions ***")
@@ -222,6 +247,7 @@ def list_extensions():
             is_builtin = dirname == extensions_builtin_dir
             extension = Extension(name=extension_dirname, path=path, enabled=extension_dirname not in shared.opts.disabled_extensions, is_builtin=is_builtin, metadata=metadata)
             extensions.append(extension)
+            extension_paths[extension.path] = extension
             loaded_extensions[canonical_name] = extension
 
     # check for requirements
@@ -240,4 +266,19 @@ def list_extensions():
                 continue
 
 
+def find_extension(filename):
+    parentdir = os.path.dirname(os.path.realpath(filename))
+
+    while parentdir != filename:
+        extension = extension_paths.get(parentdir)
+        if extension is not None:
+            return extension
+
+        filename = parentdir
+        parentdir = os.path.dirname(filename)
+
+    return None
+
+
 extensions: list[Extension] = []
+extension_paths: dict[str, Extension] = {}

+ 169 - 73
modules/script_callbacks.py

@@ -1,13 +1,14 @@
+from __future__ import annotations
+
 import dataclasses
 import inspect
 import os
-from collections import namedtuple
 from typing import Optional, Any
 
 from fastapi import FastAPI
 from gradio import Blocks
 
-from modules import errors, timer
+from modules import errors, timer, extensions, shared, util
 
 
 def report_exception(c, job):
@@ -116,7 +117,105 @@ class BeforeTokenCounterParams:
     is_positive: bool = True
 
 
-ScriptCallback = namedtuple("ScriptCallback", ["script", "callback"])
+@dataclasses.dataclass
+class ScriptCallback:
+    script: str
+    callback: any
+    name: str = None
+
+
+def add_callback(callbacks, fun, *, name=None, category='unknown', filename=None):
+    if filename is None:
+        stack = [x for x in inspect.stack() if x.filename != __file__]
+        filename = stack[0].filename if stack else 'unknown file'
+
+    extension = extensions.find_extension(filename)
+    extension_name = extension.canonical_name if extension else 'base'
+
+    callback_name = f"{extension_name}/{os.path.basename(filename)}/{category}"
+    if name is not None:
+        callback_name += f'/{name}'
+
+    unique_callback_name = callback_name
+    for index in range(1000):
+        existing = any(x.name == unique_callback_name for x in callbacks)
+        if not existing:
+            break
+
+        unique_callback_name = f'{callback_name}-{index+1}'
+
+    callbacks.append(ScriptCallback(filename, fun, unique_callback_name))
+
+
+def sort_callbacks(category, unordered_callbacks, *, enable_user_sort=True):
+    callbacks = unordered_callbacks.copy()
+    callback_lookup = {x.name: x for x in callbacks}
+    dependencies = {}
+
+    order_instructions = {}
+    for extension in extensions.extensions:
+        for order_instruction in extension.metadata.list_callback_order_instructions():
+            if order_instruction.name in callback_lookup:
+                if order_instruction.name not in order_instructions:
+                    order_instructions[order_instruction.name] = []
+
+                order_instructions[order_instruction.name].append(order_instruction)
+
+    if order_instructions:
+        for callback in callbacks:
+            dependencies[callback.name] = []
+
+        for callback in callbacks:
+            for order_instruction in order_instructions.get(callback.name, []):
+                for after in order_instruction.after:
+                    if after not in callback_lookup:
+                        continue
+
+                    dependencies[callback.name].append(after)
+
+                for before in order_instruction.before:
+                    if before not in callback_lookup:
+                        continue
+
+                    dependencies[before].append(callback.name)
+
+        sorted_names = util.topological_sort(dependencies)
+        callbacks = [callback_lookup[x] for x in sorted_names]
+
+    if enable_user_sort:
+        for name in reversed(getattr(shared.opts, 'prioritized_callbacks_' + category, [])):
+            index = next((i for i, callback in enumerate(callbacks) if callback.name == name), None)
+            if index is not None:
+                callbacks.insert(0, callbacks.pop(index))
+
+    return callbacks
+
+
+def ordered_callbacks(category, unordered_callbacks=None, *, enable_user_sort=True):
+    if unordered_callbacks is None:
+        unordered_callbacks = callback_map.get('callbacks_' + category, [])
+
+    if not enable_user_sort:
+        return sort_callbacks(category, unordered_callbacks, enable_user_sort=False)
+
+    callbacks = ordered_callbacks_map.get(category)
+    if callbacks is not None and len(callbacks) == len(unordered_callbacks):
+        return callbacks
+
+    callbacks = sort_callbacks(category, unordered_callbacks)
+
+    ordered_callbacks_map[category] = callbacks
+    return callbacks
+
+
+def enumerate_callbacks():
+    for category, callbacks in callback_map.items():
+        if category.startswith('callbacks_'):
+            category = category[10:]
+
+        yield category, callbacks
+
+
 callback_map = dict(
     callbacks_app_started=[],
     callbacks_model_loaded=[],
@@ -141,14 +240,18 @@ callback_map = dict(
     callbacks_before_token_counter=[],
 )
 
+ordered_callbacks_map = {}
+
 
 def clear_callbacks():
     for callback_list in callback_map.values():
         callback_list.clear()
 
+    ordered_callbacks_map.clear()
+
 
 def app_started_callback(demo: Optional[Blocks], app: FastAPI):
-    for c in callback_map['callbacks_app_started']:
+    for c in ordered_callbacks('app_started'):
         try:
             c.callback(demo, app)
             timer.startup_timer.record(os.path.basename(c.script))
@@ -157,7 +260,7 @@ def app_started_callback(demo: Optional[Blocks], app: FastAPI):
 
 
 def app_reload_callback():
-    for c in callback_map['callbacks_on_reload']:
+    for c in ordered_callbacks('on_reload'):
         try:
             c.callback()
         except Exception:
@@ -165,7 +268,7 @@ def app_reload_callback():
 
 
 def model_loaded_callback(sd_model):
-    for c in callback_map['callbacks_model_loaded']:
+    for c in ordered_callbacks('model_loaded'):
         try:
             c.callback(sd_model)
         except Exception:
@@ -175,7 +278,7 @@ def model_loaded_callback(sd_model):
 def ui_tabs_callback():
     res = []
 
-    for c in callback_map['callbacks_ui_tabs']:
+    for c in ordered_callbacks('ui_tabs'):
         try:
             res += c.callback() or []
         except Exception:
@@ -185,7 +288,7 @@ def ui_tabs_callback():
 
 
 def ui_train_tabs_callback(params: UiTrainTabParams):
-    for c in callback_map['callbacks_ui_train_tabs']:
+    for c in ordered_callbacks('ui_train_tabs'):
         try:
             c.callback(params)
         except Exception:
@@ -193,7 +296,7 @@ def ui_train_tabs_callback(params: UiTrainTabParams):
 
 
 def ui_settings_callback():
-    for c in callback_map['callbacks_ui_settings']:
+    for c in ordered_callbacks('ui_settings'):
         try:
             c.callback()
         except Exception:
@@ -201,7 +304,7 @@ def ui_settings_callback():
 
 
 def before_image_saved_callback(params: ImageSaveParams):
-    for c in callback_map['callbacks_before_image_saved']:
+    for c in ordered_callbacks('before_image_saved'):
         try:
             c.callback(params)
         except Exception:
@@ -209,7 +312,7 @@ def before_image_saved_callback(params: ImageSaveParams):
 
 
 def image_saved_callback(params: ImageSaveParams):
-    for c in callback_map['callbacks_image_saved']:
+    for c in ordered_callbacks('image_saved'):
         try:
             c.callback(params)
         except Exception:
@@ -217,7 +320,7 @@ def image_saved_callback(params: ImageSaveParams):
 
 
 def extra_noise_callback(params: ExtraNoiseParams):
-    for c in callback_map['callbacks_extra_noise']:
+    for c in ordered_callbacks('extra_noise'):
         try:
             c.callback(params)
         except Exception:
@@ -225,7 +328,7 @@ def extra_noise_callback(params: ExtraNoiseParams):
 
 
 def cfg_denoiser_callback(params: CFGDenoiserParams):
-    for c in callback_map['callbacks_cfg_denoiser']:
+    for c in ordered_callbacks('cfg_denoiser'):
         try:
             c.callback(params)
         except Exception:
@@ -233,7 +336,7 @@ def cfg_denoiser_callback(params: CFGDenoiserParams):
 
 
 def cfg_denoised_callback(params: CFGDenoisedParams):
-    for c in callback_map['callbacks_cfg_denoised']:
+    for c in ordered_callbacks('cfg_denoised'):
         try:
             c.callback(params)
         except Exception:
@@ -241,7 +344,7 @@ def cfg_denoised_callback(params: CFGDenoisedParams):
 
 
 def cfg_after_cfg_callback(params: AfterCFGCallbackParams):
-    for c in callback_map['callbacks_cfg_after_cfg']:
+    for c in ordered_callbacks('cfg_after_cfg'):
         try:
             c.callback(params)
         except Exception:
@@ -249,7 +352,7 @@ def cfg_after_cfg_callback(params: AfterCFGCallbackParams):
 
 
 def before_component_callback(component, **kwargs):
-    for c in callback_map['callbacks_before_component']:
+    for c in ordered_callbacks('before_component'):
         try:
             c.callback(component, **kwargs)
         except Exception:
@@ -257,7 +360,7 @@ def before_component_callback(component, **kwargs):
 
 
 def after_component_callback(component, **kwargs):
-    for c in callback_map['callbacks_after_component']:
+    for c in ordered_callbacks('after_component'):
         try:
             c.callback(component, **kwargs)
         except Exception:
@@ -265,7 +368,7 @@ def after_component_callback(component, **kwargs):
 
 
 def image_grid_callback(params: ImageGridLoopParams):
-    for c in callback_map['callbacks_image_grid']:
+    for c in ordered_callbacks('image_grid'):
         try:
             c.callback(params)
         except Exception:
@@ -273,7 +376,7 @@ def image_grid_callback(params: ImageGridLoopParams):
 
 
 def infotext_pasted_callback(infotext: str, params: dict[str, Any]):
-    for c in callback_map['callbacks_infotext_pasted']:
+    for c in ordered_callbacks('infotext_pasted'):
         try:
             c.callback(infotext, params)
         except Exception:
@@ -281,7 +384,7 @@ def infotext_pasted_callback(infotext: str, params: dict[str, Any]):
 
 
 def script_unloaded_callback():
-    for c in reversed(callback_map['callbacks_script_unloaded']):
+    for c in reversed(ordered_callbacks('script_unloaded')):
         try:
             c.callback()
         except Exception:
@@ -289,7 +392,7 @@ def script_unloaded_callback():
 
 
 def before_ui_callback():
-    for c in reversed(callback_map['callbacks_before_ui']):
+    for c in reversed(ordered_callbacks('before_ui')):
         try:
             c.callback()
         except Exception:
@@ -299,7 +402,7 @@ def before_ui_callback():
 def list_optimizers_callback():
     res = []
 
-    for c in callback_map['callbacks_list_optimizers']:
+    for c in ordered_callbacks('list_optimizers'):
         try:
             c.callback(res)
         except Exception:
@@ -311,7 +414,7 @@ def list_optimizers_callback():
 def list_unets_callback():
     res = []
 
-    for c in callback_map['callbacks_list_unets']:
+    for c in ordered_callbacks('list_unets'):
         try:
             c.callback(res)
         except Exception:
@@ -321,20 +424,13 @@ def list_unets_callback():
 
 
 def before_token_counter_callback(params: BeforeTokenCounterParams):
-    for c in callback_map['callbacks_before_token_counter']:
+    for c in ordered_callbacks('before_token_counter'):
         try:
             c.callback(params)
         except Exception:
             report_exception(c, 'before_token_counter')
 
 
-def add_callback(callbacks, fun):
-    stack = [x for x in inspect.stack() if x.filename != __file__]
-    filename = stack[0].filename if stack else 'unknown file'
-
-    callbacks.append(ScriptCallback(filename, fun))
-
-
 def remove_current_script_callbacks():
     stack = [x for x in inspect.stack() if x.filename != __file__]
     filename = stack[0].filename if stack else 'unknown file'
@@ -351,24 +447,24 @@ def remove_callbacks_for_function(callback_func):
             callback_list.remove(callback_to_remove)
 
 
-def on_app_started(callback):
+def on_app_started(callback, *, name=None):
     """register a function to be called when the webui started, the gradio `Block` component and
     fastapi `FastAPI` object are passed as the arguments"""
-    add_callback(callback_map['callbacks_app_started'], callback)
+    add_callback(callback_map['callbacks_app_started'], callback, name=name, category='app_started')
 
 
-def on_before_reload(callback):
+def on_before_reload(callback, *, name=None):
     """register a function to be called just before the server reloads."""
-    add_callback(callback_map['callbacks_on_reload'], callback)
+    add_callback(callback_map['callbacks_on_reload'], callback, name=name, category='on_reload')
 
 
-def on_model_loaded(callback):
+def on_model_loaded(callback, *, name=None):
     """register a function to be called when the stable diffusion model is created; the model is
     passed as an argument; this function is also called when the script is reloaded. """
-    add_callback(callback_map['callbacks_model_loaded'], callback)
+    add_callback(callback_map['callbacks_model_loaded'], callback, name=name, category='model_loaded')
 
 
-def on_ui_tabs(callback):
+def on_ui_tabs(callback, *, name=None):
     """register a function to be called when the UI is creating new tabs.
     The function must either return a None, which means no new tabs to be added, or a list, where
     each element is a tuple:
@@ -378,71 +474,71 @@ def on_ui_tabs(callback):
     title is tab text displayed to user in the UI
     elem_id is HTML id for the tab
     """
-    add_callback(callback_map['callbacks_ui_tabs'], callback)
+    add_callback(callback_map['callbacks_ui_tabs'], callback, name=name, category='ui_tabs')
 
 
-def on_ui_train_tabs(callback):
+def on_ui_train_tabs(callback, *, name=None):
     """register a function to be called when the UI is creating new tabs for the train tab.
     Create your new tabs with gr.Tab.
     """
-    add_callback(callback_map['callbacks_ui_train_tabs'], callback)
+    add_callback(callback_map['callbacks_ui_train_tabs'], callback, name=name, category='ui_train_tabs')
 
 
-def on_ui_settings(callback):
+def on_ui_settings(callback, *, name=None):
     """register a function to be called before UI settings are populated; add your settings
     by using shared.opts.add_option(shared.OptionInfo(...)) """
-    add_callback(callback_map['callbacks_ui_settings'], callback)
+    add_callback(callback_map['callbacks_ui_settings'], callback, name=name, category='ui_settings')
 
 
-def on_before_image_saved(callback):
+def on_before_image_saved(callback, *, name=None):
     """register a function to be called before an image is saved to a file.
     The callback is called with one argument:
         - params: ImageSaveParams - parameters the image is to be saved with. You can change fields in this object.
     """
-    add_callback(callback_map['callbacks_before_image_saved'], callback)
+    add_callback(callback_map['callbacks_before_image_saved'], callback, name=name, category='before_image_saved')
 
 
-def on_image_saved(callback):
+def on_image_saved(callback, *, name=None):
     """register a function to be called after an image is saved to a file.
     The callback is called with one argument:
         - params: ImageSaveParams - parameters the image was saved with. Changing fields in this object does nothing.
     """
-    add_callback(callback_map['callbacks_image_saved'], callback)
+    add_callback(callback_map['callbacks_image_saved'], callback, name=name, category='image_saved')
 
 
-def on_extra_noise(callback):
+def on_extra_noise(callback, *, name=None):
     """register a function to be called before adding extra noise in img2img or hires fix;
     The callback is called with one argument:
         - params: ExtraNoiseParams - contains noise determined by seed and latent representation of image
     """
-    add_callback(callback_map['callbacks_extra_noise'], callback)
+    add_callback(callback_map['callbacks_extra_noise'], callback, name=name, category='extra_noise')
 
 
-def on_cfg_denoiser(callback):
+def on_cfg_denoiser(callback, *, name=None):
     """register a function to be called in the kdiffussion cfg_denoiser method after building the inner model inputs.
     The callback is called with one argument:
         - params: CFGDenoiserParams - parameters to be passed to the inner model and sampling state details.
     """
-    add_callback(callback_map['callbacks_cfg_denoiser'], callback)
+    add_callback(callback_map['callbacks_cfg_denoiser'], callback, name=name, category='cfg_denoiser')
 
 
-def on_cfg_denoised(callback):
+def on_cfg_denoised(callback, *, name=None):
     """register a function to be called in the kdiffussion cfg_denoiser method after building the inner model inputs.
     The callback is called with one argument:
         - params: CFGDenoisedParams - parameters to be passed to the inner model and sampling state details.
     """
-    add_callback(callback_map['callbacks_cfg_denoised'], callback)
+    add_callback(callback_map['callbacks_cfg_denoised'], callback, name=name, category='cfg_denoised')
 
 
-def on_cfg_after_cfg(callback):
+def on_cfg_after_cfg(callback, *, name=None):
     """register a function to be called in the kdiffussion cfg_denoiser method after cfg calculations are completed.
     The callback is called with one argument:
         - params: AfterCFGCallbackParams - parameters to be passed to the script for post-processing after cfg calculation.
     """
-    add_callback(callback_map['callbacks_cfg_after_cfg'], callback)
+    add_callback(callback_map['callbacks_cfg_after_cfg'], callback, name=name, category='cfg_after_cfg')
 
 
-def on_before_component(callback):
+def on_before_component(callback, *, name=None):
     """register a function to be called before a component is created.
     The callback is called with arguments:
         - component - gradio component that is about to be created.
@@ -451,61 +547,61 @@ def on_before_component(callback):
     Use elem_id/label fields of kwargs to figure out which component it is.
     This can be useful to inject your own components somewhere in the middle of vanilla UI.
     """
-    add_callback(callback_map['callbacks_before_component'], callback)
+    add_callback(callback_map['callbacks_before_component'], callback, name=name, category='before_component')
 
 
-def on_after_component(callback):
+def on_after_component(callback, *, name=None):
     """register a function to be called after a component is created. See on_before_component for more."""
-    add_callback(callback_map['callbacks_after_component'], callback)
+    add_callback(callback_map['callbacks_after_component'], callback, name=name, category='after_component')
 
 
-def on_image_grid(callback):
+def on_image_grid(callback, *, name=None):
     """register a function to be called before making an image grid.
     The callback is called with one argument:
        - params: ImageGridLoopParams - parameters to be used for grid creation. Can be modified.
     """
-    add_callback(callback_map['callbacks_image_grid'], callback)
+    add_callback(callback_map['callbacks_image_grid'], callback, name=name, category='image_grid')
 
 
-def on_infotext_pasted(callback):
+def on_infotext_pasted(callback, *, name=None):
     """register a function to be called before applying an infotext.
     The callback is called with two arguments:
        - infotext: str - raw infotext.
        - result: dict[str, any] - parsed infotext parameters.
     """
-    add_callback(callback_map['callbacks_infotext_pasted'], callback)
+    add_callback(callback_map['callbacks_infotext_pasted'], callback, name=name, category='infotext_pasted')
 
 
-def on_script_unloaded(callback):
+def on_script_unloaded(callback, *, name=None):
     """register a function to be called before the script is unloaded. Any hooks/hijacks/monkeying about that
     the script did should be reverted here"""
 
-    add_callback(callback_map['callbacks_script_unloaded'], callback)
+    add_callback(callback_map['callbacks_script_unloaded'], callback, name=name, category='script_unloaded')
 
 
-def on_before_ui(callback):
+def on_before_ui(callback, *, name=None):
     """register a function to be called before the UI is created."""
 
-    add_callback(callback_map['callbacks_before_ui'], callback)
+    add_callback(callback_map['callbacks_before_ui'], callback, name=name, category='before_ui')
 
 
-def on_list_optimizers(callback):
+def on_list_optimizers(callback, *, name=None):
     """register a function to be called when UI is making a list of cross attention optimization options.
     The function will be called with one argument, a list, and shall add objects of type modules.sd_hijack_optimizations.SdOptimization
     to it."""
 
-    add_callback(callback_map['callbacks_list_optimizers'], callback)
+    add_callback(callback_map['callbacks_list_optimizers'], callback, name=name, category='list_optimizers')
 
 
-def on_list_unets(callback):
+def on_list_unets(callback, *, name=None):
     """register a function to be called when UI is making a list of alternative options for unet.
     The function will be called with one argument, a list, and shall add objects of type modules.sd_unet.SdUnetOption to it."""
 
-    add_callback(callback_map['callbacks_list_unets'], callback)
+    add_callback(callback_map['callbacks_list_unets'], callback, name=name, category='list_unets')
 
 
-def on_before_token_counter(callback):
+def on_before_token_counter(callback, *, name=None):
     """register a function to be called when UI is counting tokens for a prompt.
     The function will be called with one argument of type BeforeTokenCounterParams, and should modify its fields if necessary."""
 
-    add_callback(callback_map['callbacks_before_token_counter'], callback)
+    add_callback(callback_map['callbacks_before_token_counter'], callback, name=name, category='before_token_counter')

+ 75 - 42
modules/scripts.py

@@ -7,7 +7,9 @@ from dataclasses import dataclass
 
 import gradio as gr
 
-from modules import shared, paths, script_callbacks, extensions, script_loading, scripts_postprocessing, errors, timer
+from modules import shared, paths, script_callbacks, extensions, script_loading, scripts_postprocessing, errors, timer, util
+
+topological_sort = util.topological_sort
 
 AlwaysVisible = object()
 
@@ -138,7 +140,6 @@ class Script:
         """
         pass
 
-
     def before_process(self, p, *args):
         """
         This function is called very early during processing begins for AlwaysVisible scripts.
@@ -369,29 +370,6 @@ scripts_data = []
 postprocessing_scripts_data = []
 ScriptClassData = namedtuple("ScriptClassData", ["script_class", "path", "basedir", "module"])
 
-def topological_sort(dependencies):
-    """Accepts a dictionary mapping name to its dependencies, returns a list of names ordered according to dependencies.
-    Ignores errors relating to missing dependeencies or circular dependencies
-    """
-
-    visited = {}
-    result = []
-
-    def inner(name):
-        visited[name] = True
-
-        for dep in dependencies.get(name, []):
-            if dep in dependencies and dep not in visited:
-                inner(dep)
-
-        result.append(name)
-
-    for depname in dependencies:
-        if depname not in visited:
-            inner(depname)
-
-    return result
-
 
 @dataclass
 class ScriptWithDependencies:
@@ -562,6 +540,25 @@ class ScriptRunner:
         self.paste_field_names = []
         self.inputs = [None]
 
+        self.callback_map = {}
+        self.callback_names = [
+            'before_process',
+            'process',
+            'before_process_batch',
+            'after_extra_networks_activate',
+            'process_batch',
+            'postprocess',
+            'postprocess_batch',
+            'postprocess_batch_list',
+            'post_sample',
+            'on_mask_blend',
+            'postprocess_image',
+            'postprocess_maskoverlay',
+            'postprocess_image_after_composite',
+            'before_component',
+            'after_component',
+        ]
+
         self.on_before_component_elem_id = {}
         """dict of callbacks to be called before an element is created; key=elem_id, value=list of callbacks"""
 
@@ -600,6 +597,8 @@ class ScriptRunner:
                 self.scripts.append(script)
                 self.selectable_scripts.append(script)
 
+        self.callback_map.clear()
+
         self.apply_on_before_component_callbacks()
 
     def apply_on_before_component_callbacks(self):
@@ -769,8 +768,42 @@ class ScriptRunner:
 
         return processed
 
+    def list_scripts_for_method(self, method_name):
+        if method_name in ('before_component', 'after_component'):
+            return self.scripts
+        else:
+            return self.alwayson_scripts
+
+    def create_ordered_callbacks_list(self,  method_name, *, enable_user_sort=True):
+        script_list = self.list_scripts_for_method(method_name)
+        category = f'script_{method_name}'
+        callbacks = []
+
+        for script in script_list:
+            if getattr(script.__class__, method_name, None) == getattr(Script, method_name, None):
+                continue
+
+            script_callbacks.add_callback(callbacks, script, category=category, name=script.__class__.__name__, filename=script.filename)
+
+        return script_callbacks.sort_callbacks(category, callbacks, enable_user_sort=enable_user_sort)
+
+    def ordered_callbacks(self, method_name, *, enable_user_sort=True):
+        script_list = self.list_scripts_for_method(method_name)
+        category = f'script_{method_name}'
+
+        scrpts_len, callbacks = self.callback_map.get(category, (-1, None))
+
+        if callbacks is None or scrpts_len != len(script_list):
+            callbacks = self.create_ordered_callbacks_list(method_name, enable_user_sort=enable_user_sort)
+            self.callback_map[category] = len(script_list), callbacks
+
+        return callbacks
+
+    def ordered_scripts(self, method_name):
+        return [x.callback for x in self.ordered_callbacks(method_name)]
+
     def before_process(self, p):
-        for script in self.alwayson_scripts:
+        for script in self.ordered_scripts('before_process'):
             try:
                 script_args = p.script_args[script.args_from:script.args_to]
                 script.before_process(p, *script_args)
@@ -778,7 +811,7 @@ class ScriptRunner:
                 errors.report(f"Error running before_process: {script.filename}", exc_info=True)
 
     def process(self, p):
-        for script in self.alwayson_scripts:
+        for script in self.ordered_scripts('process'):
             try:
                 script_args = p.script_args[script.args_from:script.args_to]
                 script.process(p, *script_args)
@@ -786,7 +819,7 @@ class ScriptRunner:
                 errors.report(f"Error running process: {script.filename}", exc_info=True)
 
     def before_process_batch(self, p, **kwargs):
-        for script in self.alwayson_scripts:
+        for script in self.ordered_scripts('before_process_batch'):
             try:
                 script_args = p.script_args[script.args_from:script.args_to]
                 script.before_process_batch(p, *script_args, **kwargs)
@@ -794,7 +827,7 @@ class ScriptRunner:
                 errors.report(f"Error running before_process_batch: {script.filename}", exc_info=True)
 
     def after_extra_networks_activate(self, p, **kwargs):
-        for script in self.alwayson_scripts:
+        for script in self.ordered_scripts('after_extra_networks_activate'):
             try:
                 script_args = p.script_args[script.args_from:script.args_to]
                 script.after_extra_networks_activate(p, *script_args, **kwargs)
@@ -802,7 +835,7 @@ class ScriptRunner:
                 errors.report(f"Error running after_extra_networks_activate: {script.filename}", exc_info=True)
 
     def process_batch(self, p, **kwargs):
-        for script in self.alwayson_scripts:
+        for script in self.ordered_scripts('process_batch'):
             try:
                 script_args = p.script_args[script.args_from:script.args_to]
                 script.process_batch(p, *script_args, **kwargs)
@@ -810,7 +843,7 @@ class ScriptRunner:
                 errors.report(f"Error running process_batch: {script.filename}", exc_info=True)
 
     def postprocess(self, p, processed):
-        for script in self.alwayson_scripts:
+        for script in self.ordered_scripts('postprocess'):
             try:
                 script_args = p.script_args[script.args_from:script.args_to]
                 script.postprocess(p, processed, *script_args)
@@ -818,7 +851,7 @@ class ScriptRunner:
                 errors.report(f"Error running postprocess: {script.filename}", exc_info=True)
 
     def postprocess_batch(self, p, images, **kwargs):
-        for script in self.alwayson_scripts:
+        for script in self.ordered_scripts('postprocess_batch'):
             try:
                 script_args = p.script_args[script.args_from:script.args_to]
                 script.postprocess_batch(p, *script_args, images=images, **kwargs)
@@ -826,7 +859,7 @@ class ScriptRunner:
                 errors.report(f"Error running postprocess_batch: {script.filename}", exc_info=True)
 
     def postprocess_batch_list(self, p, pp: PostprocessBatchListArgs, **kwargs):
-        for script in self.alwayson_scripts:
+        for script in self.ordered_scripts('postprocess_batch_list'):
             try:
                 script_args = p.script_args[script.args_from:script.args_to]
                 script.postprocess_batch_list(p, pp, *script_args, **kwargs)
@@ -834,7 +867,7 @@ class ScriptRunner:
                 errors.report(f"Error running postprocess_batch_list: {script.filename}", exc_info=True)
 
     def post_sample(self, p, ps: PostSampleArgs):
-        for script in self.alwayson_scripts:
+        for script in self.ordered_scripts('post_sample'):
             try:
                 script_args = p.script_args[script.args_from:script.args_to]
                 script.post_sample(p, ps, *script_args)
@@ -842,7 +875,7 @@ class ScriptRunner:
                 errors.report(f"Error running post_sample: {script.filename}", exc_info=True)
 
     def on_mask_blend(self, p, mba: MaskBlendArgs):
-        for script in self.alwayson_scripts:
+        for script in self.ordered_scripts('on_mask_blend'):
             try:
                 script_args = p.script_args[script.args_from:script.args_to]
                 script.on_mask_blend(p, mba, *script_args)
@@ -850,7 +883,7 @@ class ScriptRunner:
                 errors.report(f"Error running post_sample: {script.filename}", exc_info=True)
 
     def postprocess_image(self, p, pp: PostprocessImageArgs):
-        for script in self.alwayson_scripts:
+        for script in self.ordered_scripts('postprocess_image'):
             try:
                 script_args = p.script_args[script.args_from:script.args_to]
                 script.postprocess_image(p, pp, *script_args)
@@ -858,7 +891,7 @@ class ScriptRunner:
                 errors.report(f"Error running postprocess_image: {script.filename}", exc_info=True)
 
     def postprocess_maskoverlay(self, p, ppmo: PostProcessMaskOverlayArgs):
-        for script in self.alwayson_scripts:
+        for script in self.ordered_scripts('postprocess_maskoverlay'):
             try:
                 script_args = p.script_args[script.args_from:script.args_to]
                 script.postprocess_maskoverlay(p, ppmo, *script_args)
@@ -866,7 +899,7 @@ class ScriptRunner:
                 errors.report(f"Error running postprocess_image: {script.filename}", exc_info=True)
 
     def postprocess_image_after_composite(self, p, pp: PostprocessImageArgs):
-        for script in self.alwayson_scripts:
+        for script in self.ordered_scripts('postprocess_image_after_composite'):
             try:
                 script_args = p.script_args[script.args_from:script.args_to]
                 script.postprocess_image_after_composite(p, pp, *script_args)
@@ -880,7 +913,7 @@ class ScriptRunner:
             except Exception:
                 errors.report(f"Error running on_before_component: {script.filename}", exc_info=True)
 
-        for script in self.scripts:
+        for script in self.ordered_scripts('before_component'):
             try:
                 script.before_component(component, **kwargs)
             except Exception:
@@ -893,7 +926,7 @@ class ScriptRunner:
             except Exception:
                 errors.report(f"Error running on_after_component: {script.filename}", exc_info=True)
 
-        for script in self.scripts:
+        for script in self.ordered_scripts('after_component'):
             try:
                 script.after_component(component, **kwargs)
             except Exception:
@@ -921,7 +954,7 @@ class ScriptRunner:
                     self.scripts[si].args_to = args_to
 
     def before_hr(self, p):
-        for script in self.alwayson_scripts:
+        for script in self.ordered_scripts('before_hr'):
             try:
                 script_args = p.script_args[script.args_from:script.args_to]
                 script.before_hr(p, *script_args)
@@ -929,7 +962,7 @@ class ScriptRunner:
                 errors.report(f"Error running before_hr: {script.filename}", exc_info=True)
 
     def setup_scrips(self, p, *, is_ui=True):
-        for script in self.alwayson_scripts:
+        for script in self.ordered_scripts('setup'):
             if not is_ui and script.setup_for_ui_only:
                 continue
 

+ 42 - 0
modules/shared_items.py

@@ -1,5 +1,8 @@
+import html
 import sys
 
+from modules import script_callbacks, scripts, ui_components
+from modules.options import OptionHTML, OptionInfo
 from modules.shared_cmd_options import cmd_opts
 
 
@@ -118,6 +121,45 @@ def ui_reorder_categories():
     yield "scripts"
 
 
+def callbacks_order_settings():
+    options = {
+        "sd_vae_explanation": OptionHTML("""
+    For categories below, callbacks added to dropdowns happen before others, in order listed.
+    """),
+
+    }
+
+    callback_options = {}
+
+    for category, _ in script_callbacks.enumerate_callbacks():
+        callback_options[category] = script_callbacks.ordered_callbacks(category, enable_user_sort=False)
+
+    for method_name in scripts.scripts_txt2img.callback_names:
+        callback_options["script_" + method_name] = scripts.scripts_txt2img.create_ordered_callbacks_list(method_name, enable_user_sort=False)
+
+    for method_name in scripts.scripts_img2img.callback_names:
+        callbacks = callback_options.get("script_" + method_name, [])
+
+        for addition in scripts.scripts_img2img.create_ordered_callbacks_list(method_name, enable_user_sort=False):
+            if any(x.name == addition.name for x in callbacks):
+                continue
+
+            callbacks.append(addition)
+
+        callback_options["script_" + method_name] = callbacks
+
+    for category, callbacks in callback_options.items():
+        if not callbacks:
+            continue
+
+        option_info = OptionInfo([], f"{category} callback priority", ui_components.DropdownMulti, {"choices": [x.name for x in callbacks]})
+        option_info.needs_restart()
+        option_info.html("<div class='info'>Default order: <ol>" + "".join(f"<li>{html.escape(x.name)}</li>\n" for x in callbacks) + "</ol></div>")
+        options['prioritized_callbacks_' + category] = option_info
+
+    return options
+
+
 class Shared(sys.modules[__name__].__class__):
     """
     this class is here to provide sd_model field as a property, so that it can be created and loaded on demand rather than

+ 7 - 1
modules/ui_settings.py

@@ -1,7 +1,8 @@
 import gradio as gr
 
-from modules import ui_common, shared, script_callbacks, scripts, sd_models, sysinfo, timer
+from modules import ui_common, shared, script_callbacks, scripts, sd_models, sysinfo, timer, shared_items
 from modules.call_queue import wrap_gradio_call
+from modules.options import options_section
 from modules.shared import opts
 from modules.ui_components import FormRow
 from modules.ui_gradio_extensions import reload_javascript
@@ -108,6 +109,11 @@ class UiSettings:
 
         shared.settings_components = self.component_dict
 
+        # we add this as late as possible so that scripts have already registered their callbacks
+        opts.data_labels.update(options_section(('callbacks', "Callbacks", "system"), {
+            **shared_items.callbacks_order_settings(),
+        }))
+
         opts.reorder()
 
         with gr.Blocks(analytics_enabled=False) as settings_interface:

+ 23 - 5
modules/util.py

@@ -148,8 +148,26 @@ class MassFileLister:
         """Clear the cache of all directories."""
         self.cached_dirs.clear()
 
-    def update_file_entry(self, path):
-        """Update the cache for a specific directory."""
-        dirname, filename = os.path.split(path)
-        if cached_dir := self.cached_dirs.get(dirname):
-            cached_dir.update_entry(filename)
+
+def topological_sort(dependencies):
+    """Accepts a dictionary mapping name to its dependencies, returns a list of names ordered according to dependencies.
+    Ignores errors relating to missing dependeencies or circular dependencies
+    """
+
+    visited = {}
+    result = []
+
+    def inner(name):
+        visited[name] = True
+
+        for dep in dependencies.get(name, []):
+            if dep in dependencies and dep not in visited:
+                inner(dep)
+
+        result.append(name)
+
+    for depname in dependencies:
+        if depname not in visited:
+            inner(depname)
+
+    return result

+ 4 - 0
style.css

@@ -528,6 +528,10 @@ table.popup-table .link{
     opacity: 0.75;
 }
 
+.settings-comment .info ol{
+    margin: 0.4em 0 0.8em 1em;
+}
+
 #sysinfo_download a.sysinfo_big_link{
     font-size: 24pt;
 }