Browse Source

Add a callback called before reloading the server

rucadi 2 years ago
parent
commit
eb5eb8aa11
2 changed files with 13 additions and 1 deletions
  1. 12 1
      modules/script_callbacks.py
  2. 1 0
      webui.py

+ 12 - 1
modules/script_callbacks.py

@@ -75,6 +75,7 @@ callback_map = dict(
     callbacks_script_unloaded=[],
     callbacks_before_ui=[],
     callbacks_on_polling=[],
+    callbacks_on_reload=[],
 )
 
 
@@ -82,7 +83,6 @@ def clear_callbacks():
     for callback_list in callback_map.values():
         callback_list.clear()
 
-
 def app_started_callback(demo: Optional[Blocks], app: FastAPI):
     for c in callback_map['callbacks_app_started']:
         try:
@@ -97,6 +97,14 @@ def app_polling_callback(demo: Optional[Blocks], app: FastAPI):
         except Exception:
             report_exception(c, 'callbacks_on_polling')
 
+def app_reload_callback(demo: Optional[Blocks], app: FastAPI):
+    for c in callback_map['callbacks_on_reload']:
+        try:
+            c.callback()
+        except Exception:
+            report_exception(c, 'callbacks_on_reload')
+
+
 def model_loaded_callback(sd_model):
     for c in callback_map['callbacks_model_loaded']:
         try:
@@ -238,6 +246,9 @@ def on_polling(callback):
     """register a function to be called on each polling of the server."""
     add_callback(callback_map['callbacks_on_polling'], callback)
 
+def on_before_reload(callback):
+    """register a function to be called just before the server reloads."""
+    add_callback(callback_map['callbacks_on_reload'], callback)
 
 def on_model_loaded(callback):
     """register a function to be called when the stable diffusion model is created; the model is

+ 1 - 0
webui.py

@@ -173,6 +173,7 @@ def wait_on_server(demo=None):
         time.sleep(0.5)
         modules.script_callbacks.app_polling_callback(None, demo)
         if shared.state.need_restart:
+            modules.script_callbacks.app_reload_callback(None, demo)
             shared.state.need_restart = False
             time.sleep(0.5)
             demo.close()