xyz_grid.py 39 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823
  1. from collections import namedtuple
  2. from copy import copy
  3. from itertools import permutations, chain
  4. import random
  5. import csv
  6. import os.path
  7. from io import StringIO
  8. from PIL import Image
  9. import numpy as np
  10. import modules.scripts as scripts
  11. import gradio as gr
  12. from modules import images, sd_samplers, processing, sd_models, sd_vae, sd_schedulers, errors
  13. from modules.processing import process_images, Processed, StableDiffusionProcessingTxt2Img
  14. from modules.shared import opts, state
  15. import modules.shared as shared
  16. import modules.sd_samplers
  17. import modules.sd_models
  18. import modules.sd_vae
  19. import re
  20. from modules.ui_components import ToolButton, InputAccordion
  21. fill_values_symbol = "\U0001f4d2" # 📒
  22. AxisInfo = namedtuple('AxisInfo', ['axis', 'values'])
  23. def apply_field(field):
  24. def fun(p, x, xs):
  25. setattr(p, field, x)
  26. return fun
  27. def apply_prompt(p, x, xs):
  28. if xs[0] not in p.prompt and xs[0] not in p.negative_prompt:
  29. raise RuntimeError(f"Prompt S/R did not find {xs[0]} in prompt or negative prompt.")
  30. p.prompt = p.prompt.replace(xs[0], x)
  31. p.negative_prompt = p.negative_prompt.replace(xs[0], x)
  32. def apply_order(p, x, xs):
  33. token_order = []
  34. # Initially grab the tokens from the prompt, so they can be replaced in order of earliest seen
  35. for token in x:
  36. token_order.append((p.prompt.find(token), token))
  37. token_order.sort(key=lambda t: t[0])
  38. prompt_parts = []
  39. # Split the prompt up, taking out the tokens
  40. for _, token in token_order:
  41. n = p.prompt.find(token)
  42. prompt_parts.append(p.prompt[0:n])
  43. p.prompt = p.prompt[n + len(token):]
  44. # Rebuild the prompt with the tokens in the order we want
  45. prompt_tmp = ""
  46. for idx, part in enumerate(prompt_parts):
  47. prompt_tmp += part
  48. prompt_tmp += x[idx]
  49. p.prompt = prompt_tmp + p.prompt
  50. def confirm_samplers(p, xs):
  51. for x in xs:
  52. if x.lower() not in sd_samplers.samplers_map:
  53. raise RuntimeError(f"Unknown sampler: {x}")
  54. def apply_checkpoint(p, x, xs):
  55. info = modules.sd_models.get_closet_checkpoint_match(x)
  56. if info is None:
  57. raise RuntimeError(f"Unknown checkpoint: {x}")
  58. p.override_settings['sd_model_checkpoint'] = info.name
  59. def confirm_checkpoints(p, xs):
  60. for x in xs:
  61. if modules.sd_models.get_closet_checkpoint_match(x) is None:
  62. raise RuntimeError(f"Unknown checkpoint: {x}")
  63. def confirm_checkpoints_or_none(p, xs):
  64. for x in xs:
  65. if x in (None, "", "None", "none"):
  66. continue
  67. if modules.sd_models.get_closet_checkpoint_match(x) is None:
  68. raise RuntimeError(f"Unknown checkpoint: {x}")
  69. def confirm_range(min_val, max_val, axis_label):
  70. """Generates a AxisOption.confirm() function that checks all values are within the specified range."""
  71. def confirm_range_fun(p, xs):
  72. for x in xs:
  73. if not (max_val >= x >= min_val):
  74. raise ValueError(f'{axis_label} value "{x}" out of range [{min_val}, {max_val}]')
  75. return confirm_range_fun
  76. def apply_size(p, x: str, xs) -> None:
  77. try:
  78. width, _, height = x.partition('x')
  79. width = int(width.strip())
  80. height = int(height.strip())
  81. p.width = width
  82. p.height = height
  83. except ValueError:
  84. print(f"Invalid size in XYZ plot: {x}")
  85. def find_vae(name: str):
  86. if (name := name.strip().lower()) in ('auto', 'automatic'):
  87. return 'Automatic'
  88. elif name == 'none':
  89. return 'None'
  90. return next((k for k in modules.sd_vae.vae_dict if k.lower() == name), print(f'No VAE found for {name}; using Automatic') or 'Automatic')
  91. def apply_vae(p, x, xs):
  92. p.override_settings['sd_vae'] = find_vae(x)
  93. def apply_styles(p: StableDiffusionProcessingTxt2Img, x: str, _):
  94. p.styles.extend(x.split(','))
  95. def apply_uni_pc_order(p, x, xs):
  96. p.override_settings['uni_pc_order'] = min(x, p.steps - 1)
  97. def apply_face_restore(p, opt, x):
  98. opt = opt.lower()
  99. if opt == 'codeformer':
  100. is_active = True
  101. p.face_restoration_model = 'CodeFormer'
  102. elif opt == 'gfpgan':
  103. is_active = True
  104. p.face_restoration_model = 'GFPGAN'
  105. else:
  106. is_active = opt in ('true', 'yes', 'y', '1')
  107. p.restore_faces = is_active
  108. def apply_override(field, boolean: bool = False):
  109. def fun(p, x, xs):
  110. if boolean:
  111. x = True if x.lower() == "true" else False
  112. p.override_settings[field] = x
  113. return fun
  114. def boolean_choice(reverse: bool = False):
  115. def choice():
  116. return ["False", "True"] if reverse else ["True", "False"]
  117. return choice
  118. def format_value_add_label(p, opt, x):
  119. if type(x) == float:
  120. x = round(x, 8)
  121. return f"{opt.label}: {x}"
  122. def format_value(p, opt, x):
  123. if type(x) == float:
  124. x = round(x, 8)
  125. return x
  126. def format_value_join_list(p, opt, x):
  127. return ", ".join(x)
  128. def do_nothing(p, x, xs):
  129. pass
  130. def format_nothing(p, opt, x):
  131. return ""
  132. def format_remove_path(p, opt, x):
  133. return os.path.basename(x)
  134. def str_permutations(x):
  135. """dummy function for specifying it in AxisOption's type when you want to get a list of permutations"""
  136. return x
  137. def list_to_csv_string(data_list):
  138. with StringIO() as o:
  139. csv.writer(o).writerow(data_list)
  140. return o.getvalue().strip()
  141. def csv_string_to_list_strip(data_str):
  142. return list(map(str.strip, chain.from_iterable(csv.reader(StringIO(data_str), skipinitialspace=True))))
  143. class AxisOption:
  144. def __init__(self, label, type, apply, format_value=format_value_add_label, confirm=None, cost=0.0, choices=None, prepare=None):
  145. self.label = label
  146. self.type = type
  147. self.apply = apply
  148. self.format_value = format_value
  149. self.confirm = confirm
  150. self.cost = cost
  151. self.prepare = prepare
  152. self.choices = choices
  153. class AxisOptionImg2Img(AxisOption):
  154. def __init__(self, *args, **kwargs):
  155. super().__init__(*args, **kwargs)
  156. self.is_img2img = True
  157. class AxisOptionTxt2Img(AxisOption):
  158. def __init__(self, *args, **kwargs):
  159. super().__init__(*args, **kwargs)
  160. self.is_img2img = False
  161. axis_options = [
  162. AxisOption("Nothing", str, do_nothing, format_value=format_nothing),
  163. AxisOption("Seed", int, apply_field("seed")),
  164. AxisOption("Var. seed", int, apply_field("subseed")),
  165. AxisOption("Var. strength", float, apply_field("subseed_strength")),
  166. AxisOption("Steps", int, apply_field("steps")),
  167. AxisOptionTxt2Img("Hires steps", int, apply_field("hr_second_pass_steps")),
  168. AxisOption("CFG Scale", float, apply_field("cfg_scale")),
  169. AxisOptionImg2Img("Image CFG Scale", float, apply_field("image_cfg_scale")),
  170. AxisOption("Prompt S/R", str, apply_prompt, format_value=format_value),
  171. AxisOption("Prompt order", str_permutations, apply_order, format_value=format_value_join_list),
  172. AxisOptionTxt2Img("Sampler", str, apply_field("sampler_name"), format_value=format_value, confirm=confirm_samplers, choices=lambda: [x.name for x in sd_samplers.samplers if x.name not in opts.hide_samplers]),
  173. AxisOptionTxt2Img("Hires sampler", str, apply_field("hr_sampler_name"), confirm=confirm_samplers, choices=lambda: [x.name for x in sd_samplers.samplers_for_img2img if x.name not in opts.hide_samplers]),
  174. AxisOptionImg2Img("Sampler", str, apply_field("sampler_name"), format_value=format_value, confirm=confirm_samplers, choices=lambda: [x.name for x in sd_samplers.samplers_for_img2img if x.name not in opts.hide_samplers]),
  175. AxisOption("Checkpoint name", str, apply_checkpoint, format_value=format_remove_path, confirm=confirm_checkpoints, cost=1.0, choices=lambda: sorted(sd_models.checkpoints_list, key=str.casefold)),
  176. AxisOption("Negative Guidance minimum sigma", float, apply_field("s_min_uncond")),
  177. AxisOption("Sigma Churn", float, apply_field("s_churn")),
  178. AxisOption("Sigma min", float, apply_field("s_tmin")),
  179. AxisOption("Sigma max", float, apply_field("s_tmax")),
  180. AxisOption("Sigma noise", float, apply_field("s_noise")),
  181. AxisOption("Schedule type", str, apply_field("scheduler"), choices=lambda: [x.label for x in sd_schedulers.schedulers]),
  182. AxisOption("Schedule min sigma", float, apply_override("sigma_min")),
  183. AxisOption("Schedule max sigma", float, apply_override("sigma_max")),
  184. AxisOption("Schedule rho", float, apply_override("rho")),
  185. AxisOption("Skip Early CFG", float, apply_override('skip_early_cond')),
  186. AxisOption("Beta schedule alpha", float, apply_override("beta_dist_alpha")),
  187. AxisOption("Beta schedule beta", float, apply_override("beta_dist_beta")),
  188. AxisOption("Eta", float, apply_field("eta")),
  189. AxisOption("Clip skip", int, apply_override('CLIP_stop_at_last_layers')),
  190. AxisOption("Denoising", float, apply_field("denoising_strength")),
  191. AxisOption("Initial noise multiplier", float, apply_field("initial_noise_multiplier")),
  192. AxisOption("Extra noise", float, apply_override("img2img_extra_noise")),
  193. AxisOptionTxt2Img("Hires upscaler", str, apply_field("hr_upscaler"), choices=lambda: [*shared.latent_upscale_modes, *[x.name for x in shared.sd_upscalers]]),
  194. AxisOptionImg2Img("Cond. Image Mask Weight", float, apply_field("inpainting_mask_weight")),
  195. AxisOption("VAE", str, apply_vae, cost=0.7, choices=lambda: ['Automatic', 'None'] + list(sd_vae.vae_dict)),
  196. AxisOption("Styles", str, apply_styles, choices=lambda: list(shared.prompt_styles.styles)),
  197. AxisOption("UniPC Order", int, apply_uni_pc_order, cost=0.5),
  198. AxisOption("Face restore", str, apply_face_restore, format_value=format_value),
  199. AxisOption("Token merging ratio", float, apply_override('token_merging_ratio')),
  200. AxisOption("Token merging ratio high-res", float, apply_override('token_merging_ratio_hr')),
  201. AxisOption("Always discard next-to-last sigma", str, apply_override('always_discard_next_to_last_sigma', boolean=True), choices=boolean_choice(reverse=True)),
  202. AxisOption("SGM noise multiplier", str, apply_override('sgm_noise_multiplier', boolean=True), choices=boolean_choice(reverse=True)),
  203. AxisOption("Refiner checkpoint", str, apply_field('refiner_checkpoint'), format_value=format_remove_path, confirm=confirm_checkpoints_or_none, cost=1.0, choices=lambda: ['None'] + sorted(sd_models.checkpoints_list, key=str.casefold)),
  204. AxisOption("Refiner switch at", float, apply_field('refiner_switch_at')),
  205. AxisOption("RNG source", str, apply_override("randn_source"), choices=lambda: ["GPU", "CPU", "NV"]),
  206. AxisOption("FP8 mode", str, apply_override("fp8_storage"), cost=0.9, choices=lambda: ["Disable", "Enable for SDXL", "Enable"]),
  207. AxisOption("Size", str, apply_size),
  208. ]
  209. def draw_xyz_grid(p, xs, ys, zs, x_labels, y_labels, z_labels, cell, draw_legend, include_lone_images, include_sub_grids, first_axes_processed, second_axes_processed, margin_size, draw_grid):
  210. hor_texts = [[images.GridAnnotation(x)] for x in x_labels]
  211. ver_texts = [[images.GridAnnotation(y)] for y in y_labels]
  212. title_texts = [[images.GridAnnotation(z)] for z in z_labels]
  213. list_size = (len(xs) * len(ys) * len(zs))
  214. processed_result = None
  215. state.job_count = list_size * p.n_iter
  216. def process_cell(x, y, z, ix, iy, iz):
  217. nonlocal processed_result
  218. def index(ix, iy, iz):
  219. return ix + iy * len(xs) + iz * len(xs) * len(ys)
  220. state.job = f"{index(ix, iy, iz) + 1} out of {list_size}"
  221. processed: Processed = cell(x, y, z, ix, iy, iz)
  222. if processed_result is None:
  223. # Use our first processed result object as a template container to hold our full results
  224. processed_result = copy(processed)
  225. processed_result.images = [None] * list_size
  226. processed_result.all_prompts = [None] * list_size
  227. processed_result.all_seeds = [None] * list_size
  228. processed_result.infotexts = [None] * list_size
  229. processed_result.index_of_first_image = 1
  230. idx = index(ix, iy, iz)
  231. if processed.images:
  232. # Non-empty list indicates some degree of success.
  233. processed_result.images[idx] = processed.images[0]
  234. processed_result.all_prompts[idx] = processed.prompt
  235. processed_result.all_seeds[idx] = processed.seed
  236. processed_result.infotexts[idx] = processed.infotexts[0]
  237. else:
  238. cell_mode = "P"
  239. cell_size = (processed_result.width, processed_result.height)
  240. if processed_result.images[0] is not None:
  241. cell_mode = processed_result.images[0].mode
  242. # This corrects size in case of batches:
  243. cell_size = processed_result.images[0].size
  244. processed_result.images[idx] = Image.new(cell_mode, cell_size)
  245. if first_axes_processed == 'x':
  246. for ix, x in enumerate(xs):
  247. if second_axes_processed == 'y':
  248. for iy, y in enumerate(ys):
  249. for iz, z in enumerate(zs):
  250. process_cell(x, y, z, ix, iy, iz)
  251. else:
  252. for iz, z in enumerate(zs):
  253. for iy, y in enumerate(ys):
  254. process_cell(x, y, z, ix, iy, iz)
  255. elif first_axes_processed == 'y':
  256. for iy, y in enumerate(ys):
  257. if second_axes_processed == 'x':
  258. for ix, x in enumerate(xs):
  259. for iz, z in enumerate(zs):
  260. process_cell(x, y, z, ix, iy, iz)
  261. else:
  262. for iz, z in enumerate(zs):
  263. for ix, x in enumerate(xs):
  264. process_cell(x, y, z, ix, iy, iz)
  265. elif first_axes_processed == 'z':
  266. for iz, z in enumerate(zs):
  267. if second_axes_processed == 'x':
  268. for ix, x in enumerate(xs):
  269. for iy, y in enumerate(ys):
  270. process_cell(x, y, z, ix, iy, iz)
  271. else:
  272. for iy, y in enumerate(ys):
  273. for ix, x in enumerate(xs):
  274. process_cell(x, y, z, ix, iy, iz)
  275. if not processed_result:
  276. # Should never happen, I've only seen it on one of four open tabs and it needed to refresh.
  277. print("Unexpected error: Processing could not begin, you may need to refresh the tab or restart the service.")
  278. return Processed(p, [])
  279. elif not any(processed_result.images):
  280. print("Unexpected error: draw_xyz_grid failed to return even a single processed image")
  281. return Processed(p, [])
  282. if draw_grid:
  283. z_count = len(zs)
  284. for i in range(z_count):
  285. start_index = (i * len(xs) * len(ys)) + i
  286. end_index = start_index + len(xs) * len(ys)
  287. grid = images.image_grid(processed_result.images[start_index:end_index], rows=len(ys))
  288. if draw_legend:
  289. grid_max_w, grid_max_h = map(max, zip(*(img.size for img in processed_result.images[start_index:end_index])))
  290. grid = images.draw_grid_annotations(grid, grid_max_w, grid_max_h, hor_texts, ver_texts, margin_size)
  291. processed_result.images.insert(i, grid)
  292. processed_result.all_prompts.insert(i, processed_result.all_prompts[start_index])
  293. processed_result.all_seeds.insert(i, processed_result.all_seeds[start_index])
  294. processed_result.infotexts.insert(i, processed_result.infotexts[start_index])
  295. z_grid = images.image_grid(processed_result.images[:z_count], rows=1)
  296. z_sub_grid_max_w, z_sub_grid_max_h = map(max, zip(*(img.size for img in processed_result.images[:z_count])))
  297. if draw_legend:
  298. z_grid = images.draw_grid_annotations(z_grid, z_sub_grid_max_w, z_sub_grid_max_h, title_texts, [[images.GridAnnotation()]])
  299. processed_result.images.insert(0, z_grid)
  300. # TODO: Deeper aspects of the program rely on grid info being misaligned between metadata arrays, which is not ideal.
  301. # processed_result.all_prompts.insert(0, processed_result.all_prompts[0])
  302. # processed_result.all_seeds.insert(0, processed_result.all_seeds[0])
  303. processed_result.infotexts.insert(0, processed_result.infotexts[0])
  304. return processed_result
  305. class SharedSettingsStackHelper(object):
  306. def __enter__(self):
  307. pass
  308. def __exit__(self, exc_type, exc_value, tb):
  309. modules.sd_models.reload_model_weights()
  310. modules.sd_vae.reload_vae_weights()
  311. re_range = re.compile(r"\s*([+-]?\s*\d+)\s*-\s*([+-]?\s*\d+)(?:\s*\(([+-]\d+)\s*\))?\s*")
  312. re_range_float = re.compile(r"\s*([+-]?\s*\d+(?:.\d*)?)\s*-\s*([+-]?\s*\d+(?:.\d*)?)(?:\s*\(([+-]\d+(?:.\d*)?)\s*\))?\s*")
  313. re_range_count = re.compile(r"\s*([+-]?\s*\d+)\s*-\s*([+-]?\s*\d+)(?:\s*\[(\d+)\s*])?\s*")
  314. re_range_count_float = re.compile(r"\s*([+-]?\s*\d+(?:.\d*)?)\s*-\s*([+-]?\s*\d+(?:.\d*)?)(?:\s*\[(\d+(?:.\d*)?)\s*])?\s*")
  315. class Script(scripts.Script):
  316. def title(self):
  317. return "X/Y/Z plot"
  318. def ui(self, is_img2img):
  319. self.current_axis_options = [x for x in axis_options if type(x) == AxisOption or x.is_img2img == is_img2img]
  320. with gr.Row():
  321. with gr.Column(scale=19):
  322. with gr.Row():
  323. x_type = gr.Dropdown(label="X type", choices=[x.label for x in self.current_axis_options], value=self.current_axis_options[1].label, type="index", elem_id=self.elem_id("x_type"))
  324. x_values = gr.Textbox(label="X values", lines=1, elem_id=self.elem_id("x_values"))
  325. x_values_dropdown = gr.Dropdown(label="X values", visible=False, multiselect=True, interactive=True)
  326. fill_x_button = ToolButton(value=fill_values_symbol, elem_id="xyz_grid_fill_x_tool_button", visible=False)
  327. with gr.Row():
  328. y_type = gr.Dropdown(label="Y type", choices=[x.label for x in self.current_axis_options], value=self.current_axis_options[0].label, type="index", elem_id=self.elem_id("y_type"))
  329. y_values = gr.Textbox(label="Y values", lines=1, elem_id=self.elem_id("y_values"))
  330. y_values_dropdown = gr.Dropdown(label="Y values", visible=False, multiselect=True, interactive=True)
  331. fill_y_button = ToolButton(value=fill_values_symbol, elem_id="xyz_grid_fill_y_tool_button", visible=False)
  332. with gr.Row():
  333. z_type = gr.Dropdown(label="Z type", choices=[x.label for x in self.current_axis_options], value=self.current_axis_options[0].label, type="index", elem_id=self.elem_id("z_type"))
  334. z_values = gr.Textbox(label="Z values", lines=1, elem_id=self.elem_id("z_values"))
  335. z_values_dropdown = gr.Dropdown(label="Z values", visible=False, multiselect=True, interactive=True)
  336. fill_z_button = ToolButton(value=fill_values_symbol, elem_id="xyz_grid_fill_z_tool_button", visible=False)
  337. with gr.Row(variant="compact", elem_id="axis_options"):
  338. with gr.Column():
  339. no_fixed_seeds = gr.Checkbox(label='Keep -1 for seeds', value=False, elem_id=self.elem_id("no_fixed_seeds"))
  340. with gr.Row():
  341. vary_seeds_x = gr.Checkbox(label='Vary seeds for X', value=False, min_width=80, elem_id=self.elem_id("vary_seeds_x"), tooltip="Use different seeds for images along X axis.")
  342. vary_seeds_y = gr.Checkbox(label='Vary seeds for Y', value=False, min_width=80, elem_id=self.elem_id("vary_seeds_y"), tooltip="Use different seeds for images along Y axis.")
  343. vary_seeds_z = gr.Checkbox(label='Vary seeds for Z', value=False, min_width=80, elem_id=self.elem_id("vary_seeds_z"), tooltip="Use different seeds for images along Z axis.")
  344. with gr.Column():
  345. include_lone_images = gr.Checkbox(label='Include Sub Images', value=False, elem_id=self.elem_id("include_lone_images"))
  346. csv_mode = gr.Checkbox(label='Use text inputs instead of dropdowns', value=False, elem_id=self.elem_id("csv_mode"))
  347. with InputAccordion(True, label='Draw grid', elem_id=self.elem_id('draw_grid')) as draw_grid:
  348. with gr.Row():
  349. include_sub_grids = gr.Checkbox(label='Include Sub Grids', value=False, elem_id=self.elem_id("include_sub_grids"))
  350. draw_legend = gr.Checkbox(label='Draw legend', value=True, elem_id=self.elem_id("draw_legend"))
  351. margin_size = gr.Slider(label="Grid margins (px)", minimum=0, maximum=500, value=0, step=2, elem_id=self.elem_id("margin_size"))
  352. with gr.Row(variant="compact", elem_id="swap_axes"):
  353. swap_xy_axes_button = gr.Button(value="Swap X/Y axes", elem_id="xy_grid_swap_axes_button")
  354. swap_yz_axes_button = gr.Button(value="Swap Y/Z axes", elem_id="yz_grid_swap_axes_button")
  355. swap_xz_axes_button = gr.Button(value="Swap X/Z axes", elem_id="xz_grid_swap_axes_button")
  356. def swap_axes(axis1_type, axis1_values, axis1_values_dropdown, axis2_type, axis2_values, axis2_values_dropdown):
  357. return self.current_axis_options[axis2_type].label, axis2_values, axis2_values_dropdown, self.current_axis_options[axis1_type].label, axis1_values, axis1_values_dropdown
  358. xy_swap_args = [x_type, x_values, x_values_dropdown, y_type, y_values, y_values_dropdown]
  359. swap_xy_axes_button.click(swap_axes, inputs=xy_swap_args, outputs=xy_swap_args)
  360. yz_swap_args = [y_type, y_values, y_values_dropdown, z_type, z_values, z_values_dropdown]
  361. swap_yz_axes_button.click(swap_axes, inputs=yz_swap_args, outputs=yz_swap_args)
  362. xz_swap_args = [x_type, x_values, x_values_dropdown, z_type, z_values, z_values_dropdown]
  363. swap_xz_axes_button.click(swap_axes, inputs=xz_swap_args, outputs=xz_swap_args)
  364. def fill(axis_type, csv_mode):
  365. axis = self.current_axis_options[axis_type]
  366. if axis.choices:
  367. if csv_mode:
  368. return list_to_csv_string(axis.choices()), gr.update()
  369. else:
  370. return gr.update(), axis.choices()
  371. else:
  372. return gr.update(), gr.update()
  373. fill_x_button.click(fn=fill, inputs=[x_type, csv_mode], outputs=[x_values, x_values_dropdown])
  374. fill_y_button.click(fn=fill, inputs=[y_type, csv_mode], outputs=[y_values, y_values_dropdown])
  375. fill_z_button.click(fn=fill, inputs=[z_type, csv_mode], outputs=[z_values, z_values_dropdown])
  376. def select_axis(axis_type, axis_values, axis_values_dropdown, csv_mode):
  377. axis_type = axis_type or 0 # if axle type is None set to 0
  378. choices = self.current_axis_options[axis_type].choices
  379. has_choices = choices is not None
  380. if has_choices:
  381. choices = choices()
  382. if csv_mode:
  383. if axis_values_dropdown:
  384. axis_values = list_to_csv_string(list(filter(lambda x: x in choices, axis_values_dropdown)))
  385. axis_values_dropdown = []
  386. else:
  387. if axis_values:
  388. axis_values_dropdown = list(filter(lambda x: x in choices, csv_string_to_list_strip(axis_values)))
  389. axis_values = ""
  390. return (gr.Button.update(visible=has_choices), gr.Textbox.update(visible=not has_choices or csv_mode, value=axis_values),
  391. gr.update(choices=choices if has_choices else None, visible=has_choices and not csv_mode, value=axis_values_dropdown))
  392. x_type.change(fn=select_axis, inputs=[x_type, x_values, x_values_dropdown, csv_mode], outputs=[fill_x_button, x_values, x_values_dropdown])
  393. y_type.change(fn=select_axis, inputs=[y_type, y_values, y_values_dropdown, csv_mode], outputs=[fill_y_button, y_values, y_values_dropdown])
  394. z_type.change(fn=select_axis, inputs=[z_type, z_values, z_values_dropdown, csv_mode], outputs=[fill_z_button, z_values, z_values_dropdown])
  395. def change_choice_mode(csv_mode, x_type, x_values, x_values_dropdown, y_type, y_values, y_values_dropdown, z_type, z_values, z_values_dropdown):
  396. _fill_x_button, _x_values, _x_values_dropdown = select_axis(x_type, x_values, x_values_dropdown, csv_mode)
  397. _fill_y_button, _y_values, _y_values_dropdown = select_axis(y_type, y_values, y_values_dropdown, csv_mode)
  398. _fill_z_button, _z_values, _z_values_dropdown = select_axis(z_type, z_values, z_values_dropdown, csv_mode)
  399. return _fill_x_button, _x_values, _x_values_dropdown, _fill_y_button, _y_values, _y_values_dropdown, _fill_z_button, _z_values, _z_values_dropdown
  400. csv_mode.change(fn=change_choice_mode, inputs=[csv_mode, x_type, x_values, x_values_dropdown, y_type, y_values, y_values_dropdown, z_type, z_values, z_values_dropdown], outputs=[fill_x_button, x_values, x_values_dropdown, fill_y_button, y_values, y_values_dropdown, fill_z_button, z_values, z_values_dropdown])
  401. def get_dropdown_update_from_params(axis, params):
  402. val_key = f"{axis} Values"
  403. vals = params.get(val_key, "")
  404. valslist = csv_string_to_list_strip(vals)
  405. return gr.update(value=valslist)
  406. self.infotext_fields = (
  407. (x_type, "X Type"),
  408. (x_values, "X Values"),
  409. (x_values_dropdown, lambda params: get_dropdown_update_from_params("X", params)),
  410. (y_type, "Y Type"),
  411. (y_values, "Y Values"),
  412. (y_values_dropdown, lambda params: get_dropdown_update_from_params("Y", params)),
  413. (z_type, "Z Type"),
  414. (z_values, "Z Values"),
  415. (z_values_dropdown, lambda params: get_dropdown_update_from_params("Z", params)),
  416. )
  417. return [x_type, x_values, x_values_dropdown, y_type, y_values, y_values_dropdown, z_type, z_values, z_values_dropdown, draw_legend, include_lone_images, include_sub_grids, no_fixed_seeds, vary_seeds_x, vary_seeds_y, vary_seeds_z, margin_size, csv_mode, draw_grid]
  418. def run(self, p, x_type, x_values, x_values_dropdown, y_type, y_values, y_values_dropdown, z_type, z_values, z_values_dropdown, draw_legend, include_lone_images, include_sub_grids, no_fixed_seeds, vary_seeds_x, vary_seeds_y, vary_seeds_z, margin_size, csv_mode, draw_grid):
  419. x_type, y_type, z_type = x_type or 0, y_type or 0, z_type or 0 # if axle type is None set to 0
  420. if not no_fixed_seeds:
  421. modules.processing.fix_seed(p)
  422. if not opts.return_grid:
  423. p.batch_size = 1
  424. def process_axis(opt, vals, vals_dropdown):
  425. if opt.label == 'Nothing':
  426. return [0]
  427. if opt.choices is not None and not csv_mode:
  428. valslist = vals_dropdown
  429. elif opt.prepare is not None:
  430. valslist = opt.prepare(vals)
  431. else:
  432. valslist = csv_string_to_list_strip(vals)
  433. if opt.type == int:
  434. valslist_ext = []
  435. for val in valslist:
  436. if val.strip() == '':
  437. continue
  438. m = re_range.fullmatch(val)
  439. mc = re_range_count.fullmatch(val)
  440. if m is not None:
  441. start = int(m.group(1))
  442. end = int(m.group(2)) + 1
  443. step = int(m.group(3)) if m.group(3) is not None else 1
  444. valslist_ext += list(range(start, end, step))
  445. elif mc is not None:
  446. start = int(mc.group(1))
  447. end = int(mc.group(2))
  448. num = int(mc.group(3)) if mc.group(3) is not None else 1
  449. valslist_ext += [int(x) for x in np.linspace(start=start, stop=end, num=num).tolist()]
  450. else:
  451. valslist_ext.append(val)
  452. valslist = valslist_ext
  453. elif opt.type == float:
  454. valslist_ext = []
  455. for val in valslist:
  456. if val.strip() == '':
  457. continue
  458. m = re_range_float.fullmatch(val)
  459. mc = re_range_count_float.fullmatch(val)
  460. if m is not None:
  461. start = float(m.group(1))
  462. end = float(m.group(2))
  463. step = float(m.group(3)) if m.group(3) is not None else 1
  464. valslist_ext += np.arange(start, end + step, step).tolist()
  465. elif mc is not None:
  466. start = float(mc.group(1))
  467. end = float(mc.group(2))
  468. num = int(mc.group(3)) if mc.group(3) is not None else 1
  469. valslist_ext += np.linspace(start=start, stop=end, num=num).tolist()
  470. else:
  471. valslist_ext.append(val)
  472. valslist = valslist_ext
  473. elif opt.type == str_permutations:
  474. valslist = list(permutations(valslist))
  475. valslist = [opt.type(x) for x in valslist]
  476. # Confirm options are valid before starting
  477. if opt.confirm:
  478. opt.confirm(p, valslist)
  479. return valslist
  480. x_opt = self.current_axis_options[x_type]
  481. if x_opt.choices is not None and not csv_mode:
  482. x_values = list_to_csv_string(x_values_dropdown)
  483. xs = process_axis(x_opt, x_values, x_values_dropdown)
  484. y_opt = self.current_axis_options[y_type]
  485. if y_opt.choices is not None and not csv_mode:
  486. y_values = list_to_csv_string(y_values_dropdown)
  487. ys = process_axis(y_opt, y_values, y_values_dropdown)
  488. z_opt = self.current_axis_options[z_type]
  489. if z_opt.choices is not None and not csv_mode:
  490. z_values = list_to_csv_string(z_values_dropdown)
  491. zs = process_axis(z_opt, z_values, z_values_dropdown)
  492. # this could be moved to common code, but unlikely to be ever triggered anywhere else
  493. Image.MAX_IMAGE_PIXELS = None # disable check in Pillow and rely on check below to allow large custom image sizes
  494. grid_mp = round(len(xs) * len(ys) * len(zs) * p.width * p.height / 1000000)
  495. assert grid_mp < opts.img_max_size_mp, f'Error: Resulting grid would be too large ({grid_mp} MPixels) (max configured size is {opts.img_max_size_mp} MPixels)'
  496. def fix_axis_seeds(axis_opt, axis_list):
  497. if axis_opt.label in ['Seed', 'Var. seed']:
  498. return [int(random.randrange(4294967294)) if val is None or val == '' or val == -1 else val for val in axis_list]
  499. else:
  500. return axis_list
  501. if not no_fixed_seeds:
  502. xs = fix_axis_seeds(x_opt, xs)
  503. ys = fix_axis_seeds(y_opt, ys)
  504. zs = fix_axis_seeds(z_opt, zs)
  505. if x_opt.label == 'Steps':
  506. total_steps = sum(xs) * len(ys) * len(zs)
  507. elif y_opt.label == 'Steps':
  508. total_steps = sum(ys) * len(xs) * len(zs)
  509. elif z_opt.label == 'Steps':
  510. total_steps = sum(zs) * len(xs) * len(ys)
  511. else:
  512. total_steps = p.steps * len(xs) * len(ys) * len(zs)
  513. if isinstance(p, StableDiffusionProcessingTxt2Img) and p.enable_hr:
  514. if x_opt.label == "Hires steps":
  515. total_steps += sum(xs) * len(ys) * len(zs)
  516. elif y_opt.label == "Hires steps":
  517. total_steps += sum(ys) * len(xs) * len(zs)
  518. elif z_opt.label == "Hires steps":
  519. total_steps += sum(zs) * len(xs) * len(ys)
  520. elif p.hr_second_pass_steps:
  521. total_steps += p.hr_second_pass_steps * len(xs) * len(ys) * len(zs)
  522. else:
  523. total_steps *= 2
  524. total_steps *= p.n_iter
  525. image_cell_count = p.n_iter * p.batch_size
  526. cell_console_text = f"; {image_cell_count} images per cell" if image_cell_count > 1 else ""
  527. plural_s = 's' if len(zs) > 1 else ''
  528. print(f"X/Y/Z plot will create {len(xs) * len(ys) * len(zs) * image_cell_count} images on {len(zs)} {len(xs)}x{len(ys)} grid{plural_s}{cell_console_text}. (Total steps to process: {total_steps})")
  529. shared.total_tqdm.updateTotal(total_steps)
  530. state.xyz_plot_x = AxisInfo(x_opt, xs)
  531. state.xyz_plot_y = AxisInfo(y_opt, ys)
  532. state.xyz_plot_z = AxisInfo(z_opt, zs)
  533. # If one of the axes is very slow to change between (like SD model
  534. # checkpoint), then make sure it is in the outer iteration of the nested
  535. # `for` loop.
  536. first_axes_processed = 'z'
  537. second_axes_processed = 'y'
  538. if x_opt.cost > y_opt.cost and x_opt.cost > z_opt.cost:
  539. first_axes_processed = 'x'
  540. if y_opt.cost > z_opt.cost:
  541. second_axes_processed = 'y'
  542. else:
  543. second_axes_processed = 'z'
  544. elif y_opt.cost > x_opt.cost and y_opt.cost > z_opt.cost:
  545. first_axes_processed = 'y'
  546. if x_opt.cost > z_opt.cost:
  547. second_axes_processed = 'x'
  548. else:
  549. second_axes_processed = 'z'
  550. elif z_opt.cost > x_opt.cost and z_opt.cost > y_opt.cost:
  551. first_axes_processed = 'z'
  552. if x_opt.cost > y_opt.cost:
  553. second_axes_processed = 'x'
  554. else:
  555. second_axes_processed = 'y'
  556. grid_infotext = [None] * (1 + len(zs))
  557. def cell(x, y, z, ix, iy, iz):
  558. if shared.state.interrupted or state.stopping_generation:
  559. return Processed(p, [], p.seed, "")
  560. pc = copy(p)
  561. pc.styles = pc.styles[:]
  562. x_opt.apply(pc, x, xs)
  563. y_opt.apply(pc, y, ys)
  564. z_opt.apply(pc, z, zs)
  565. xdim = len(xs) if vary_seeds_x else 1
  566. ydim = len(ys) if vary_seeds_y else 1
  567. if vary_seeds_x:
  568. pc.seed += ix
  569. if vary_seeds_y:
  570. pc.seed += iy * xdim
  571. if vary_seeds_z:
  572. pc.seed += iz * xdim * ydim
  573. try:
  574. res = process_images(pc)
  575. except Exception as e:
  576. errors.display(e, "generating image for xyz plot")
  577. res = Processed(p, [], p.seed, "")
  578. # Sets subgrid infotexts
  579. subgrid_index = 1 + iz
  580. if grid_infotext[subgrid_index] is None and ix == 0 and iy == 0:
  581. pc.extra_generation_params = copy(pc.extra_generation_params)
  582. pc.extra_generation_params['Script'] = self.title()
  583. if x_opt.label != 'Nothing':
  584. pc.extra_generation_params["X Type"] = x_opt.label
  585. pc.extra_generation_params["X Values"] = x_values
  586. if x_opt.label in ["Seed", "Var. seed"] and not no_fixed_seeds:
  587. pc.extra_generation_params["Fixed X Values"] = ", ".join([str(x) for x in xs])
  588. if y_opt.label != 'Nothing':
  589. pc.extra_generation_params["Y Type"] = y_opt.label
  590. pc.extra_generation_params["Y Values"] = y_values
  591. if y_opt.label in ["Seed", "Var. seed"] and not no_fixed_seeds:
  592. pc.extra_generation_params["Fixed Y Values"] = ", ".join([str(y) for y in ys])
  593. grid_infotext[subgrid_index] = processing.create_infotext(pc, pc.all_prompts, pc.all_seeds, pc.all_subseeds)
  594. # Sets main grid infotext
  595. if grid_infotext[0] is None and ix == 0 and iy == 0 and iz == 0:
  596. pc.extra_generation_params = copy(pc.extra_generation_params)
  597. if z_opt.label != 'Nothing':
  598. pc.extra_generation_params["Z Type"] = z_opt.label
  599. pc.extra_generation_params["Z Values"] = z_values
  600. if z_opt.label in ["Seed", "Var. seed"] and not no_fixed_seeds:
  601. pc.extra_generation_params["Fixed Z Values"] = ", ".join([str(z) for z in zs])
  602. grid_infotext[0] = processing.create_infotext(pc, pc.all_prompts, pc.all_seeds, pc.all_subseeds)
  603. return res
  604. with SharedSettingsStackHelper():
  605. processed = draw_xyz_grid(
  606. p,
  607. xs=xs,
  608. ys=ys,
  609. zs=zs,
  610. x_labels=[x_opt.format_value(p, x_opt, x) for x in xs],
  611. y_labels=[y_opt.format_value(p, y_opt, y) for y in ys],
  612. z_labels=[z_opt.format_value(p, z_opt, z) for z in zs],
  613. cell=cell,
  614. draw_legend=draw_legend,
  615. include_lone_images=include_lone_images,
  616. include_sub_grids=include_sub_grids,
  617. first_axes_processed=first_axes_processed,
  618. second_axes_processed=second_axes_processed,
  619. margin_size=margin_size,
  620. draw_grid=draw_grid,
  621. )
  622. if not processed.images:
  623. # It broke, no further handling needed.
  624. return processed
  625. z_count = len(zs)
  626. if draw_grid:
  627. # Set the grid infotexts to the real ones with extra_generation_params (1 main grid + z_count sub-grids)
  628. processed.infotexts[:1 + z_count] = grid_infotext[:1 + z_count]
  629. if not include_lone_images:
  630. # Don't need sub-images anymore, drop from list:
  631. processed.images = processed.images[:z_count + 1] if draw_grid else []
  632. if draw_grid and opts.grid_save:
  633. # Auto-save main and sub-grids:
  634. grid_count = z_count + 1 if z_count > 1 else 1
  635. for g in range(grid_count):
  636. # TODO: See previous comment about intentional data misalignment.
  637. adj_g = g - 1 if g > 0 else g
  638. images.save_image(processed.images[g], p.outpath_grids, "xyz_grid", info=processed.infotexts[g], extension=opts.grid_format, prompt=processed.all_prompts[adj_g], seed=processed.all_seeds[adj_g], grid=True, p=processed)
  639. if not include_sub_grids: # if not include_sub_grids then skip saving after the first grid
  640. break
  641. if draw_grid and not include_sub_grids:
  642. # Done with sub-grids, drop all related information:
  643. for _ in range(z_count):
  644. del processed.images[1]
  645. del processed.all_prompts[1]
  646. del processed.all_seeds[1]
  647. del processed.infotexts[1]
  648. return processed