Parcourir la source

add names to callbacks

AUTOMATIC1111 il y a 1 an
Parent
commit
0411eced89
2 fichiers modifiés avec 91 ajouts et 52 suppressions
  1. 17 0
      modules/extensions.py
  2. 74 52
      modules/script_callbacks.py

+ 17 - 0
modules/extensions.py

@@ -186,6 +186,7 @@ class Extension:
 
 def list_extensions():
     extensions.clear()
+    extension_paths.clear()
 
     if shared.cmd_opts.disable_all_extensions:
         print("*** \"--disable-all-extensions\" arg was used, will not load any extensions ***")
@@ -220,6 +221,7 @@ def list_extensions():
             is_builtin = dirname == extensions_builtin_dir
             extension = Extension(name=extension_dirname, path=path, enabled=extension_dirname not in shared.opts.disabled_extensions, is_builtin=is_builtin, metadata=metadata)
             extensions.append(extension)
+            extension_paths[extension.path] = extension
             loaded_extensions[canonical_name] = extension
 
     # check for requirements
@@ -238,4 +240,19 @@ def list_extensions():
                 continue
 
 
+def find_extension(filename):
+    parentdir = os.path.dirname(os.path.realpath(filename))
+
+    while parentdir != filename:
+        extension = extension_paths.get(parentdir)
+        if extension is not None:
+            return extension
+
+        filename = parentdir
+        parentdir = os.path.dirname(filename)
+
+    return None
+
+
 extensions: list[Extension] = []
+extension_paths: dict[str, Extension] = {}

+ 74 - 52
modules/script_callbacks.py

@@ -1,13 +1,14 @@
+from __future__ import annotations
+
 import dataclasses
 import inspect
 import os
-from collections import namedtuple
 from typing import Optional, Any
 
 from fastapi import FastAPI
 from gradio import Blocks
 
-from modules import errors, timer
+from modules import errors, timer, extensions
 
 
 def report_exception(c, job):
@@ -116,7 +117,35 @@ class BeforeTokenCounterParams:
     is_positive: bool = True
 
 
-ScriptCallback = namedtuple("ScriptCallback", ["script", "callback"])
+@dataclasses.dataclass
+class ScriptCallback:
+    script: str
+    callback: ''
+    name: str = None
+
+
+def add_callback(callbacks, fun, *, name=None, category='unknown'):
+    stack = [x for x in inspect.stack() if x.filename != __file__]
+    filename = stack[0].filename if stack else 'unknown file'
+
+    extension = extensions.find_extension(filename)
+    extension_name = extension.canonical_name if extension else 'base'
+
+    callback_name = f"{extension_name}/{os.path.basename(filename)}/{category}"
+    if name is not None:
+        callback_name += f'/{name}'
+
+    unique_callback_name = callback_name
+    for index in range(1000):
+        existing = any(x.name == unique_callback_name for x in callbacks)
+        if not existing:
+            break
+
+        unique_callback_name = f'{callback_name}-{index+1}'
+
+    callbacks.append(ScriptCallback(filename, fun, unique_callback_name))
+
+
 callback_map = dict(
     callbacks_app_started=[],
     callbacks_model_loaded=[],
@@ -328,13 +357,6 @@ def before_token_counter_callback(params: BeforeTokenCounterParams):
             report_exception(c, 'before_token_counter')
 
 
-def add_callback(callbacks, fun):
-    stack = [x for x in inspect.stack() if x.filename != __file__]
-    filename = stack[0].filename if stack else 'unknown file'
-
-    callbacks.append(ScriptCallback(filename, fun))
-
-
 def remove_current_script_callbacks():
     stack = [x for x in inspect.stack() if x.filename != __file__]
     filename = stack[0].filename if stack else 'unknown file'
@@ -351,24 +373,24 @@ def remove_callbacks_for_function(callback_func):
             callback_list.remove(callback_to_remove)
 
 
-def on_app_started(callback):
+def on_app_started(callback, *, name=None):
     """register a function to be called when the webui started, the gradio `Block` component and
     fastapi `FastAPI` object are passed as the arguments"""
-    add_callback(callback_map['callbacks_app_started'], callback)
+    add_callback(callback_map['callbacks_app_started'], callback, name=name, category='app_started')
 
 
-def on_before_reload(callback):
+def on_before_reload(callback, *, name=None):
     """register a function to be called just before the server reloads."""
-    add_callback(callback_map['callbacks_on_reload'], callback)
+    add_callback(callback_map['callbacks_on_reload'], callback, name=name, category='on_reload')
 
 
-def on_model_loaded(callback):
+def on_model_loaded(callback, *, name=None):
     """register a function to be called when the stable diffusion model is created; the model is
     passed as an argument; this function is also called when the script is reloaded. """
-    add_callback(callback_map['callbacks_model_loaded'], callback)
+    add_callback(callback_map['callbacks_model_loaded'], callback, name=name, category='model_loaded')
 
 
-def on_ui_tabs(callback):
+def on_ui_tabs(callback, *, name=None):
     """register a function to be called when the UI is creating new tabs.
     The function must either return a None, which means no new tabs to be added, or a list, where
     each element is a tuple:
@@ -378,71 +400,71 @@ def on_ui_tabs(callback):
     title is tab text displayed to user in the UI
     elem_id is HTML id for the tab
     """
-    add_callback(callback_map['callbacks_ui_tabs'], callback)
+    add_callback(callback_map['callbacks_ui_tabs'], callback, name=name, category='ui_tabs')
 
 
-def on_ui_train_tabs(callback):
+def on_ui_train_tabs(callback, *, name=None):
     """register a function to be called when the UI is creating new tabs for the train tab.
     Create your new tabs with gr.Tab.
     """
-    add_callback(callback_map['callbacks_ui_train_tabs'], callback)
+    add_callback(callback_map['callbacks_ui_train_tabs'], callback, name=name, category='ui_train_tabs')
 
 
-def on_ui_settings(callback):
+def on_ui_settings(callback, *, name=None):
     """register a function to be called before UI settings are populated; add your settings
     by using shared.opts.add_option(shared.OptionInfo(...)) """
-    add_callback(callback_map['callbacks_ui_settings'], callback)
+    add_callback(callback_map['callbacks_ui_settings'], callback, name=name, category='ui_settings')
 
 
-def on_before_image_saved(callback):
+def on_before_image_saved(callback, *, name=None):
     """register a function to be called before an image is saved to a file.
     The callback is called with one argument:
         - params: ImageSaveParams - parameters the image is to be saved with. You can change fields in this object.
     """
-    add_callback(callback_map['callbacks_before_image_saved'], callback)
+    add_callback(callback_map['callbacks_before_image_saved'], callback, name=name, category='before_image_saved')
 
 
-def on_image_saved(callback):
+def on_image_saved(callback, *, name=None):
     """register a function to be called after an image is saved to a file.
     The callback is called with one argument:
         - params: ImageSaveParams - parameters the image was saved with. Changing fields in this object does nothing.
     """
-    add_callback(callback_map['callbacks_image_saved'], callback)
+    add_callback(callback_map['callbacks_image_saved'], callback, name=name, category='image_saved')
 
 
-def on_extra_noise(callback):
+def on_extra_noise(callback, *, name=None):
     """register a function to be called before adding extra noise in img2img or hires fix;
     The callback is called with one argument:
         - params: ExtraNoiseParams - contains noise determined by seed and latent representation of image
     """
-    add_callback(callback_map['callbacks_extra_noise'], callback)
+    add_callback(callback_map['callbacks_extra_noise'], callback, name=name, category='extra_noise')
 
 
-def on_cfg_denoiser(callback):
+def on_cfg_denoiser(callback, *, name=None):
     """register a function to be called in the kdiffussion cfg_denoiser method after building the inner model inputs.
     The callback is called with one argument:
         - params: CFGDenoiserParams - parameters to be passed to the inner model and sampling state details.
     """
-    add_callback(callback_map['callbacks_cfg_denoiser'], callback)
+    add_callback(callback_map['callbacks_cfg_denoiser'], callback, name=name, category='cfg_denoiser')
 
 
-def on_cfg_denoised(callback):
+def on_cfg_denoised(callback, *, name=None):
     """register a function to be called in the kdiffussion cfg_denoiser method after building the inner model inputs.
     The callback is called with one argument:
         - params: CFGDenoisedParams - parameters to be passed to the inner model and sampling state details.
     """
-    add_callback(callback_map['callbacks_cfg_denoised'], callback)
+    add_callback(callback_map['callbacks_cfg_denoised'], callback, name=name, category='cfg_denoised')
 
 
-def on_cfg_after_cfg(callback):
+def on_cfg_after_cfg(callback, *, name=None):
     """register a function to be called in the kdiffussion cfg_denoiser method after cfg calculations are completed.
     The callback is called with one argument:
         - params: AfterCFGCallbackParams - parameters to be passed to the script for post-processing after cfg calculation.
     """
-    add_callback(callback_map['callbacks_cfg_after_cfg'], callback)
+    add_callback(callback_map['callbacks_cfg_after_cfg'], callback, name=name, category='cfg_after_cfg')
 
 
-def on_before_component(callback):
+def on_before_component(callback, *, name=None):
     """register a function to be called before a component is created.
     The callback is called with arguments:
         - component - gradio component that is about to be created.
@@ -451,61 +473,61 @@ def on_before_component(callback):
     Use elem_id/label fields of kwargs to figure out which component it is.
     This can be useful to inject your own components somewhere in the middle of vanilla UI.
     """
-    add_callback(callback_map['callbacks_before_component'], callback)
+    add_callback(callback_map['callbacks_before_component'], callback, name=name, category='before_component')
 
 
-def on_after_component(callback):
+def on_after_component(callback, *, name=None):
     """register a function to be called after a component is created. See on_before_component for more."""
-    add_callback(callback_map['callbacks_after_component'], callback)
+    add_callback(callback_map['callbacks_after_component'], callback, name=name, category='after_component')
 
 
-def on_image_grid(callback):
+def on_image_grid(callback, *, name=None):
     """register a function to be called before making an image grid.
     The callback is called with one argument:
        - params: ImageGridLoopParams - parameters to be used for grid creation. Can be modified.
     """
-    add_callback(callback_map['callbacks_image_grid'], callback)
+    add_callback(callback_map['callbacks_image_grid'], callback, name=name, category='image_grid')
 
 
-def on_infotext_pasted(callback):
+def on_infotext_pasted(callback, *, name=None):
     """register a function to be called before applying an infotext.
     The callback is called with two arguments:
        - infotext: str - raw infotext.
        - result: dict[str, any] - parsed infotext parameters.
     """
-    add_callback(callback_map['callbacks_infotext_pasted'], callback)
+    add_callback(callback_map['callbacks_infotext_pasted'], callback, name=name, category='infotext_pasted')
 
 
-def on_script_unloaded(callback):
+def on_script_unloaded(callback, *, name=None):
     """register a function to be called before the script is unloaded. Any hooks/hijacks/monkeying about that
     the script did should be reverted here"""
 
-    add_callback(callback_map['callbacks_script_unloaded'], callback)
+    add_callback(callback_map['callbacks_script_unloaded'], callback, name=name, category='script_unloaded')
 
 
-def on_before_ui(callback):
+def on_before_ui(callback, *, name=None):
     """register a function to be called before the UI is created."""
 
-    add_callback(callback_map['callbacks_before_ui'], callback)
+    add_callback(callback_map['callbacks_before_ui'], callback, name=name, category='before_ui')
 
 
-def on_list_optimizers(callback):
+def on_list_optimizers(callback, *, name=None):
     """register a function to be called when UI is making a list of cross attention optimization options.
     The function will be called with one argument, a list, and shall add objects of type modules.sd_hijack_optimizations.SdOptimization
     to it."""
 
-    add_callback(callback_map['callbacks_list_optimizers'], callback)
+    add_callback(callback_map['callbacks_list_optimizers'], callback, name=name, category='list_optimizers')
 
 
-def on_list_unets(callback):
+def on_list_unets(callback, *, name=None):
     """register a function to be called when UI is making a list of alternative options for unet.
     The function will be called with one argument, a list, and shall add objects of type modules.sd_unet.SdUnetOption to it."""
 
-    add_callback(callback_map['callbacks_list_unets'], callback)
+    add_callback(callback_map['callbacks_list_unets'], callback, name=name, category='list_unets')
 
 
-def on_before_token_counter(callback):
+def on_before_token_counter(callback, *, name=None):
     """register a function to be called when UI is counting tokens for a prompt.
     The function will be called with one argument of type BeforeTokenCounterParams, and should modify its fields if necessary."""
 
-    add_callback(callback_map['callbacks_before_token_counter'], callback)
+    add_callback(callback_map['callbacks_before_token_counter'], callback, name=name, category='before_token_counter')