ui_common.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325
  1. import csv
  2. import dataclasses
  3. import json
  4. import html
  5. import os
  6. from contextlib import nullcontext
  7. import gradio as gr
  8. from modules import call_queue, shared, ui_tempdir, util
  9. from modules.infotext_utils import image_from_url_text
  10. import modules.images
  11. from modules.ui_components import ToolButton
  12. import modules.infotext_utils as parameters_copypaste
  13. folder_symbol = '\U0001f4c2' # 📂
  14. refresh_symbol = '\U0001f504' # 🔄
  15. def update_generation_info(generation_info, html_info, img_index):
  16. try:
  17. generation_info = json.loads(generation_info)
  18. if img_index < 0 or img_index >= len(generation_info["infotexts"]):
  19. return html_info, gr.update()
  20. return plaintext_to_html(generation_info["infotexts"][img_index]), gr.update()
  21. except Exception:
  22. pass
  23. # if the json parse or anything else fails, just return the old html_info
  24. return html_info, gr.update()
  25. def plaintext_to_html(text, classname=None):
  26. content = "<br>\n".join(html.escape(x) for x in text.split('\n'))
  27. return f"<p class='{classname}'>{content}</p>" if classname else f"<p>{content}</p>"
  28. def update_logfile(logfile_path, fields):
  29. """Update a logfile from old format to new format to maintain CSV integrity."""
  30. with open(logfile_path, "r", encoding="utf8", newline="") as file:
  31. reader = csv.reader(file)
  32. rows = list(reader)
  33. # blank file: leave it as is
  34. if not rows:
  35. return
  36. # file is already synced, do nothing
  37. if len(rows[0]) == len(fields):
  38. return
  39. rows[0] = fields
  40. # append new fields to each row as empty values
  41. for row in rows[1:]:
  42. while len(row) < len(fields):
  43. row.append("")
  44. with open(logfile_path, "w", encoding="utf8", newline="") as file:
  45. writer = csv.writer(file)
  46. writer.writerows(rows)
  47. def save_files(js_data, images, do_make_zip, index):
  48. filenames = []
  49. fullfns = []
  50. parsed_infotexts = []
  51. # quick dictionary to class object conversion. Its necessary due apply_filename_pattern requiring it
  52. class MyObject:
  53. def __init__(self, d=None):
  54. if d is not None:
  55. for key, value in d.items():
  56. setattr(self, key, value)
  57. data = json.loads(js_data)
  58. p = MyObject(data)
  59. path = shared.opts.outdir_save
  60. save_to_dirs = shared.opts.use_save_to_dirs_for_ui
  61. extension: str = shared.opts.samples_format
  62. start_index = 0
  63. if index > -1 and shared.opts.save_selected_only and (index >= data["index_of_first_image"]): # ensures we are looking at a specific non-grid picture, and we have save_selected_only
  64. images = [images[index]]
  65. start_index = index
  66. os.makedirs(shared.opts.outdir_save, exist_ok=True)
  67. fields = [
  68. "prompt",
  69. "seed",
  70. "width",
  71. "height",
  72. "sampler",
  73. "cfgs",
  74. "steps",
  75. "filename",
  76. "negative_prompt",
  77. "sd_model_name",
  78. "sd_model_hash",
  79. ]
  80. logfile_path = os.path.join(shared.opts.outdir_save, "log.csv")
  81. # NOTE: ensure csv integrity when fields are added by
  82. # updating headers and padding with delimiters where needed
  83. if shared.opts.save_write_log_csv and os.path.exists(logfile_path):
  84. update_logfile(logfile_path, fields)
  85. with (open(logfile_path, "a", encoding="utf8", newline='') if shared.opts.save_write_log_csv else nullcontext()) as file:
  86. if file:
  87. at_start = file.tell() == 0
  88. writer = csv.writer(file)
  89. if at_start:
  90. writer.writerow(fields)
  91. for image_index, filedata in enumerate(images, start_index):
  92. image = image_from_url_text(filedata)
  93. is_grid = image_index < p.index_of_first_image
  94. p.batch_index = image_index-1
  95. parameters = parameters_copypaste.parse_generation_parameters(data["infotexts"][image_index], [])
  96. parsed_infotexts.append(parameters)
  97. fullfn, txt_fullfn = modules.images.save_image(image, path, "", seed=parameters['Seed'], prompt=parameters['Prompt'], extension=extension, info=p.infotexts[image_index], grid=is_grid, p=p, save_to_dirs=save_to_dirs)
  98. filename = os.path.relpath(fullfn, path)
  99. filenames.append(filename)
  100. fullfns.append(fullfn)
  101. if txt_fullfn:
  102. filenames.append(os.path.basename(txt_fullfn))
  103. fullfns.append(txt_fullfn)
  104. if file:
  105. writer.writerow([parsed_infotexts[0]['Prompt'], parsed_infotexts[0]['Seed'], data["width"], data["height"], data["sampler_name"], data["cfg_scale"], data["steps"], filenames[0], parsed_infotexts[0]['Negative prompt'], data["sd_model_name"], data["sd_model_hash"]])
  106. # Make Zip
  107. if do_make_zip:
  108. p.all_seeds = [parameters['Seed'] for parameters in parsed_infotexts]
  109. namegen = modules.images.FilenameGenerator(p, parsed_infotexts[0]['Seed'], parsed_infotexts[0]['Prompt'], image, True)
  110. zip_filename = namegen.apply(shared.opts.grid_zip_filename_pattern or "[datetime]_[[model_name]]_[seed]-[seed_last]")
  111. zip_filepath = os.path.join(path, f"{zip_filename}.zip")
  112. from zipfile import ZipFile
  113. with ZipFile(zip_filepath, "w") as zip_file:
  114. for i in range(len(fullfns)):
  115. with open(fullfns[i], mode="rb") as f:
  116. zip_file.writestr(filenames[i], f.read())
  117. fullfns.insert(0, zip_filepath)
  118. return gr.File.update(value=fullfns, visible=True), plaintext_to_html(f"Saved: {filenames[0]}")
  119. @dataclasses.dataclass
  120. class OutputPanel:
  121. gallery = None
  122. generation_info = None
  123. infotext = None
  124. html_log = None
  125. button_upscale = None
  126. def create_output_panel(tabname, outdir, toprow=None):
  127. res = OutputPanel()
  128. def open_folder(f, images=None, index=None):
  129. if shared.cmd_opts.hide_ui_dir_config:
  130. return
  131. try:
  132. if 'Sub' in shared.opts.open_dir_button_choice:
  133. image_dir = os.path.split(images[index]["name"].rsplit('?', 1)[0])[0]
  134. if 'temp' in shared.opts.open_dir_button_choice or not ui_tempdir.is_gradio_temp_path(image_dir):
  135. f = image_dir
  136. except Exception:
  137. pass
  138. util.open_folder(f)
  139. with gr.Column(elem_id=f"{tabname}_results"):
  140. if toprow:
  141. toprow.create_inline_toprow_image()
  142. with gr.Column(variant='panel', elem_id=f"{tabname}_results_panel"):
  143. with gr.Group(elem_id=f"{tabname}_gallery_container"):
  144. res.gallery = gr.Gallery(label='Output', show_label=False, elem_id=f"{tabname}_gallery", columns=4, preview=True, height=shared.opts.gallery_height or None)
  145. with gr.Row(elem_id=f"image_buttons_{tabname}", elem_classes="image-buttons"):
  146. open_folder_button = ToolButton(folder_symbol, elem_id=f'{tabname}_open_folder', visible=not shared.cmd_opts.hide_ui_dir_config, tooltip="Open images output directory.")
  147. if tabname != "extras":
  148. save = ToolButton('💾', elem_id=f'save_{tabname}', tooltip=f"Save the image to a dedicated directory ({shared.opts.outdir_save}).")
  149. save_zip = ToolButton('🗃️', elem_id=f'save_zip_{tabname}', tooltip=f"Save zip archive with images to a dedicated directory ({shared.opts.outdir_save})")
  150. buttons = {
  151. 'img2img': ToolButton('🖼️', elem_id=f'{tabname}_send_to_img2img', tooltip="Send image and generation parameters to img2img tab."),
  152. 'inpaint': ToolButton('🎨️', elem_id=f'{tabname}_send_to_inpaint', tooltip="Send image and generation parameters to img2img inpaint tab."),
  153. 'extras': ToolButton('📐', elem_id=f'{tabname}_send_to_extras', tooltip="Send image and generation parameters to extras tab.")
  154. }
  155. if tabname == 'txt2img':
  156. res.button_upscale = ToolButton('✨', elem_id=f'{tabname}_upscale', tooltip="Create an upscaled version of the current image using hires fix settings.")
  157. open_folder_button.click(
  158. fn=lambda images, index: open_folder(shared.opts.outdir_samples or outdir, images, index),
  159. _js="(y, w) => [y, selected_gallery_index()]",
  160. inputs=[
  161. res.gallery,
  162. open_folder_button, # placeholder for index
  163. ],
  164. outputs=[],
  165. )
  166. if tabname != "extras":
  167. download_files = gr.File(None, file_count="multiple", interactive=False, show_label=False, visible=False, elem_id=f'download_files_{tabname}')
  168. with gr.Group():
  169. res.infotext = gr.HTML(elem_id=f'html_info_{tabname}', elem_classes="infotext")
  170. res.html_log = gr.HTML(elem_id=f'html_log_{tabname}', elem_classes="html-log")
  171. res.generation_info = gr.Textbox(visible=False, elem_id=f'generation_info_{tabname}')
  172. if tabname == 'txt2img' or tabname == 'img2img':
  173. generation_info_button = gr.Button(visible=False, elem_id=f"{tabname}_generation_info_button")
  174. generation_info_button.click(
  175. fn=update_generation_info,
  176. _js="function(x, y, z){ return [x, y, selected_gallery_index()] }",
  177. inputs=[res.generation_info, res.infotext, res.infotext],
  178. outputs=[res.infotext, res.infotext],
  179. show_progress=False,
  180. )
  181. save.click(
  182. fn=call_queue.wrap_gradio_call_no_job(save_files),
  183. _js="(x, y, z, w) => [x, y, false, selected_gallery_index()]",
  184. inputs=[
  185. res.generation_info,
  186. res.gallery,
  187. res.infotext,
  188. res.infotext,
  189. ],
  190. outputs=[
  191. download_files,
  192. res.html_log,
  193. ],
  194. show_progress=False,
  195. )
  196. save_zip.click(
  197. fn=call_queue.wrap_gradio_call_no_job(save_files),
  198. _js="(x, y, z, w) => [x, y, true, selected_gallery_index()]",
  199. inputs=[
  200. res.generation_info,
  201. res.gallery,
  202. res.infotext,
  203. res.infotext,
  204. ],
  205. outputs=[
  206. download_files,
  207. res.html_log,
  208. ]
  209. )
  210. else:
  211. res.generation_info = gr.HTML(elem_id=f'html_info_x_{tabname}')
  212. res.infotext = gr.HTML(elem_id=f'html_info_{tabname}', elem_classes="infotext")
  213. res.html_log = gr.HTML(elem_id=f'html_log_{tabname}')
  214. paste_field_names = []
  215. if tabname == "txt2img":
  216. paste_field_names = modules.scripts.scripts_txt2img.paste_field_names
  217. elif tabname == "img2img":
  218. paste_field_names = modules.scripts.scripts_img2img.paste_field_names
  219. for paste_tabname, paste_button in buttons.items():
  220. parameters_copypaste.register_paste_params_button(parameters_copypaste.ParamBinding(
  221. paste_button=paste_button, tabname=paste_tabname, source_tabname="txt2img" if tabname == "txt2img" else None, source_image_component=res.gallery,
  222. paste_field_names=paste_field_names
  223. ))
  224. return res
  225. def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_id):
  226. refresh_components = refresh_component if isinstance(refresh_component, list) else [refresh_component]
  227. label = None
  228. for comp in refresh_components:
  229. label = getattr(comp, 'label', None)
  230. if label is not None:
  231. break
  232. def refresh():
  233. refresh_method()
  234. args = refreshed_args() if callable(refreshed_args) else refreshed_args
  235. for k, v in args.items():
  236. for comp in refresh_components:
  237. setattr(comp, k, v)
  238. return [gr.update(**(args or {})) for _ in refresh_components] if len(refresh_components) > 1 else gr.update(**(args or {}))
  239. refresh_button = ToolButton(value=refresh_symbol, elem_id=elem_id, tooltip=f"{label}: refresh" if label else "Refresh")
  240. refresh_button.click(
  241. fn=refresh,
  242. inputs=[],
  243. outputs=refresh_components
  244. )
  245. return refresh_button
  246. def setup_dialog(button_show, dialog, *, button_close=None):
  247. """Sets up the UI so that the dialog (gr.Box) is invisible, and is only shown when buttons_show is clicked, in a fullscreen modal window."""
  248. dialog.visible = False
  249. button_show.click(
  250. fn=lambda: gr.update(visible=True),
  251. inputs=[],
  252. outputs=[dialog],
  253. ).then(fn=None, _js="function(){ popupId('" + dialog.elem_id + "'); }")
  254. if button_close:
  255. button_close.click(fn=None, _js="closePopup")