Răsfoiți Sursa

add callback for creating a tab in train UI

AUTOMATIC 2 ani în urmă
părinte
comite
1610b32584
2 a modificat fișierele cu 29 adăugiri și 2 ștergeri
  1. 25 2
      modules/script_callbacks.py
  2. 4 0
      modules/ui.py

+ 25 - 2
modules/script_callbacks.py

@@ -7,6 +7,7 @@ from typing import Optional
 from fastapi import FastAPI
 from fastapi import FastAPI
 from gradio import Blocks
 from gradio import Blocks
 
 
+
 def report_exception(c, job):
 def report_exception(c, job):
     print(f"Error executing callback {job} for {c.script}", file=sys.stderr)
     print(f"Error executing callback {job} for {c.script}", file=sys.stderr)
     print(traceback.format_exc(), file=sys.stderr)
     print(traceback.format_exc(), file=sys.stderr)
@@ -45,15 +46,21 @@ class CFGDenoiserParams:
         """Total number of sampling steps planned"""
         """Total number of sampling steps planned"""
 
 
 
 
+class UiTrainTabParams:
+    def __init__(self, txt2img_preview_params):
+        self.txt2img_preview_params = txt2img_preview_params
+
+
 ScriptCallback = namedtuple("ScriptCallback", ["script", "callback"])
 ScriptCallback = namedtuple("ScriptCallback", ["script", "callback"])
 callback_map = dict(
 callback_map = dict(
     callbacks_app_started=[],
     callbacks_app_started=[],
     callbacks_model_loaded=[],
     callbacks_model_loaded=[],
     callbacks_ui_tabs=[],
     callbacks_ui_tabs=[],
+    callbacks_ui_train_tabs=[],
     callbacks_ui_settings=[],
     callbacks_ui_settings=[],
     callbacks_before_image_saved=[],
     callbacks_before_image_saved=[],
     callbacks_image_saved=[],
     callbacks_image_saved=[],
-    callbacks_cfg_denoiser=[]
+    callbacks_cfg_denoiser=[],
 )
 )
 
 
 
 
@@ -61,6 +68,7 @@ def clear_callbacks():
     for callback_list in callback_map.values():
     for callback_list in callback_map.values():
         callback_list.clear()
         callback_list.clear()
 
 
+
 def app_started_callback(demo: Optional[Blocks], app: FastAPI):
 def app_started_callback(demo: Optional[Blocks], app: FastAPI):
     for c in callback_map['callbacks_app_started']:
     for c in callback_map['callbacks_app_started']:
         try:
         try:
@@ -79,7 +87,7 @@ def model_loaded_callback(sd_model):
 
 
 def ui_tabs_callback():
 def ui_tabs_callback():
     res = []
     res = []
-    
+
     for c in callback_map['callbacks_ui_tabs']:
     for c in callback_map['callbacks_ui_tabs']:
         try:
         try:
             res += c.callback() or []
             res += c.callback() or []
@@ -89,6 +97,14 @@ def ui_tabs_callback():
     return res
     return res
 
 
 
 
+def ui_train_tabs_callback(params: UiTrainTabParams):
+    for c in callback_map['callbacks_ui_train_tabs']:
+        try:
+            c.callback(params)
+        except Exception:
+            report_exception(c, 'callbacks_ui_train_tabs')
+
+
 def ui_settings_callback():
 def ui_settings_callback():
     for c in callback_map['callbacks_ui_settings']:
     for c in callback_map['callbacks_ui_settings']:
         try:
         try:
@@ -169,6 +185,13 @@ def on_ui_tabs(callback):
     add_callback(callback_map['callbacks_ui_tabs'], callback)
     add_callback(callback_map['callbacks_ui_tabs'], callback)
 
 
 
 
+def on_ui_train_tabs(callback):
+    """register a function to be called when the UI is creating new tabs for the train tab.
+    Create your new tabs with gr.Tab.
+    """
+    add_callback(callback_map['callbacks_ui_train_tabs'], callback)
+
+
 def on_ui_settings(callback):
 def on_ui_settings(callback):
     """register a function to be called before UI settings are populated; add your settings
     """register a function to be called before UI settings are populated; add your settings
     by using shared.opts.add_option(shared.OptionInfo(...)) """
     by using shared.opts.add_option(shared.OptionInfo(...)) """

+ 4 - 0
modules/ui.py

@@ -1270,6 +1270,10 @@ def create_ui(wrap_gradio_gpu_call):
                         train_hypernetwork = gr.Button(value="Train Hypernetwork", variant='primary')
                         train_hypernetwork = gr.Button(value="Train Hypernetwork", variant='primary')
                         train_embedding = gr.Button(value="Train Embedding", variant='primary')
                         train_embedding = gr.Button(value="Train Embedding", variant='primary')
 
 
+                params = script_callbacks.UiTrainTabParams(txt2img_preview_params)
+
+                script_callbacks.ui_train_tabs_callback(params)
+
             with gr.Column():
             with gr.Column():
                 progressbar = gr.HTML(elem_id="ti_progressbar")
                 progressbar = gr.HTML(elem_id="ti_progressbar")
                 ti_output = gr.Text(elem_id="ti_output", value="", show_label=False)
                 ti_output = gr.Text(elem_id="ti_output", value="", show_label=False)