scripts_postprocessing.py 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228
  1. import dataclasses
  2. import os
  3. import gradio as gr
  4. from modules import errors, shared
  5. @dataclasses.dataclass
  6. class PostprocessedImageSharedInfo:
  7. target_width: int = None
  8. target_height: int = None
  9. class PostprocessedImage:
  10. def __init__(self, image):
  11. self.image = image
  12. self.info = {}
  13. self.shared = PostprocessedImageSharedInfo()
  14. self.extra_images = []
  15. self.nametags = []
  16. self.disable_processing = False
  17. self.caption = None
  18. def get_suffix(self, used_suffixes=None):
  19. used_suffixes = {} if used_suffixes is None else used_suffixes
  20. suffix = "-".join(self.nametags)
  21. if suffix:
  22. suffix = "-" + suffix
  23. if suffix not in used_suffixes:
  24. used_suffixes[suffix] = 1
  25. return suffix
  26. for i in range(1, 100):
  27. proposed_suffix = suffix + "-" + str(i)
  28. if proposed_suffix not in used_suffixes:
  29. used_suffixes[proposed_suffix] = 1
  30. return proposed_suffix
  31. return suffix
  32. def create_copy(self, new_image, *, nametags=None, disable_processing=False):
  33. pp = PostprocessedImage(new_image)
  34. pp.shared = self.shared
  35. pp.nametags = self.nametags.copy()
  36. pp.info = self.info.copy()
  37. pp.disable_processing = disable_processing
  38. if nametags is not None:
  39. pp.nametags += nametags
  40. return pp
  41. class ScriptPostprocessing:
  42. filename = None
  43. controls = None
  44. args_from = None
  45. args_to = None
  46. order = 1000
  47. """scripts will be ordred by this value in postprocessing UI"""
  48. name = None
  49. """this function should return the title of the script."""
  50. group = None
  51. """A gr.Group component that has all script's UI inside it"""
  52. def ui(self):
  53. """
  54. This function should create gradio UI elements. See https://gradio.app/docs/#components
  55. The return value should be a dictionary that maps parameter names to components used in processing.
  56. Values of those components will be passed to process() function.
  57. """
  58. pass
  59. def process(self, pp: PostprocessedImage, **args):
  60. """
  61. This function is called to postprocess the image.
  62. args contains a dictionary with all values returned by components from ui()
  63. """
  64. pass
  65. def process_firstpass(self, pp: PostprocessedImage, **args):
  66. """
  67. Called for all scripts before calling process(). Scripts can examine the image here and set fields
  68. of the pp object to communicate things to other scripts.
  69. args contains a dictionary with all values returned by components from ui()
  70. """
  71. pass
  72. def image_changed(self):
  73. pass
  74. def wrap_call(func, filename, funcname, *args, default=None, **kwargs):
  75. try:
  76. res = func(*args, **kwargs)
  77. return res
  78. except Exception as e:
  79. errors.display(e, f"calling {filename}/{funcname}")
  80. return default
  81. class ScriptPostprocessingRunner:
  82. def __init__(self):
  83. self.scripts = None
  84. self.ui_created = False
  85. def initialize_scripts(self, scripts_data):
  86. self.scripts = []
  87. for script_data in scripts_data:
  88. script: ScriptPostprocessing = script_data.script_class()
  89. script.filename = script_data.path
  90. if script.name == "Simple Upscale":
  91. continue
  92. self.scripts.append(script)
  93. def create_script_ui(self, script, inputs):
  94. script.args_from = len(inputs)
  95. script.args_to = len(inputs)
  96. script.controls = wrap_call(script.ui, script.filename, "ui")
  97. for control in script.controls.values():
  98. control.custom_script_source = os.path.basename(script.filename)
  99. inputs += list(script.controls.values())
  100. script.args_to = len(inputs)
  101. def scripts_in_preferred_order(self):
  102. if self.scripts is None:
  103. import modules.scripts
  104. self.initialize_scripts(modules.scripts.postprocessing_scripts_data)
  105. scripts_order = shared.opts.postprocessing_operation_order
  106. def script_score(name):
  107. for i, possible_match in enumerate(scripts_order):
  108. if possible_match == name:
  109. return i
  110. return len(self.scripts)
  111. script_scores = {script.name: (script_score(script.name), script.order, script.name, original_index) for original_index, script in enumerate(self.scripts)}
  112. return sorted(self.scripts, key=lambda x: script_scores[x.name])
  113. def setup_ui(self):
  114. inputs = []
  115. for script in self.scripts_in_preferred_order():
  116. with gr.Row() as group:
  117. self.create_script_ui(script, inputs)
  118. script.group = group
  119. self.ui_created = True
  120. return inputs
  121. def run(self, pp: PostprocessedImage, args):
  122. scripts = []
  123. for script in self.scripts_in_preferred_order():
  124. script_args = args[script.args_from:script.args_to]
  125. process_args = {}
  126. for (name, _component), value in zip(script.controls.items(), script_args):
  127. process_args[name] = value
  128. scripts.append((script, process_args))
  129. for script, process_args in scripts:
  130. script.process_firstpass(pp, **process_args)
  131. all_images = [pp]
  132. for script, process_args in scripts:
  133. if shared.state.skipped:
  134. break
  135. shared.state.job = script.name
  136. for single_image in all_images.copy():
  137. if not single_image.disable_processing:
  138. script.process(single_image, **process_args)
  139. for extra_image in single_image.extra_images:
  140. if not isinstance(extra_image, PostprocessedImage):
  141. extra_image = single_image.create_copy(extra_image)
  142. all_images.append(extra_image)
  143. single_image.extra_images.clear()
  144. pp.extra_images = all_images[1:]
  145. def create_args_for_run(self, scripts_args):
  146. if not self.ui_created:
  147. with gr.Blocks(analytics_enabled=False):
  148. self.setup_ui()
  149. scripts = self.scripts_in_preferred_order()
  150. args = [None] * max([x.args_to for x in scripts])
  151. for script in scripts:
  152. script_args_dict = scripts_args.get(script.name, None)
  153. if script_args_dict is not None:
  154. for i, name in enumerate(script.controls):
  155. args[script.args_from + i] = script_args_dict.get(name, None)
  156. return args
  157. def image_changed(self):
  158. for script in self.scripts_in_preferred_order():
  159. script.image_changed()