script_callbacks.py 8.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257
  1. import sys
  2. import traceback
  3. from collections import namedtuple
  4. import inspect
  5. from typing import Optional
  6. from fastapi import FastAPI
  7. from gradio import Blocks
  8. def report_exception(c, job):
  9. print(f"Error executing callback {job} for {c.script}", file=sys.stderr)
  10. print(traceback.format_exc(), file=sys.stderr)
  11. class ImageSaveParams:
  12. def __init__(self, image, p, filename, pnginfo):
  13. self.image = image
  14. """the PIL image itself"""
  15. self.p = p
  16. """p object with processing parameters; either StableDiffusionProcessing or an object with same fields"""
  17. self.filename = filename
  18. """name of file that the image would be saved to"""
  19. self.pnginfo = pnginfo
  20. """dictionary with parameters for image's PNG info data; infotext will have the key 'parameters'"""
  21. class CFGDenoiserParams:
  22. def __init__(self, x, image_cond, sigma, sampling_step, total_sampling_steps):
  23. self.x = x
  24. """Latent image representation in the process of being denoised"""
  25. self.image_cond = image_cond
  26. """Conditioning image"""
  27. self.sigma = sigma
  28. """Current sigma noise step value"""
  29. self.sampling_step = sampling_step
  30. """Current Sampling step number"""
  31. self.total_sampling_steps = total_sampling_steps
  32. """Total number of sampling steps planned"""
  33. class UiTrainTabParams:
  34. def __init__(self, txt2img_preview_params):
  35. self.txt2img_preview_params = txt2img_preview_params
  36. ScriptCallback = namedtuple("ScriptCallback", ["script", "callback"])
  37. callback_map = dict(
  38. callbacks_app_started=[],
  39. callbacks_model_loaded=[],
  40. callbacks_ui_tabs=[],
  41. callbacks_ui_train_tabs=[],
  42. callbacks_ui_settings=[],
  43. callbacks_before_image_saved=[],
  44. callbacks_image_saved=[],
  45. callbacks_cfg_denoiser=[],
  46. callbacks_before_component=[],
  47. callbacks_after_component=[],
  48. )
  49. def clear_callbacks():
  50. for callback_list in callback_map.values():
  51. callback_list.clear()
  52. def app_started_callback(demo: Optional[Blocks], app: FastAPI):
  53. for c in callback_map['callbacks_app_started']:
  54. try:
  55. c.callback(demo, app)
  56. except Exception:
  57. report_exception(c, 'app_started_callback')
  58. def model_loaded_callback(sd_model):
  59. for c in callback_map['callbacks_model_loaded']:
  60. try:
  61. c.callback(sd_model)
  62. except Exception:
  63. report_exception(c, 'model_loaded_callback')
  64. def ui_tabs_callback():
  65. res = []
  66. for c in callback_map['callbacks_ui_tabs']:
  67. try:
  68. res += c.callback() or []
  69. except Exception:
  70. report_exception(c, 'ui_tabs_callback')
  71. return res
  72. def ui_train_tabs_callback(params: UiTrainTabParams):
  73. for c in callback_map['callbacks_ui_train_tabs']:
  74. try:
  75. c.callback(params)
  76. except Exception:
  77. report_exception(c, 'callbacks_ui_train_tabs')
  78. def ui_settings_callback():
  79. for c in callback_map['callbacks_ui_settings']:
  80. try:
  81. c.callback()
  82. except Exception:
  83. report_exception(c, 'ui_settings_callback')
  84. def before_image_saved_callback(params: ImageSaveParams):
  85. for c in callback_map['callbacks_before_image_saved']:
  86. try:
  87. c.callback(params)
  88. except Exception:
  89. report_exception(c, 'before_image_saved_callback')
  90. def image_saved_callback(params: ImageSaveParams):
  91. for c in callback_map['callbacks_image_saved']:
  92. try:
  93. c.callback(params)
  94. except Exception:
  95. report_exception(c, 'image_saved_callback')
  96. def cfg_denoiser_callback(params: CFGDenoiserParams):
  97. for c in callback_map['callbacks_cfg_denoiser']:
  98. try:
  99. c.callback(params)
  100. except Exception:
  101. report_exception(c, 'cfg_denoiser_callback')
  102. def before_component_callback(component, **kwargs):
  103. for c in callback_map['callbacks_before_component']:
  104. try:
  105. c.callback(component, **kwargs)
  106. except Exception:
  107. report_exception(c, 'before_component_callback')
  108. def after_component_callback(component, **kwargs):
  109. for c in callback_map['callbacks_after_component']:
  110. try:
  111. c.callback(component, **kwargs)
  112. except Exception:
  113. report_exception(c, 'after_component_callback')
  114. def add_callback(callbacks, fun):
  115. stack = [x for x in inspect.stack() if x.filename != __file__]
  116. filename = stack[0].filename if len(stack) > 0 else 'unknown file'
  117. callbacks.append(ScriptCallback(filename, fun))
  118. def remove_current_script_callbacks():
  119. stack = [x for x in inspect.stack() if x.filename != __file__]
  120. filename = stack[0].filename if len(stack) > 0 else 'unknown file'
  121. if filename == 'unknown file':
  122. return
  123. for callback_list in callback_map.values():
  124. for callback_to_remove in [cb for cb in callback_list if cb.script == filename]:
  125. callback_list.remove(callback_to_remove)
  126. def remove_callbacks_for_function(callback_func):
  127. for callback_list in callback_map.values():
  128. for callback_to_remove in [cb for cb in callback_list if cb.callback == callback_func]:
  129. callback_list.remove(callback_to_remove)
  130. def on_app_started(callback):
  131. """register a function to be called when the webui started, the gradio `Block` component and
  132. fastapi `FastAPI` object are passed as the arguments"""
  133. add_callback(callback_map['callbacks_app_started'], callback)
  134. def on_model_loaded(callback):
  135. """register a function to be called when the stable diffusion model is created; the model is
  136. passed as an argument"""
  137. add_callback(callback_map['callbacks_model_loaded'], callback)
  138. def on_ui_tabs(callback):
  139. """register a function to be called when the UI is creating new tabs.
  140. The function must either return a None, which means no new tabs to be added, or a list, where
  141. each element is a tuple:
  142. (gradio_component, title, elem_id)
  143. gradio_component is a gradio component to be used for contents of the tab (usually gr.Blocks)
  144. title is tab text displayed to user in the UI
  145. elem_id is HTML id for the tab
  146. """
  147. add_callback(callback_map['callbacks_ui_tabs'], callback)
  148. def on_ui_train_tabs(callback):
  149. """register a function to be called when the UI is creating new tabs for the train tab.
  150. Create your new tabs with gr.Tab.
  151. """
  152. add_callback(callback_map['callbacks_ui_train_tabs'], callback)
  153. def on_ui_settings(callback):
  154. """register a function to be called before UI settings are populated; add your settings
  155. by using shared.opts.add_option(shared.OptionInfo(...)) """
  156. add_callback(callback_map['callbacks_ui_settings'], callback)
  157. def on_before_image_saved(callback):
  158. """register a function to be called before an image is saved to a file.
  159. The callback is called with one argument:
  160. - params: ImageSaveParams - parameters the image is to be saved with. You can change fields in this object.
  161. """
  162. add_callback(callback_map['callbacks_before_image_saved'], callback)
  163. def on_image_saved(callback):
  164. """register a function to be called after an image is saved to a file.
  165. The callback is called with one argument:
  166. - params: ImageSaveParams - parameters the image was saved with. Changing fields in this object does nothing.
  167. """
  168. add_callback(callback_map['callbacks_image_saved'], callback)
  169. def on_cfg_denoiser(callback):
  170. """register a function to be called in the kdiffussion cfg_denoiser method after building the inner model inputs.
  171. The callback is called with one argument:
  172. - params: CFGDenoiserParams - parameters to be passed to the inner model and sampling state details.
  173. """
  174. add_callback(callback_map['callbacks_cfg_denoiser'], callback)
  175. def on_before_component(callback):
  176. """register a function to be called before a component is created.
  177. The callback is called with arguments:
  178. - component - gradio component that is about to be created.
  179. - **kwargs - args to gradio.components.IOComponent.__init__ function
  180. Use elem_id/label fields of kwargs to figure out which component it is.
  181. This can be useful to inject your own components somewhere in the middle of vanilla UI.
  182. """
  183. add_callback(callback_map['callbacks_before_component'], callback)
  184. def on_after_component(callback):
  185. """register a function to be called after a component is created. See on_before_component for more."""
  186. add_callback(callback_map['callbacks_after_component'], callback)