xyz_grid.py 32 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724
  1. from collections import namedtuple
  2. from copy import copy
  3. from itertools import permutations, chain
  4. import random
  5. import csv
  6. from io import StringIO
  7. from PIL import Image
  8. import numpy as np
  9. import modules.scripts as scripts
  10. import gradio as gr
  11. from modules import images, sd_samplers, processing, sd_models, sd_vae, sd_samplers_kdiffusion
  12. from modules.processing import process_images, Processed, StableDiffusionProcessingTxt2Img
  13. from modules.shared import opts, state
  14. import modules.shared as shared
  15. import modules.sd_samplers
  16. import modules.sd_models
  17. import modules.sd_vae
  18. import re
  19. from modules.ui_components import ToolButton
  20. fill_values_symbol = "\U0001f4d2" # 📒
  21. AxisInfo = namedtuple('AxisInfo', ['axis', 'values'])
  22. def apply_field(field):
  23. def fun(p, x, xs):
  24. setattr(p, field, x)
  25. return fun
  26. def apply_prompt(p, x, xs):
  27. if xs[0] not in p.prompt and xs[0] not in p.negative_prompt:
  28. raise RuntimeError(f"Prompt S/R did not find {xs[0]} in prompt or negative prompt.")
  29. p.prompt = p.prompt.replace(xs[0], x)
  30. p.negative_prompt = p.negative_prompt.replace(xs[0], x)
  31. def apply_order(p, x, xs):
  32. token_order = []
  33. # Initally grab the tokens from the prompt, so they can be replaced in order of earliest seen
  34. for token in x:
  35. token_order.append((p.prompt.find(token), token))
  36. token_order.sort(key=lambda t: t[0])
  37. prompt_parts = []
  38. # Split the prompt up, taking out the tokens
  39. for _, token in token_order:
  40. n = p.prompt.find(token)
  41. prompt_parts.append(p.prompt[0:n])
  42. p.prompt = p.prompt[n + len(token):]
  43. # Rebuild the prompt with the tokens in the order we want
  44. prompt_tmp = ""
  45. for idx, part in enumerate(prompt_parts):
  46. prompt_tmp += part
  47. prompt_tmp += x[idx]
  48. p.prompt = prompt_tmp + p.prompt
  49. def apply_sampler(p, x, xs):
  50. sampler_name = sd_samplers.samplers_map.get(x.lower(), None)
  51. if sampler_name is None:
  52. raise RuntimeError(f"Unknown sampler: {x}")
  53. p.sampler_name = sampler_name
  54. def confirm_samplers(p, xs):
  55. for x in xs:
  56. if x.lower() not in sd_samplers.samplers_map:
  57. raise RuntimeError(f"Unknown sampler: {x}")
  58. def apply_checkpoint(p, x, xs):
  59. info = modules.sd_models.get_closet_checkpoint_match(x)
  60. if info is None:
  61. raise RuntimeError(f"Unknown checkpoint: {x}")
  62. p.override_settings['sd_model_checkpoint'] = info.name
  63. def confirm_checkpoints(p, xs):
  64. for x in xs:
  65. if modules.sd_models.get_closet_checkpoint_match(x) is None:
  66. raise RuntimeError(f"Unknown checkpoint: {x}")
  67. def apply_clip_skip(p, x, xs):
  68. opts.data["CLIP_stop_at_last_layers"] = x
  69. def apply_upscale_latent_space(p, x, xs):
  70. if x.lower().strip() != '0':
  71. opts.data["use_scale_latent_for_hires_fix"] = True
  72. else:
  73. opts.data["use_scale_latent_for_hires_fix"] = False
  74. def find_vae(name: str):
  75. if name.lower() in ['auto', 'automatic']:
  76. return modules.sd_vae.unspecified
  77. if name.lower() == 'none':
  78. return None
  79. else:
  80. choices = [x for x in sorted(modules.sd_vae.vae_dict, key=lambda x: len(x)) if name.lower().strip() in x.lower()]
  81. if len(choices) == 0:
  82. print(f"No VAE found for {name}; using automatic")
  83. return modules.sd_vae.unspecified
  84. else:
  85. return modules.sd_vae.vae_dict[choices[0]]
  86. def apply_vae(p, x, xs):
  87. modules.sd_vae.reload_vae_weights(shared.sd_model, vae_file=find_vae(x))
  88. def apply_styles(p: StableDiffusionProcessingTxt2Img, x: str, _):
  89. p.styles.extend(x.split(','))
  90. def apply_uni_pc_order(p, x, xs):
  91. opts.data["uni_pc_order"] = min(x, p.steps - 1)
  92. def apply_face_restore(p, opt, x):
  93. opt = opt.lower()
  94. if opt == 'codeformer':
  95. is_active = True
  96. p.face_restoration_model = 'CodeFormer'
  97. elif opt == 'gfpgan':
  98. is_active = True
  99. p.face_restoration_model = 'GFPGAN'
  100. else:
  101. is_active = opt in ('true', 'yes', 'y', '1')
  102. p.restore_faces = is_active
  103. def apply_override(field):
  104. def fun(p, x, xs):
  105. p.override_settings[field] = x
  106. return fun
  107. def format_value_add_label(p, opt, x):
  108. if type(x) == float:
  109. x = round(x, 8)
  110. return f"{opt.label}: {x}"
  111. def format_value(p, opt, x):
  112. if type(x) == float:
  113. x = round(x, 8)
  114. return x
  115. def format_value_join_list(p, opt, x):
  116. return ", ".join(x)
  117. def do_nothing(p, x, xs):
  118. pass
  119. def format_nothing(p, opt, x):
  120. return ""
  121. def str_permutations(x):
  122. """dummy function for specifying it in AxisOption's type when you want to get a list of permutations"""
  123. return x
  124. class AxisOption:
  125. def __init__(self, label, type, apply, format_value=format_value_add_label, confirm=None, cost=0.0, choices=None):
  126. self.label = label
  127. self.type = type
  128. self.apply = apply
  129. self.format_value = format_value
  130. self.confirm = confirm
  131. self.cost = cost
  132. self.choices = choices
  133. class AxisOptionImg2Img(AxisOption):
  134. def __init__(self, *args, **kwargs):
  135. super().__init__(*args, **kwargs)
  136. self.is_img2img = True
  137. class AxisOptionTxt2Img(AxisOption):
  138. def __init__(self, *args, **kwargs):
  139. super().__init__(*args, **kwargs)
  140. self.is_img2img = False
  141. axis_options = [
  142. AxisOption("Nothing", str, do_nothing, format_value=format_nothing),
  143. AxisOption("Seed", int, apply_field("seed")),
  144. AxisOption("Var. seed", int, apply_field("subseed")),
  145. AxisOption("Var. strength", float, apply_field("subseed_strength")),
  146. AxisOption("Steps", int, apply_field("steps")),
  147. AxisOptionTxt2Img("Hires steps", int, apply_field("hr_second_pass_steps")),
  148. AxisOption("CFG Scale", float, apply_field("cfg_scale")),
  149. AxisOptionImg2Img("Image CFG Scale", float, apply_field("image_cfg_scale")),
  150. AxisOption("Prompt S/R", str, apply_prompt, format_value=format_value),
  151. AxisOption("Prompt order", str_permutations, apply_order, format_value=format_value_join_list),
  152. AxisOptionTxt2Img("Sampler", str, apply_sampler, format_value=format_value, confirm=confirm_samplers, choices=lambda: [x.name for x in sd_samplers.samplers]),
  153. AxisOptionImg2Img("Sampler", str, apply_sampler, format_value=format_value, confirm=confirm_samplers, choices=lambda: [x.name for x in sd_samplers.samplers_for_img2img]),
  154. AxisOption("Checkpoint name", str, apply_checkpoint, format_value=format_value, confirm=confirm_checkpoints, cost=1.0, choices=lambda: sorted(sd_models.checkpoints_list, key=str.casefold)),
  155. AxisOption("Negative Guidance minimum sigma", float, apply_field("s_min_uncond")),
  156. AxisOption("Sigma Churn", float, apply_field("s_churn")),
  157. AxisOption("Sigma min", float, apply_field("s_tmin")),
  158. AxisOption("Sigma max", float, apply_field("s_tmax")),
  159. AxisOption("Sigma noise", float, apply_field("s_noise")),
  160. AxisOption("Schedule type", str, apply_override("k_sched_type"), choices=lambda: list(sd_samplers_kdiffusion.k_diffusion_scheduler)),
  161. AxisOption("Schedule min sigma", float, apply_override("sigma_min")),
  162. AxisOption("Schedule max sigma", float, apply_override("sigma_max")),
  163. AxisOption("Schedule rho", float, apply_override("rho")),
  164. AxisOption("Eta", float, apply_field("eta")),
  165. AxisOption("Clip skip", int, apply_clip_skip),
  166. AxisOption("Denoising", float, apply_field("denoising_strength")),
  167. AxisOptionTxt2Img("Hires upscaler", str, apply_field("hr_upscaler"), choices=lambda: [*shared.latent_upscale_modes, *[x.name for x in shared.sd_upscalers]]),
  168. AxisOptionImg2Img("Cond. Image Mask Weight", float, apply_field("inpainting_mask_weight")),
  169. AxisOption("VAE", str, apply_vae, cost=0.7, choices=lambda: ['None'] + list(sd_vae.vae_dict)),
  170. AxisOption("Styles", str, apply_styles, choices=lambda: list(shared.prompt_styles.styles)),
  171. AxisOption("UniPC Order", int, apply_uni_pc_order, cost=0.5),
  172. AxisOption("Face restore", str, apply_face_restore, format_value=format_value),
  173. AxisOption("Token merging ratio", float, apply_override('token_merging_ratio')),
  174. AxisOption("Token merging ratio high-res", float, apply_override('token_merging_ratio_hr')),
  175. ]
  176. def draw_xyz_grid(p, xs, ys, zs, x_labels, y_labels, z_labels, cell, draw_legend, include_lone_images, include_sub_grids, first_axes_processed, second_axes_processed, margin_size):
  177. hor_texts = [[images.GridAnnotation(x)] for x in x_labels]
  178. ver_texts = [[images.GridAnnotation(y)] for y in y_labels]
  179. title_texts = [[images.GridAnnotation(z)] for z in z_labels]
  180. list_size = (len(xs) * len(ys) * len(zs))
  181. processed_result = None
  182. state.job_count = list_size * p.n_iter
  183. def process_cell(x, y, z, ix, iy, iz):
  184. nonlocal processed_result
  185. def index(ix, iy, iz):
  186. return ix + iy * len(xs) + iz * len(xs) * len(ys)
  187. state.job = f"{index(ix, iy, iz) + 1} out of {list_size}"
  188. processed: Processed = cell(x, y, z, ix, iy, iz)
  189. if processed_result is None:
  190. # Use our first processed result object as a template container to hold our full results
  191. processed_result = copy(processed)
  192. processed_result.images = [None] * list_size
  193. processed_result.all_prompts = [None] * list_size
  194. processed_result.all_seeds = [None] * list_size
  195. processed_result.infotexts = [None] * list_size
  196. processed_result.index_of_first_image = 1
  197. idx = index(ix, iy, iz)
  198. if processed.images:
  199. # Non-empty list indicates some degree of success.
  200. processed_result.images[idx] = processed.images[0]
  201. processed_result.all_prompts[idx] = processed.prompt
  202. processed_result.all_seeds[idx] = processed.seed
  203. processed_result.infotexts[idx] = processed.infotexts[0]
  204. else:
  205. cell_mode = "P"
  206. cell_size = (processed_result.width, processed_result.height)
  207. if processed_result.images[0] is not None:
  208. cell_mode = processed_result.images[0].mode
  209. #This corrects size in case of batches:
  210. cell_size = processed_result.images[0].size
  211. processed_result.images[idx] = Image.new(cell_mode, cell_size)
  212. if first_axes_processed == 'x':
  213. for ix, x in enumerate(xs):
  214. if second_axes_processed == 'y':
  215. for iy, y in enumerate(ys):
  216. for iz, z in enumerate(zs):
  217. process_cell(x, y, z, ix, iy, iz)
  218. else:
  219. for iz, z in enumerate(zs):
  220. for iy, y in enumerate(ys):
  221. process_cell(x, y, z, ix, iy, iz)
  222. elif first_axes_processed == 'y':
  223. for iy, y in enumerate(ys):
  224. if second_axes_processed == 'x':
  225. for ix, x in enumerate(xs):
  226. for iz, z in enumerate(zs):
  227. process_cell(x, y, z, ix, iy, iz)
  228. else:
  229. for iz, z in enumerate(zs):
  230. for ix, x in enumerate(xs):
  231. process_cell(x, y, z, ix, iy, iz)
  232. elif first_axes_processed == 'z':
  233. for iz, z in enumerate(zs):
  234. if second_axes_processed == 'x':
  235. for ix, x in enumerate(xs):
  236. for iy, y in enumerate(ys):
  237. process_cell(x, y, z, ix, iy, iz)
  238. else:
  239. for iy, y in enumerate(ys):
  240. for ix, x in enumerate(xs):
  241. process_cell(x, y, z, ix, iy, iz)
  242. if not processed_result:
  243. # Should never happen, I've only seen it on one of four open tabs and it needed to refresh.
  244. print("Unexpected error: Processing could not begin, you may need to refresh the tab or restart the service.")
  245. return Processed(p, [])
  246. elif not any(processed_result.images):
  247. print("Unexpected error: draw_xyz_grid failed to return even a single processed image")
  248. return Processed(p, [])
  249. z_count = len(zs)
  250. for i in range(z_count):
  251. start_index = (i * len(xs) * len(ys)) + i
  252. end_index = start_index + len(xs) * len(ys)
  253. grid = images.image_grid(processed_result.images[start_index:end_index], rows=len(ys))
  254. if draw_legend:
  255. grid = images.draw_grid_annotations(grid, processed_result.images[start_index].size[0], processed_result.images[start_index].size[1], hor_texts, ver_texts, margin_size)
  256. processed_result.images.insert(i, grid)
  257. processed_result.all_prompts.insert(i, processed_result.all_prompts[start_index])
  258. processed_result.all_seeds.insert(i, processed_result.all_seeds[start_index])
  259. processed_result.infotexts.insert(i, processed_result.infotexts[start_index])
  260. sub_grid_size = processed_result.images[0].size
  261. z_grid = images.image_grid(processed_result.images[:z_count], rows=1)
  262. if draw_legend:
  263. z_grid = images.draw_grid_annotations(z_grid, sub_grid_size[0], sub_grid_size[1], title_texts, [[images.GridAnnotation()]])
  264. processed_result.images.insert(0, z_grid)
  265. #TODO: Deeper aspects of the program rely on grid info being misaligned between metadata arrays, which is not ideal.
  266. #processed_result.all_prompts.insert(0, processed_result.all_prompts[0])
  267. #processed_result.all_seeds.insert(0, processed_result.all_seeds[0])
  268. processed_result.infotexts.insert(0, processed_result.infotexts[0])
  269. return processed_result
  270. class SharedSettingsStackHelper(object):
  271. def __enter__(self):
  272. self.CLIP_stop_at_last_layers = opts.CLIP_stop_at_last_layers
  273. self.vae = opts.sd_vae
  274. self.uni_pc_order = opts.uni_pc_order
  275. def __exit__(self, exc_type, exc_value, tb):
  276. opts.data["sd_vae"] = self.vae
  277. opts.data["uni_pc_order"] = self.uni_pc_order
  278. modules.sd_models.reload_model_weights()
  279. modules.sd_vae.reload_vae_weights()
  280. opts.data["CLIP_stop_at_last_layers"] = self.CLIP_stop_at_last_layers
  281. re_range = re.compile(r"\s*([+-]?\s*\d+)\s*-\s*([+-]?\s*\d+)(?:\s*\(([+-]\d+)\s*\))?\s*")
  282. re_range_float = re.compile(r"\s*([+-]?\s*\d+(?:.\d*)?)\s*-\s*([+-]?\s*\d+(?:.\d*)?)(?:\s*\(([+-]\d+(?:.\d*)?)\s*\))?\s*")
  283. re_range_count = re.compile(r"\s*([+-]?\s*\d+)\s*-\s*([+-]?\s*\d+)(?:\s*\[(\d+)\s*\])?\s*")
  284. re_range_count_float = re.compile(r"\s*([+-]?\s*\d+(?:.\d*)?)\s*-\s*([+-]?\s*\d+(?:.\d*)?)(?:\s*\[(\d+(?:.\d*)?)\s*\])?\s*")
  285. class Script(scripts.Script):
  286. def title(self):
  287. return "X/Y/Z plot"
  288. def ui(self, is_img2img):
  289. self.current_axis_options = [x for x in axis_options if type(x) == AxisOption or x.is_img2img == is_img2img]
  290. with gr.Row():
  291. with gr.Column(scale=19):
  292. with gr.Row():
  293. x_type = gr.Dropdown(label="X type", choices=[x.label for x in self.current_axis_options], value=self.current_axis_options[1].label, type="index", elem_id=self.elem_id("x_type"))
  294. x_values = gr.Textbox(label="X values", lines=1, elem_id=self.elem_id("x_values"))
  295. x_values_dropdown = gr.Dropdown(label="X values",visible=False,multiselect=True,interactive=True)
  296. fill_x_button = ToolButton(value=fill_values_symbol, elem_id="xyz_grid_fill_x_tool_button", visible=False)
  297. with gr.Row():
  298. y_type = gr.Dropdown(label="Y type", choices=[x.label for x in self.current_axis_options], value=self.current_axis_options[0].label, type="index", elem_id=self.elem_id("y_type"))
  299. y_values = gr.Textbox(label="Y values", lines=1, elem_id=self.elem_id("y_values"))
  300. y_values_dropdown = gr.Dropdown(label="Y values",visible=False,multiselect=True,interactive=True)
  301. fill_y_button = ToolButton(value=fill_values_symbol, elem_id="xyz_grid_fill_y_tool_button", visible=False)
  302. with gr.Row():
  303. z_type = gr.Dropdown(label="Z type", choices=[x.label for x in self.current_axis_options], value=self.current_axis_options[0].label, type="index", elem_id=self.elem_id("z_type"))
  304. z_values = gr.Textbox(label="Z values", lines=1, elem_id=self.elem_id("z_values"))
  305. z_values_dropdown = gr.Dropdown(label="Z values",visible=False,multiselect=True,interactive=True)
  306. fill_z_button = ToolButton(value=fill_values_symbol, elem_id="xyz_grid_fill_z_tool_button", visible=False)
  307. with gr.Row(variant="compact", elem_id="axis_options"):
  308. with gr.Column():
  309. draw_legend = gr.Checkbox(label='Draw legend', value=True, elem_id=self.elem_id("draw_legend"))
  310. no_fixed_seeds = gr.Checkbox(label='Keep -1 for seeds', value=False, elem_id=self.elem_id("no_fixed_seeds"))
  311. with gr.Column():
  312. include_lone_images = gr.Checkbox(label='Include Sub Images', value=False, elem_id=self.elem_id("include_lone_images"))
  313. include_sub_grids = gr.Checkbox(label='Include Sub Grids', value=False, elem_id=self.elem_id("include_sub_grids"))
  314. with gr.Column():
  315. margin_size = gr.Slider(label="Grid margins (px)", minimum=0, maximum=500, value=0, step=2, elem_id=self.elem_id("margin_size"))
  316. with gr.Row(variant="compact", elem_id="swap_axes"):
  317. swap_xy_axes_button = gr.Button(value="Swap X/Y axes", elem_id="xy_grid_swap_axes_button")
  318. swap_yz_axes_button = gr.Button(value="Swap Y/Z axes", elem_id="yz_grid_swap_axes_button")
  319. swap_xz_axes_button = gr.Button(value="Swap X/Z axes", elem_id="xz_grid_swap_axes_button")
  320. def swap_axes(axis1_type, axis1_values, axis1_values_dropdown, axis2_type, axis2_values, axis2_values_dropdown):
  321. return self.current_axis_options[axis2_type].label, axis2_values, axis2_values_dropdown, self.current_axis_options[axis1_type].label, axis1_values, axis1_values_dropdown
  322. xy_swap_args = [x_type, x_values, x_values_dropdown, y_type, y_values, y_values_dropdown]
  323. swap_xy_axes_button.click(swap_axes, inputs=xy_swap_args, outputs=xy_swap_args)
  324. yz_swap_args = [y_type, y_values, y_values_dropdown, z_type, z_values, z_values_dropdown]
  325. swap_yz_axes_button.click(swap_axes, inputs=yz_swap_args, outputs=yz_swap_args)
  326. xz_swap_args = [x_type, x_values, x_values_dropdown, z_type, z_values, z_values_dropdown]
  327. swap_xz_axes_button.click(swap_axes, inputs=xz_swap_args, outputs=xz_swap_args)
  328. def fill(x_type):
  329. axis = self.current_axis_options[x_type]
  330. return axis.choices() if axis.choices else gr.update()
  331. fill_x_button.click(fn=fill, inputs=[x_type], outputs=[x_values_dropdown])
  332. fill_y_button.click(fn=fill, inputs=[y_type], outputs=[y_values_dropdown])
  333. fill_z_button.click(fn=fill, inputs=[z_type], outputs=[z_values_dropdown])
  334. def select_axis(axis_type,axis_values_dropdown):
  335. choices = self.current_axis_options[axis_type].choices
  336. has_choices = choices is not None
  337. current_values = axis_values_dropdown
  338. if has_choices:
  339. choices = choices()
  340. if isinstance(current_values,str):
  341. current_values = current_values.split(",")
  342. current_values = list(filter(lambda x: x in choices, current_values))
  343. return gr.Button.update(visible=has_choices),gr.Textbox.update(visible=not has_choices),gr.update(choices=choices if has_choices else None,visible=has_choices,value=current_values)
  344. x_type.change(fn=select_axis, inputs=[x_type,x_values_dropdown], outputs=[fill_x_button,x_values,x_values_dropdown])
  345. y_type.change(fn=select_axis, inputs=[y_type,y_values_dropdown], outputs=[fill_y_button,y_values,y_values_dropdown])
  346. z_type.change(fn=select_axis, inputs=[z_type,z_values_dropdown], outputs=[fill_z_button,z_values,z_values_dropdown])
  347. def get_dropdown_update_from_params(axis,params):
  348. val_key = f"{axis} Values"
  349. vals = params.get(val_key,"")
  350. valslist = [x.strip() for x in chain.from_iterable(csv.reader(StringIO(vals))) if x]
  351. return gr.update(value = valslist)
  352. self.infotext_fields = (
  353. (x_type, "X Type"),
  354. (x_values, "X Values"),
  355. (x_values_dropdown, lambda params:get_dropdown_update_from_params("X",params)),
  356. (y_type, "Y Type"),
  357. (y_values, "Y Values"),
  358. (y_values_dropdown, lambda params:get_dropdown_update_from_params("Y",params)),
  359. (z_type, "Z Type"),
  360. (z_values, "Z Values"),
  361. (z_values_dropdown, lambda params:get_dropdown_update_from_params("Z",params)),
  362. )
  363. return [x_type, x_values, x_values_dropdown, y_type, y_values, y_values_dropdown, z_type, z_values, z_values_dropdown, draw_legend, include_lone_images, include_sub_grids, no_fixed_seeds, margin_size]
  364. def run(self, p, x_type, x_values, x_values_dropdown, y_type, y_values, y_values_dropdown, z_type, z_values, z_values_dropdown, draw_legend, include_lone_images, include_sub_grids, no_fixed_seeds, margin_size):
  365. if not no_fixed_seeds:
  366. modules.processing.fix_seed(p)
  367. if not opts.return_grid:
  368. p.batch_size = 1
  369. def process_axis(opt, vals, vals_dropdown):
  370. if opt.label == 'Nothing':
  371. return [0]
  372. if opt.choices is not None:
  373. valslist = vals_dropdown
  374. else:
  375. valslist = [x.strip() for x in chain.from_iterable(csv.reader(StringIO(vals))) if x]
  376. if opt.type == int:
  377. valslist_ext = []
  378. for val in valslist:
  379. m = re_range.fullmatch(val)
  380. mc = re_range_count.fullmatch(val)
  381. if m is not None:
  382. start = int(m.group(1))
  383. end = int(m.group(2))+1
  384. step = int(m.group(3)) if m.group(3) is not None else 1
  385. valslist_ext += list(range(start, end, step))
  386. elif mc is not None:
  387. start = int(mc.group(1))
  388. end = int(mc.group(2))
  389. num = int(mc.group(3)) if mc.group(3) is not None else 1
  390. valslist_ext += [int(x) for x in np.linspace(start=start, stop=end, num=num).tolist()]
  391. else:
  392. valslist_ext.append(val)
  393. valslist = valslist_ext
  394. elif opt.type == float:
  395. valslist_ext = []
  396. for val in valslist:
  397. m = re_range_float.fullmatch(val)
  398. mc = re_range_count_float.fullmatch(val)
  399. if m is not None:
  400. start = float(m.group(1))
  401. end = float(m.group(2))
  402. step = float(m.group(3)) if m.group(3) is not None else 1
  403. valslist_ext += np.arange(start, end + step, step).tolist()
  404. elif mc is not None:
  405. start = float(mc.group(1))
  406. end = float(mc.group(2))
  407. num = int(mc.group(3)) if mc.group(3) is not None else 1
  408. valslist_ext += np.linspace(start=start, stop=end, num=num).tolist()
  409. else:
  410. valslist_ext.append(val)
  411. valslist = valslist_ext
  412. elif opt.type == str_permutations:
  413. valslist = list(permutations(valslist))
  414. valslist = [opt.type(x) for x in valslist]
  415. # Confirm options are valid before starting
  416. if opt.confirm:
  417. opt.confirm(p, valslist)
  418. return valslist
  419. x_opt = self.current_axis_options[x_type]
  420. if x_opt.choices is not None:
  421. x_values = ",".join(x_values_dropdown)
  422. xs = process_axis(x_opt, x_values, x_values_dropdown)
  423. y_opt = self.current_axis_options[y_type]
  424. if y_opt.choices is not None:
  425. y_values = ",".join(y_values_dropdown)
  426. ys = process_axis(y_opt, y_values, y_values_dropdown)
  427. z_opt = self.current_axis_options[z_type]
  428. if z_opt.choices is not None:
  429. z_values = ",".join(z_values_dropdown)
  430. zs = process_axis(z_opt, z_values, z_values_dropdown)
  431. # this could be moved to common code, but unlikely to be ever triggered anywhere else
  432. Image.MAX_IMAGE_PIXELS = None # disable check in Pillow and rely on check below to allow large custom image sizes
  433. grid_mp = round(len(xs) * len(ys) * len(zs) * p.width * p.height / 1000000)
  434. assert grid_mp < opts.img_max_size_mp, f'Error: Resulting grid would be too large ({grid_mp} MPixels) (max configured size is {opts.img_max_size_mp} MPixels)'
  435. def fix_axis_seeds(axis_opt, axis_list):
  436. if axis_opt.label in ['Seed', 'Var. seed']:
  437. return [int(random.randrange(4294967294)) if val is None or val == '' or val == -1 else val for val in axis_list]
  438. else:
  439. return axis_list
  440. if not no_fixed_seeds:
  441. xs = fix_axis_seeds(x_opt, xs)
  442. ys = fix_axis_seeds(y_opt, ys)
  443. zs = fix_axis_seeds(z_opt, zs)
  444. if x_opt.label == 'Steps':
  445. total_steps = sum(xs) * len(ys) * len(zs)
  446. elif y_opt.label == 'Steps':
  447. total_steps = sum(ys) * len(xs) * len(zs)
  448. elif z_opt.label == 'Steps':
  449. total_steps = sum(zs) * len(xs) * len(ys)
  450. else:
  451. total_steps = p.steps * len(xs) * len(ys) * len(zs)
  452. if isinstance(p, StableDiffusionProcessingTxt2Img) and p.enable_hr:
  453. if x_opt.label == "Hires steps":
  454. total_steps += sum(xs) * len(ys) * len(zs)
  455. elif y_opt.label == "Hires steps":
  456. total_steps += sum(ys) * len(xs) * len(zs)
  457. elif z_opt.label == "Hires steps":
  458. total_steps += sum(zs) * len(xs) * len(ys)
  459. elif p.hr_second_pass_steps:
  460. total_steps += p.hr_second_pass_steps * len(xs) * len(ys) * len(zs)
  461. else:
  462. total_steps *= 2
  463. total_steps *= p.n_iter
  464. image_cell_count = p.n_iter * p.batch_size
  465. cell_console_text = f"; {image_cell_count} images per cell" if image_cell_count > 1 else ""
  466. plural_s = 's' if len(zs) > 1 else ''
  467. print(f"X/Y/Z plot will create {len(xs) * len(ys) * len(zs) * image_cell_count} images on {len(zs)} {len(xs)}x{len(ys)} grid{plural_s}{cell_console_text}. (Total steps to process: {total_steps})")
  468. shared.total_tqdm.updateTotal(total_steps)
  469. state.xyz_plot_x = AxisInfo(x_opt, xs)
  470. state.xyz_plot_y = AxisInfo(y_opt, ys)
  471. state.xyz_plot_z = AxisInfo(z_opt, zs)
  472. # If one of the axes is very slow to change between (like SD model
  473. # checkpoint), then make sure it is in the outer iteration of the nested
  474. # `for` loop.
  475. first_axes_processed = 'z'
  476. second_axes_processed = 'y'
  477. if x_opt.cost > y_opt.cost and x_opt.cost > z_opt.cost:
  478. first_axes_processed = 'x'
  479. if y_opt.cost > z_opt.cost:
  480. second_axes_processed = 'y'
  481. else:
  482. second_axes_processed = 'z'
  483. elif y_opt.cost > x_opt.cost and y_opt.cost > z_opt.cost:
  484. first_axes_processed = 'y'
  485. if x_opt.cost > z_opt.cost:
  486. second_axes_processed = 'x'
  487. else:
  488. second_axes_processed = 'z'
  489. elif z_opt.cost > x_opt.cost and z_opt.cost > y_opt.cost:
  490. first_axes_processed = 'z'
  491. if x_opt.cost > y_opt.cost:
  492. second_axes_processed = 'x'
  493. else:
  494. second_axes_processed = 'y'
  495. grid_infotext = [None] * (1 + len(zs))
  496. def cell(x, y, z, ix, iy, iz):
  497. if shared.state.interrupted:
  498. return Processed(p, [], p.seed, "")
  499. pc = copy(p)
  500. pc.styles = pc.styles[:]
  501. x_opt.apply(pc, x, xs)
  502. y_opt.apply(pc, y, ys)
  503. z_opt.apply(pc, z, zs)
  504. res = process_images(pc)
  505. # Sets subgrid infotexts
  506. subgrid_index = 1 + iz
  507. if grid_infotext[subgrid_index] is None and ix == 0 and iy == 0:
  508. pc.extra_generation_params = copy(pc.extra_generation_params)
  509. pc.extra_generation_params['Script'] = self.title()
  510. if x_opt.label != 'Nothing':
  511. pc.extra_generation_params["X Type"] = x_opt.label
  512. pc.extra_generation_params["X Values"] = x_values
  513. if x_opt.label in ["Seed", "Var. seed"] and not no_fixed_seeds:
  514. pc.extra_generation_params["Fixed X Values"] = ", ".join([str(x) for x in xs])
  515. if y_opt.label != 'Nothing':
  516. pc.extra_generation_params["Y Type"] = y_opt.label
  517. pc.extra_generation_params["Y Values"] = y_values
  518. if y_opt.label in ["Seed", "Var. seed"] and not no_fixed_seeds:
  519. pc.extra_generation_params["Fixed Y Values"] = ", ".join([str(y) for y in ys])
  520. grid_infotext[subgrid_index] = processing.create_infotext(pc, pc.all_prompts, pc.all_seeds, pc.all_subseeds)
  521. # Sets main grid infotext
  522. if grid_infotext[0] is None and ix == 0 and iy == 0 and iz == 0:
  523. pc.extra_generation_params = copy(pc.extra_generation_params)
  524. if z_opt.label != 'Nothing':
  525. pc.extra_generation_params["Z Type"] = z_opt.label
  526. pc.extra_generation_params["Z Values"] = z_values
  527. if z_opt.label in ["Seed", "Var. seed"] and not no_fixed_seeds:
  528. pc.extra_generation_params["Fixed Z Values"] = ", ".join([str(z) for z in zs])
  529. grid_infotext[0] = processing.create_infotext(pc, pc.all_prompts, pc.all_seeds, pc.all_subseeds)
  530. return res
  531. with SharedSettingsStackHelper():
  532. processed = draw_xyz_grid(
  533. p,
  534. xs=xs,
  535. ys=ys,
  536. zs=zs,
  537. x_labels=[x_opt.format_value(p, x_opt, x) for x in xs],
  538. y_labels=[y_opt.format_value(p, y_opt, y) for y in ys],
  539. z_labels=[z_opt.format_value(p, z_opt, z) for z in zs],
  540. cell=cell,
  541. draw_legend=draw_legend,
  542. include_lone_images=include_lone_images,
  543. include_sub_grids=include_sub_grids,
  544. first_axes_processed=first_axes_processed,
  545. second_axes_processed=second_axes_processed,
  546. margin_size=margin_size
  547. )
  548. if not processed.images:
  549. # It broke, no further handling needed.
  550. return processed
  551. z_count = len(zs)
  552. # Set the grid infotexts to the real ones with extra_generation_params (1 main grid + z_count sub-grids)
  553. processed.infotexts[:1+z_count] = grid_infotext[:1+z_count]
  554. if not include_lone_images:
  555. # Don't need sub-images anymore, drop from list:
  556. processed.images = processed.images[:z_count+1]
  557. if opts.grid_save:
  558. # Auto-save main and sub-grids:
  559. grid_count = z_count + 1 if z_count > 1 else 1
  560. for g in range(grid_count):
  561. #TODO: See previous comment about intentional data misalignment.
  562. adj_g = g-1 if g > 0 else g
  563. images.save_image(processed.images[g], p.outpath_grids, "xyz_grid", info=processed.infotexts[g], extension=opts.grid_format, prompt=processed.all_prompts[adj_g], seed=processed.all_seeds[adj_g], grid=True, p=processed)
  564. if not include_sub_grids:
  565. # Done with sub-grids, drop all related information:
  566. for _ in range(z_count):
  567. del processed.images[1]
  568. del processed.all_prompts[1]
  569. del processed.all_seeds[1]
  570. del processed.infotexts[1]
  571. return processed