|
@@ -7,6 +7,7 @@ from typing import Optional
|
|
|
from fastapi import FastAPI
|
|
|
from gradio import Blocks
|
|
|
|
|
|
+
|
|
|
def report_exception(c, job):
|
|
|
print(f"Error executing callback {job} for {c.script}", file=sys.stderr)
|
|
|
print(traceback.format_exc(), file=sys.stderr)
|
|
@@ -45,15 +46,21 @@ class CFGDenoiserParams:
|
|
|
"""Total number of sampling steps planned"""
|
|
|
|
|
|
|
|
|
+class UiTrainTabParams:
|
|
|
+ def __init__(self, txt2img_preview_params):
|
|
|
+ self.txt2img_preview_params = txt2img_preview_params
|
|
|
+
|
|
|
+
|
|
|
ScriptCallback = namedtuple("ScriptCallback", ["script", "callback"])
|
|
|
callback_map = dict(
|
|
|
callbacks_app_started=[],
|
|
|
callbacks_model_loaded=[],
|
|
|
callbacks_ui_tabs=[],
|
|
|
+ callbacks_ui_train_tabs=[],
|
|
|
callbacks_ui_settings=[],
|
|
|
callbacks_before_image_saved=[],
|
|
|
callbacks_image_saved=[],
|
|
|
- callbacks_cfg_denoiser=[]
|
|
|
+ callbacks_cfg_denoiser=[],
|
|
|
)
|
|
|
|
|
|
|
|
@@ -61,6 +68,7 @@ def clear_callbacks():
|
|
|
for callback_list in callback_map.values():
|
|
|
callback_list.clear()
|
|
|
|
|
|
+
|
|
|
def app_started_callback(demo: Optional[Blocks], app: FastAPI):
|
|
|
for c in callback_map['callbacks_app_started']:
|
|
|
try:
|
|
@@ -79,7 +87,7 @@ def model_loaded_callback(sd_model):
|
|
|
|
|
|
def ui_tabs_callback():
|
|
|
res = []
|
|
|
-
|
|
|
+
|
|
|
for c in callback_map['callbacks_ui_tabs']:
|
|
|
try:
|
|
|
res += c.callback() or []
|
|
@@ -89,6 +97,14 @@ def ui_tabs_callback():
|
|
|
return res
|
|
|
|
|
|
|
|
|
+def ui_train_tabs_callback(params: UiTrainTabParams):
|
|
|
+ for c in callback_map['callbacks_ui_train_tabs']:
|
|
|
+ try:
|
|
|
+ c.callback(params)
|
|
|
+ except Exception:
|
|
|
+ report_exception(c, 'callbacks_ui_train_tabs')
|
|
|
+
|
|
|
+
|
|
|
def ui_settings_callback():
|
|
|
for c in callback_map['callbacks_ui_settings']:
|
|
|
try:
|
|
@@ -169,6 +185,13 @@ def on_ui_tabs(callback):
|
|
|
add_callback(callback_map['callbacks_ui_tabs'], callback)
|
|
|
|
|
|
|
|
|
+def on_ui_train_tabs(callback):
|
|
|
+ """register a function to be called when the UI is creating new tabs for the train tab.
|
|
|
+ Create your new tabs with gr.Tab.
|
|
|
+ """
|
|
|
+ add_callback(callback_map['callbacks_ui_train_tabs'], callback)
|
|
|
+
|
|
|
+
|
|
|
def on_ui_settings(callback):
|
|
|
"""register a function to be called before UI settings are populated; add your settings
|
|
|
by using shared.opts.add_option(shared.OptionInfo(...)) """
|