Эх сурвалжийг харах

Add API for scripts to add elements anywhere in UI.

AUTOMATIC 2 жил өмнө
parent
commit
3596af0749

+ 35 - 0
modules/script_callbacks.py

@@ -61,6 +61,8 @@ callback_map = dict(
     callbacks_before_image_saved=[],
     callbacks_image_saved=[],
     callbacks_cfg_denoiser=[],
+    callbacks_before_component=[],
+    callbacks_after_component=[],
 )
 
 
@@ -137,6 +139,22 @@ def cfg_denoiser_callback(params: CFGDenoiserParams):
             report_exception(c, 'cfg_denoiser_callback')
 
 
+def before_component_callback(component, **kwargs):
+    for c in callback_map['callbacks_before_component']:
+        try:
+            c.callback(component, **kwargs)
+        except Exception:
+            report_exception(c, 'before_component_callback')
+
+
+def after_component_callback(component, **kwargs):
+    for c in callback_map['callbacks_after_component']:
+        try:
+            c.callback(component, **kwargs)
+        except Exception:
+            report_exception(c, 'after_component_callback')
+
+
 def add_callback(callbacks, fun):
     stack = [x for x in inspect.stack() if x.filename != __file__]
     filename = stack[0].filename if len(stack) > 0 else 'unknown file'
@@ -220,3 +238,20 @@ def on_cfg_denoiser(callback):
         - params: CFGDenoiserParams - parameters to be passed to the inner model and sampling state details.
     """
     add_callback(callback_map['callbacks_cfg_denoiser'], callback)
+
+
+def on_before_component(callback):
+    """register a function to be called before a component is created.
+    The callback is called with arguments:
+        - component - gradio component that is about to be created.
+        - **kwargs - args to gradio.components.IOComponent.__init__ function
+
+    Use elem_id/label fields of kwargs to figure out which component it is.
+    This can be useful to inject your own components somewhere in the middle of vanilla UI.
+    """
+    add_callback(callback_map['callbacks_before_component'], callback)
+
+
+def on_after_component(callback):
+    """register a function to be called after a component is created. See on_before_component for more."""
+    add_callback(callback_map['callbacks_after_component'], callback)

+ 66 - 3
modules/scripts.py

@@ -17,6 +17,9 @@ class Script:
     args_to = None
     alwayson = False
 
+    is_txt2img = False
+    is_img2img = False
+
     """A gr.Group component that has all script's UI inside it"""
     group = None
 
@@ -93,6 +96,23 @@ class Script:
 
         pass
 
+    def before_component(self, component, **kwargs):
+        """
+        Called before a component is created.
+        Use elem_id/label fields of kwargs to figure out which component it is.
+        This can be useful to inject your own components somewhere in the middle of vanilla UI.
+        You can return created components in the ui() function to add them to the list of arguments for your processing functions
+        """
+
+        pass
+
+    def after_component(self, component, **kwargs):
+        """
+        Called after a component is created. Same as above.
+        """
+
+        pass
+
     def describe(self):
         """unused"""
         return ""
@@ -195,12 +215,18 @@ class ScriptRunner:
         self.titles = []
         self.infotext_fields = []
 
-    def setup_ui(self, is_img2img):
+    def initialize_scripts(self, is_img2img):
+        self.scripts.clear()
+        self.alwayson_scripts.clear()
+        self.selectable_scripts.clear()
+
         for script_class, path, basedir in scripts_data:
             script = script_class()
             script.filename = path
+            script.is_txt2img = not is_img2img
+            script.is_img2img = is_img2img
 
-            visibility = script.show(is_img2img)
+            visibility = script.show(script.is_img2img)
 
             if visibility == AlwaysVisible:
                 self.scripts.append(script)
@@ -211,6 +237,7 @@ class ScriptRunner:
                 self.scripts.append(script)
                 self.selectable_scripts.append(script)
 
+    def setup_ui(self):
         self.titles = [wrap_call(script.title, script.filename, "title") or f"{script.filename} [error]" for script in self.selectable_scripts]
 
         inputs = [None]
@@ -220,7 +247,7 @@ class ScriptRunner:
             script.args_from = len(inputs)
             script.args_to = len(inputs)
 
-            controls = wrap_call(script.ui, script.filename, "ui", is_img2img)
+            controls = wrap_call(script.ui, script.filename, "ui", script.is_img2img)
 
             if controls is None:
                 return
@@ -320,6 +347,22 @@ class ScriptRunner:
                 print(f"Error running postprocess: {script.filename}", file=sys.stderr)
                 print(traceback.format_exc(), file=sys.stderr)
 
+    def before_component(self, component, **kwargs):
+        for script in self.scripts:
+            try:
+                script.before_component(component, **kwargs)
+            except Exception:
+                print(f"Error running before_component: {script.filename}", file=sys.stderr)
+                print(traceback.format_exc(), file=sys.stderr)
+
+    def after_component(self, component, **kwargs):
+        for script in self.scripts:
+            try:
+                script.after_component(component, **kwargs)
+            except Exception:
+                print(f"Error running after_component: {script.filename}", file=sys.stderr)
+                print(traceback.format_exc(), file=sys.stderr)
+
     def reload_sources(self, cache):
         for si, script in list(enumerate(self.scripts)):
             args_from = script.args_from
@@ -341,6 +384,7 @@ class ScriptRunner:
 
 scripts_txt2img = ScriptRunner()
 scripts_img2img = ScriptRunner()
+scripts_current: ScriptRunner = None
 
 
 def reload_script_body_only():
@@ -357,3 +401,22 @@ def reload_scripts():
     scripts_txt2img = ScriptRunner()
     scripts_img2img = ScriptRunner()
 
+
+def IOComponent_init(self, *args, **kwargs):
+    if scripts_current is not None:
+        scripts_current.before_component(self, **kwargs)
+
+    script_callbacks.before_component_callback(self, **kwargs)
+
+    res = original_IOComponent_init(self, *args, **kwargs)
+
+    script_callbacks.after_component_callback(self, **kwargs)
+
+    if scripts_current is not None:
+        scripts_current.after_component(self, **kwargs)
+
+    return res
+
+
+original_IOComponent_init = gr.components.IOComponent.__init__
+gr.components.IOComponent.__init__ = IOComponent_init

+ 10 - 2
modules/ui.py

@@ -695,6 +695,9 @@ def create_ui(wrap_gradio_gpu_call):
 
     parameters_copypaste.reset()
 
+    modules.scripts.scripts_current = modules.scripts.scripts_txt2img
+    modules.scripts.scripts_txt2img.initialize_scripts(is_img2img=False)
+
     with gr.Blocks(analytics_enabled=False) as txt2img_interface:
         txt2img_prompt, roll, txt2img_prompt_style, txt2img_negative_prompt, txt2img_prompt_style2, submit, _, _, txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, token_counter, token_button = create_toprow(is_img2img=False)
         dummy_component = gr.Label(visible=False)
@@ -737,7 +740,7 @@ def create_ui(wrap_gradio_gpu_call):
                 seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox = create_seed_inputs()
 
                 with gr.Group():
-                    custom_inputs = modules.scripts.scripts_txt2img.setup_ui(is_img2img=False)
+                    custom_inputs = modules.scripts.scripts_txt2img.setup_ui()
 
             txt2img_gallery, generation_info, html_info = create_output_panel("txt2img", opts.outdir_txt2img_samples)
             parameters_copypaste.bind_buttons({"txt2img": txt2img_paste}, None, txt2img_prompt)
@@ -846,6 +849,9 @@ def create_ui(wrap_gradio_gpu_call):
 
             token_button.click(fn=update_token_counter, inputs=[txt2img_prompt, steps], outputs=[token_counter])
 
+    modules.scripts.scripts_current = modules.scripts.scripts_img2img
+    modules.scripts.scripts_img2img.initialize_scripts(is_img2img=True)
+
     with gr.Blocks(analytics_enabled=False) as img2img_interface:
         img2img_prompt, roll, img2img_prompt_style, img2img_negative_prompt, img2img_prompt_style2, submit, img2img_interrogate, img2img_deepbooru, img2img_prompt_style_apply, img2img_save_style, img2img_paste, token_counter, token_button = create_toprow(is_img2img=True)
 
@@ -916,7 +922,7 @@ def create_ui(wrap_gradio_gpu_call):
                 seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox = create_seed_inputs()
 
                 with gr.Group():
-                    custom_inputs = modules.scripts.scripts_img2img.setup_ui(is_img2img=True)
+                    custom_inputs = modules.scripts.scripts_img2img.setup_ui()
 
             img2img_gallery, generation_info, html_info = create_output_panel("img2img", opts.outdir_img2img_samples)
             parameters_copypaste.bind_buttons({"img2img": img2img_paste}, None, img2img_prompt)
@@ -1065,6 +1071,8 @@ def create_ui(wrap_gradio_gpu_call):
             parameters_copypaste.add_paste_fields("img2img", init_img, img2img_paste_fields)
             parameters_copypaste.add_paste_fields("inpaint", init_img_with_mask, img2img_paste_fields)
 
+    modules.scripts.scripts_current = None
+
     with gr.Blocks(analytics_enabled=False) as extras_interface:
         with gr.Row().style(equal_height=False):
             with gr.Column(variant='panel'):