xy_grid.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482
  1. from collections import namedtuple
  2. from copy import copy
  3. from itertools import permutations, chain
  4. import random
  5. import csv
  6. from io import StringIO
  7. from PIL import Image
  8. import numpy as np
  9. import modules.scripts as scripts
  10. import gradio as gr
  11. from modules import images, paths, sd_samplers, processing, sd_models, sd_vae
  12. from modules.processing import process_images, Processed, StableDiffusionProcessingTxt2Img
  13. from modules.shared import opts, cmd_opts, state
  14. import modules.shared as shared
  15. import modules.sd_samplers
  16. import modules.sd_models
  17. import modules.sd_vae
  18. import glob
  19. import os
  20. import re
  21. from modules.ui_components import ToolButton
  22. fill_values_symbol = "\U0001f4d2" # 📒
  23. def apply_field(field):
  24. def fun(p, x, xs):
  25. setattr(p, field, x)
  26. return fun
  27. def apply_prompt(p, x, xs):
  28. if xs[0] not in p.prompt and xs[0] not in p.negative_prompt:
  29. raise RuntimeError(f"Prompt S/R did not find {xs[0]} in prompt or negative prompt.")
  30. p.prompt = p.prompt.replace(xs[0], x)
  31. p.negative_prompt = p.negative_prompt.replace(xs[0], x)
  32. def apply_order(p, x, xs):
  33. token_order = []
  34. # Initally grab the tokens from the prompt, so they can be replaced in order of earliest seen
  35. for token in x:
  36. token_order.append((p.prompt.find(token), token))
  37. token_order.sort(key=lambda t: t[0])
  38. prompt_parts = []
  39. # Split the prompt up, taking out the tokens
  40. for _, token in token_order:
  41. n = p.prompt.find(token)
  42. prompt_parts.append(p.prompt[0:n])
  43. p.prompt = p.prompt[n + len(token):]
  44. # Rebuild the prompt with the tokens in the order we want
  45. prompt_tmp = ""
  46. for idx, part in enumerate(prompt_parts):
  47. prompt_tmp += part
  48. prompt_tmp += x[idx]
  49. p.prompt = prompt_tmp + p.prompt
  50. def apply_sampler(p, x, xs):
  51. sampler_name = sd_samplers.samplers_map.get(x.lower(), None)
  52. if sampler_name is None:
  53. raise RuntimeError(f"Unknown sampler: {x}")
  54. p.sampler_name = sampler_name
  55. def confirm_samplers(p, xs):
  56. for x in xs:
  57. if x.lower() not in sd_samplers.samplers_map:
  58. raise RuntimeError(f"Unknown sampler: {x}")
  59. def apply_checkpoint(p, x, xs):
  60. info = modules.sd_models.get_closet_checkpoint_match(x)
  61. if info is None:
  62. raise RuntimeError(f"Unknown checkpoint: {x}")
  63. modules.sd_models.reload_model_weights(shared.sd_model, info)
  64. def confirm_checkpoints(p, xs):
  65. for x in xs:
  66. if modules.sd_models.get_closet_checkpoint_match(x) is None:
  67. raise RuntimeError(f"Unknown checkpoint: {x}")
  68. def apply_clip_skip(p, x, xs):
  69. opts.data["CLIP_stop_at_last_layers"] = x
  70. def apply_upscale_latent_space(p, x, xs):
  71. if x.lower().strip() != '0':
  72. opts.data["use_scale_latent_for_hires_fix"] = True
  73. else:
  74. opts.data["use_scale_latent_for_hires_fix"] = False
  75. def find_vae(name: str):
  76. if name.lower() in ['auto', 'automatic']:
  77. return modules.sd_vae.unspecified
  78. if name.lower() == 'none':
  79. return None
  80. else:
  81. choices = [x for x in sorted(modules.sd_vae.vae_dict, key=lambda x: len(x)) if name.lower().strip() in x.lower()]
  82. if len(choices) == 0:
  83. print(f"No VAE found for {name}; using automatic")
  84. return modules.sd_vae.unspecified
  85. else:
  86. return modules.sd_vae.vae_dict[choices[0]]
  87. def apply_vae(p, x, xs):
  88. modules.sd_vae.reload_vae_weights(shared.sd_model, vae_file=find_vae(x))
  89. def apply_styles(p: StableDiffusionProcessingTxt2Img, x: str, _):
  90. p.styles = x.split(',')
  91. def format_value_add_label(p, opt, x):
  92. if type(x) == float:
  93. x = round(x, 8)
  94. return f"{opt.label}: {x}"
  95. def format_value(p, opt, x):
  96. if type(x) == float:
  97. x = round(x, 8)
  98. return x
  99. def format_value_join_list(p, opt, x):
  100. return ", ".join(x)
  101. def do_nothing(p, x, xs):
  102. pass
  103. def format_nothing(p, opt, x):
  104. return ""
  105. def str_permutations(x):
  106. """dummy function for specifying it in AxisOption's type when you want to get a list of permutations"""
  107. return x
  108. class AxisOption:
  109. def __init__(self, label, type, apply, format_value=format_value_add_label, confirm=None, cost=0.0, choices=None):
  110. self.label = label
  111. self.type = type
  112. self.apply = apply
  113. self.format_value = format_value
  114. self.confirm = confirm
  115. self.cost = cost
  116. self.choices = choices
  117. self.is_img2img = False
  118. class AxisOptionImg2Img(AxisOption):
  119. def __init__(self, *args, **kwargs):
  120. super().__init__(*args, **kwargs)
  121. self.is_img2img = False
  122. axis_options = [
  123. AxisOption("Nothing", str, do_nothing, format_value=format_nothing),
  124. AxisOption("Seed", int, apply_field("seed")),
  125. AxisOption("Var. seed", int, apply_field("subseed")),
  126. AxisOption("Var. strength", float, apply_field("subseed_strength")),
  127. AxisOption("Steps", int, apply_field("steps")),
  128. AxisOption("CFG Scale", float, apply_field("cfg_scale")),
  129. AxisOption("Prompt S/R", str, apply_prompt, format_value=format_value),
  130. AxisOption("Prompt order", str_permutations, apply_order, format_value=format_value_join_list),
  131. AxisOption("Sampler", str, apply_sampler, format_value=format_value, confirm=confirm_samplers, choices=lambda: [x.name for x in sd_samplers.samplers]),
  132. AxisOption("Checkpoint name", str, apply_checkpoint, format_value=format_value, confirm=confirm_checkpoints, cost=1.0, choices=lambda: list(sd_models.checkpoints_list)),
  133. AxisOption("Sigma Churn", float, apply_field("s_churn")),
  134. AxisOption("Sigma min", float, apply_field("s_tmin")),
  135. AxisOption("Sigma max", float, apply_field("s_tmax")),
  136. AxisOption("Sigma noise", float, apply_field("s_noise")),
  137. AxisOption("Eta", float, apply_field("eta")),
  138. AxisOption("Clip skip", int, apply_clip_skip),
  139. AxisOption("Denoising", float, apply_field("denoising_strength")),
  140. AxisOption("Hires upscaler", str, apply_field("hr_upscaler"), choices=lambda: [x.name for x in shared.sd_upscalers]),
  141. AxisOption("Cond. Image Mask Weight", float, apply_field("inpainting_mask_weight")),
  142. AxisOption("VAE", str, apply_vae, cost=0.7, choices=lambda: list(sd_vae.vae_dict)),
  143. AxisOption("Styles", str, apply_styles, choices=lambda: list(shared.prompt_styles.styles)),
  144. ]
  145. def draw_xy_grid(p, xs, ys, x_labels, y_labels, cell, draw_legend, include_lone_images, swap_axes_processing_order):
  146. ver_texts = [[images.GridAnnotation(y)] for y in y_labels]
  147. hor_texts = [[images.GridAnnotation(x)] for x in x_labels]
  148. # Temporary list of all the images that are generated to be populated into the grid.
  149. # Will be filled with empty images for any individual step that fails to process properly
  150. image_cache = [None] * (len(xs) * len(ys))
  151. processed_result = None
  152. cell_mode = "P"
  153. cell_size = (1, 1)
  154. state.job_count = len(xs) * len(ys) * p.n_iter
  155. def process_cell(x, y, ix, iy):
  156. nonlocal image_cache, processed_result, cell_mode, cell_size
  157. state.job = f"{ix + iy * len(xs) + 1} out of {len(xs) * len(ys)}"
  158. processed: Processed = cell(x, y)
  159. try:
  160. # this dereference will throw an exception if the image was not processed
  161. # (this happens in cases such as if the user stops the process from the UI)
  162. processed_image = processed.images[0]
  163. if processed_result is None:
  164. # Use our first valid processed result as a template container to hold our full results
  165. processed_result = copy(processed)
  166. cell_mode = processed_image.mode
  167. cell_size = processed_image.size
  168. processed_result.images = [Image.new(cell_mode, cell_size)]
  169. image_cache[ix + iy * len(xs)] = processed_image
  170. if include_lone_images:
  171. processed_result.images.append(processed_image)
  172. processed_result.all_prompts.append(processed.prompt)
  173. processed_result.all_seeds.append(processed.seed)
  174. processed_result.infotexts.append(processed.infotexts[0])
  175. except:
  176. image_cache[ix + iy * len(xs)] = Image.new(cell_mode, cell_size)
  177. if swap_axes_processing_order:
  178. for ix, x in enumerate(xs):
  179. for iy, y in enumerate(ys):
  180. process_cell(x, y, ix, iy)
  181. else:
  182. for iy, y in enumerate(ys):
  183. for ix, x in enumerate(xs):
  184. process_cell(x, y, ix, iy)
  185. if not processed_result:
  186. print("Unexpected error: draw_xy_grid failed to return even a single processed image")
  187. return Processed(p, [])
  188. grid = images.image_grid(image_cache, rows=len(ys))
  189. if draw_legend:
  190. grid = images.draw_grid_annotations(grid, cell_size[0], cell_size[1], hor_texts, ver_texts)
  191. processed_result.images[0] = grid
  192. return processed_result
  193. class SharedSettingsStackHelper(object):
  194. def __enter__(self):
  195. self.CLIP_stop_at_last_layers = opts.CLIP_stop_at_last_layers
  196. self.vae = opts.sd_vae
  197. def __exit__(self, exc_type, exc_value, tb):
  198. opts.data["sd_vae"] = self.vae
  199. modules.sd_models.reload_model_weights()
  200. modules.sd_vae.reload_vae_weights()
  201. opts.data["CLIP_stop_at_last_layers"] = self.CLIP_stop_at_last_layers
  202. re_range = re.compile(r"\s*([+-]?\s*\d+)\s*-\s*([+-]?\s*\d+)(?:\s*\(([+-]\d+)\s*\))?\s*")
  203. re_range_float = re.compile(r"\s*([+-]?\s*\d+(?:.\d*)?)\s*-\s*([+-]?\s*\d+(?:.\d*)?)(?:\s*\(([+-]\d+(?:.\d*)?)\s*\))?\s*")
  204. re_range_count = re.compile(r"\s*([+-]?\s*\d+)\s*-\s*([+-]?\s*\d+)(?:\s*\[(\d+)\s*\])?\s*")
  205. re_range_count_float = re.compile(r"\s*([+-]?\s*\d+(?:.\d*)?)\s*-\s*([+-]?\s*\d+(?:.\d*)?)(?:\s*\[(\d+(?:.\d*)?)\s*\])?\s*")
  206. class Script(scripts.Script):
  207. def title(self):
  208. return "X/Y plot"
  209. def ui(self, is_img2img):
  210. current_axis_options = [x for x in axis_options if type(x) == AxisOption or x.is_img2img and is_img2img]
  211. with gr.Row():
  212. with gr.Column(scale=19):
  213. with gr.Row():
  214. x_type = gr.Dropdown(label="X type", choices=[x.label for x in current_axis_options], value=current_axis_options[1].label, type="index", elem_id=self.elem_id("x_type"))
  215. x_values = gr.Textbox(label="X values", lines=1, elem_id=self.elem_id("x_values"))
  216. fill_x_button = ToolButton(value=fill_values_symbol, elem_id="xy_grid_fill_x_tool_button", visible=False)
  217. with gr.Row():
  218. y_type = gr.Dropdown(label="Y type", choices=[x.label for x in current_axis_options], value=current_axis_options[0].label, type="index", elem_id=self.elem_id("y_type"))
  219. y_values = gr.Textbox(label="Y values", lines=1, elem_id=self.elem_id("y_values"))
  220. fill_y_button = ToolButton(value=fill_values_symbol, elem_id="xy_grid_fill_y_tool_button", visible=False)
  221. with gr.Row(variant="compact"):
  222. draw_legend = gr.Checkbox(label='Draw legend', value=True, elem_id=self.elem_id("draw_legend"))
  223. include_lone_images = gr.Checkbox(label='Include Separate Images', value=False, elem_id=self.elem_id("include_lone_images"))
  224. no_fixed_seeds = gr.Checkbox(label='Keep -1 for seeds', value=False, elem_id=self.elem_id("no_fixed_seeds"))
  225. swap_axes_button = gr.Button(value="Swap axes", elem_id="xy_grid_swap_axes_button")
  226. def swap_axes(x_type, x_values, y_type, y_values):
  227. nonlocal current_axis_options
  228. return current_axis_options[y_type].label, y_values, current_axis_options[x_type].label, x_values
  229. swap_args = [x_type, x_values, y_type, y_values]
  230. swap_axes_button.click(swap_axes, inputs=swap_args, outputs=swap_args)
  231. def fill(x_type):
  232. axis = axis_options[x_type]
  233. return ", ".join(axis.choices()) if axis.choices else gr.update()
  234. fill_x_button.click(fn=fill, inputs=[x_type], outputs=[x_values])
  235. fill_y_button.click(fn=fill, inputs=[y_type], outputs=[y_values])
  236. def select_axis(x_type):
  237. return gr.Button.update(visible=axis_options[x_type].choices is not None)
  238. x_type.change(fn=select_axis, inputs=[x_type], outputs=[fill_x_button])
  239. y_type.change(fn=select_axis, inputs=[y_type], outputs=[fill_y_button])
  240. return [x_type, x_values, y_type, y_values, draw_legend, include_lone_images, no_fixed_seeds]
  241. def run(self, p, x_type, x_values, y_type, y_values, draw_legend, include_lone_images, no_fixed_seeds):
  242. if not no_fixed_seeds:
  243. modules.processing.fix_seed(p)
  244. if not opts.return_grid:
  245. p.batch_size = 1
  246. def process_axis(opt, vals):
  247. if opt.label == 'Nothing':
  248. return [0]
  249. valslist = [x.strip() for x in chain.from_iterable(csv.reader(StringIO(vals)))]
  250. if opt.type == int:
  251. valslist_ext = []
  252. for val in valslist:
  253. m = re_range.fullmatch(val)
  254. mc = re_range_count.fullmatch(val)
  255. if m is not None:
  256. start = int(m.group(1))
  257. end = int(m.group(2))+1
  258. step = int(m.group(3)) if m.group(3) is not None else 1
  259. valslist_ext += list(range(start, end, step))
  260. elif mc is not None:
  261. start = int(mc.group(1))
  262. end = int(mc.group(2))
  263. num = int(mc.group(3)) if mc.group(3) is not None else 1
  264. valslist_ext += [int(x) for x in np.linspace(start=start, stop=end, num=num).tolist()]
  265. else:
  266. valslist_ext.append(val)
  267. valslist = valslist_ext
  268. elif opt.type == float:
  269. valslist_ext = []
  270. for val in valslist:
  271. m = re_range_float.fullmatch(val)
  272. mc = re_range_count_float.fullmatch(val)
  273. if m is not None:
  274. start = float(m.group(1))
  275. end = float(m.group(2))
  276. step = float(m.group(3)) if m.group(3) is not None else 1
  277. valslist_ext += np.arange(start, end + step, step).tolist()
  278. elif mc is not None:
  279. start = float(mc.group(1))
  280. end = float(mc.group(2))
  281. num = int(mc.group(3)) if mc.group(3) is not None else 1
  282. valslist_ext += np.linspace(start=start, stop=end, num=num).tolist()
  283. else:
  284. valslist_ext.append(val)
  285. valslist = valslist_ext
  286. elif opt.type == str_permutations:
  287. valslist = list(permutations(valslist))
  288. valslist = [opt.type(x) for x in valslist]
  289. # Confirm options are valid before starting
  290. if opt.confirm:
  291. opt.confirm(p, valslist)
  292. return valslist
  293. x_opt = axis_options[x_type]
  294. xs = process_axis(x_opt, x_values)
  295. y_opt = axis_options[y_type]
  296. ys = process_axis(y_opt, y_values)
  297. def fix_axis_seeds(axis_opt, axis_list):
  298. if axis_opt.label in ['Seed', 'Var. seed']:
  299. return [int(random.randrange(4294967294)) if val is None or val == '' or val == -1 else val for val in axis_list]
  300. else:
  301. return axis_list
  302. if not no_fixed_seeds:
  303. xs = fix_axis_seeds(x_opt, xs)
  304. ys = fix_axis_seeds(y_opt, ys)
  305. if x_opt.label == 'Steps':
  306. total_steps = sum(xs) * len(ys)
  307. elif y_opt.label == 'Steps':
  308. total_steps = sum(ys) * len(xs)
  309. else:
  310. total_steps = p.steps * len(xs) * len(ys)
  311. if isinstance(p, StableDiffusionProcessingTxt2Img) and p.enable_hr:
  312. total_steps *= 2
  313. print(f"X/Y plot will create {len(xs) * len(ys) * p.n_iter} images on a {len(xs)}x{len(ys)} grid. (Total steps to process: {total_steps * p.n_iter})")
  314. shared.total_tqdm.updateTotal(total_steps * p.n_iter)
  315. grid_infotext = [None]
  316. # If one of the axes is very slow to change between (like SD model
  317. # checkpoint), then make sure it is in the outer iteration of the nested
  318. # `for` loop.
  319. swap_axes_processing_order = x_opt.cost > y_opt.cost
  320. def cell(x, y):
  321. if shared.state.interrupted:
  322. return Processed(p, [], p.seed, "")
  323. pc = copy(p)
  324. x_opt.apply(pc, x, xs)
  325. y_opt.apply(pc, y, ys)
  326. res = process_images(pc)
  327. if grid_infotext[0] is None:
  328. pc.extra_generation_params = copy(pc.extra_generation_params)
  329. if x_opt.label != 'Nothing':
  330. pc.extra_generation_params["X Type"] = x_opt.label
  331. pc.extra_generation_params["X Values"] = x_values
  332. if x_opt.label in ["Seed", "Var. seed"] and not no_fixed_seeds:
  333. pc.extra_generation_params["Fixed X Values"] = ", ".join([str(x) for x in xs])
  334. if y_opt.label != 'Nothing':
  335. pc.extra_generation_params["Y Type"] = y_opt.label
  336. pc.extra_generation_params["Y Values"] = y_values
  337. if y_opt.label in ["Seed", "Var. seed"] and not no_fixed_seeds:
  338. pc.extra_generation_params["Fixed Y Values"] = ", ".join([str(y) for y in ys])
  339. grid_infotext[0] = processing.create_infotext(pc, pc.all_prompts, pc.all_seeds, pc.all_subseeds)
  340. return res
  341. with SharedSettingsStackHelper():
  342. processed = draw_xy_grid(
  343. p,
  344. xs=xs,
  345. ys=ys,
  346. x_labels=[x_opt.format_value(p, x_opt, x) for x in xs],
  347. y_labels=[y_opt.format_value(p, y_opt, y) for y in ys],
  348. cell=cell,
  349. draw_legend=draw_legend,
  350. include_lone_images=include_lone_images,
  351. swap_axes_processing_order=swap_axes_processing_order
  352. )
  353. if opts.grid_save:
  354. images.save_image(processed.images[0], p.outpath_grids, "xy_grid", info=grid_infotext[0], extension=opts.grid_format, prompt=p.prompt, seed=processed.seed, grid=True, p=p)
  355. return processed