scripts_postprocessing.py 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256
  1. import re
  2. import dataclasses
  3. import os
  4. import gradio as gr
  5. from modules import errors, shared
  6. @dataclasses.dataclass
  7. class PostprocessedImageSharedInfo:
  8. target_width: int = None
  9. target_height: int = None
  10. class PostprocessedImage:
  11. def __init__(self, image):
  12. self.image = image
  13. self.info = {}
  14. self.shared = PostprocessedImageSharedInfo()
  15. self.extra_images = []
  16. self.nametags = []
  17. self.disable_processing = False
  18. self.caption = None
  19. def get_suffix(self, used_suffixes=None):
  20. used_suffixes = {} if used_suffixes is None else used_suffixes
  21. suffix = "-".join(self.nametags)
  22. if suffix:
  23. suffix = "-" + suffix
  24. if suffix not in used_suffixes:
  25. used_suffixes[suffix] = 1
  26. return suffix
  27. for i in range(1, 100):
  28. proposed_suffix = suffix + "-" + str(i)
  29. if proposed_suffix not in used_suffixes:
  30. used_suffixes[proposed_suffix] = 1
  31. return proposed_suffix
  32. return suffix
  33. def create_copy(self, new_image, *, nametags=None, disable_processing=False):
  34. pp = PostprocessedImage(new_image)
  35. pp.shared = self.shared
  36. pp.nametags = self.nametags.copy()
  37. pp.info = self.info.copy()
  38. pp.disable_processing = disable_processing
  39. if nametags is not None:
  40. pp.nametags += nametags
  41. return pp
  42. class ScriptPostprocessing:
  43. filename = None
  44. controls = None
  45. args_from = None
  46. args_to = None
  47. # define if the script should be used only in extras or main UI
  48. extra_only = None
  49. main_ui_only = None
  50. order = 1000
  51. """scripts will be ordred by this value in postprocessing UI"""
  52. name = None
  53. """this function should return the title of the script."""
  54. group = None
  55. """A gr.Group component that has all script's UI inside it"""
  56. def ui(self):
  57. """
  58. This function should create gradio UI elements. See https://gradio.app/docs/#components
  59. The return value should be a dictionary that maps parameter names to components used in processing.
  60. Values of those components will be passed to process() function.
  61. """
  62. pass
  63. def process(self, pp: PostprocessedImage, **args):
  64. """
  65. This function is called to postprocess the image.
  66. args contains a dictionary with all values returned by components from ui()
  67. """
  68. pass
  69. def process_firstpass(self, pp: PostprocessedImage, **args):
  70. """
  71. Called for all scripts before calling process(). Scripts can examine the image here and set fields
  72. of the pp object to communicate things to other scripts.
  73. args contains a dictionary with all values returned by components from ui()
  74. """
  75. pass
  76. def image_changed(self):
  77. pass
  78. tab_name = '' # used by ScriptPostprocessingForMainUI
  79. replace_pattern = re.compile(r'\s')
  80. rm_pattern = re.compile(r'[^a-z_0-9]')
  81. def elem_id(self, item_id):
  82. """
  83. Helper function to generate id for a HTML element
  84. constructs final id out of script name and user-supplied item_id
  85. 'script_extras_{self.name.lower()}_{item_id}'
  86. {tab_name} will append to the end of the id if set
  87. tab_name will be set to '_img2img' or '_txt2img' if use by ScriptPostprocessingForMainUI
  88. Extensions should use this function to generate element IDs
  89. """
  90. return self.elem_id_suffix(f'extras_{self.name.lower()}_{item_id}')
  91. def elem_id_suffix(self, base_id):
  92. """
  93. Append tab_name to the base_id
  94. Extensions that already have specific there element IDs and wish to keep their IDs the same when possible should use this function
  95. """
  96. base_id = self.rm_pattern.sub('', self.replace_pattern.sub('_', base_id))
  97. return f'{base_id}{self.tab_name}'
  98. def wrap_call(func, filename, funcname, *args, default=None, **kwargs):
  99. try:
  100. res = func(*args, **kwargs)
  101. return res
  102. except Exception as e:
  103. errors.display(e, f"calling {filename}/{funcname}")
  104. return default
  105. class ScriptPostprocessingRunner:
  106. def __init__(self):
  107. self.scripts = None
  108. self.ui_created = False
  109. def initialize_scripts(self, scripts_data):
  110. self.scripts = []
  111. for script_data in scripts_data:
  112. script: ScriptPostprocessing = script_data.script_class()
  113. script.filename = script_data.path
  114. self.scripts.append(script)
  115. def create_script_ui(self, script, inputs):
  116. script.args_from = len(inputs)
  117. script.args_to = len(inputs)
  118. script.controls = wrap_call(script.ui, script.filename, "ui")
  119. for control in script.controls.values():
  120. control.custom_script_source = os.path.basename(script.filename)
  121. inputs += list(script.controls.values())
  122. script.args_to = len(inputs)
  123. def scripts_in_preferred_order(self):
  124. if self.scripts is None:
  125. import modules.scripts
  126. self.initialize_scripts(modules.scripts.postprocessing_scripts_data)
  127. scripts_order = shared.opts.postprocessing_operation_order
  128. scripts_filter_out = set(shared.opts.postprocessing_disable_in_extras)
  129. def script_score(name):
  130. for i, possible_match in enumerate(scripts_order):
  131. if possible_match == name:
  132. return i
  133. return len(self.scripts)
  134. filtered_scripts = [script for script in self.scripts if script.name not in scripts_filter_out and not script.main_ui_only]
  135. script_scores = {script.name: (script_score(script.name), script.order, script.name, original_index) for original_index, script in enumerate(filtered_scripts)}
  136. return sorted(filtered_scripts, key=lambda x: script_scores[x.name])
  137. def setup_ui(self):
  138. inputs = []
  139. for script in self.scripts_in_preferred_order():
  140. with gr.Row() as group:
  141. self.create_script_ui(script, inputs)
  142. script.group = group
  143. self.ui_created = True
  144. return inputs
  145. def run(self, pp: PostprocessedImage, args):
  146. scripts = []
  147. for script in self.scripts_in_preferred_order():
  148. script_args = args[script.args_from:script.args_to]
  149. process_args = {}
  150. for (name, _component), value in zip(script.controls.items(), script_args):
  151. process_args[name] = value
  152. scripts.append((script, process_args))
  153. for script, process_args in scripts:
  154. script.process_firstpass(pp, **process_args)
  155. all_images = [pp]
  156. for script, process_args in scripts:
  157. if shared.state.skipped:
  158. break
  159. shared.state.job = script.name
  160. for single_image in all_images.copy():
  161. if not single_image.disable_processing:
  162. script.process(single_image, **process_args)
  163. for extra_image in single_image.extra_images:
  164. if not isinstance(extra_image, PostprocessedImage):
  165. extra_image = single_image.create_copy(extra_image)
  166. all_images.append(extra_image)
  167. single_image.extra_images.clear()
  168. pp.extra_images = all_images[1:]
  169. def create_args_for_run(self, scripts_args):
  170. if not self.ui_created:
  171. with gr.Blocks(analytics_enabled=False):
  172. self.setup_ui()
  173. scripts = self.scripts_in_preferred_order()
  174. args = [None] * max([x.args_to for x in scripts])
  175. for script in scripts:
  176. script_args_dict = scripts_args.get(script.name, None)
  177. if script_args_dict is not None:
  178. for i, name in enumerate(script.controls):
  179. args[script.args_from + i] = script_args_dict.get(name, None)
  180. return args
  181. def image_changed(self):
  182. for script in self.scripts_in_preferred_order():
  183. script.image_changed()