xy_grid.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450
  1. from collections import namedtuple
  2. from copy import copy
  3. from itertools import permutations, chain
  4. import random
  5. import csv
  6. from io import StringIO
  7. from PIL import Image
  8. import numpy as np
  9. import modules.scripts as scripts
  10. import gradio as gr
  11. from modules import images, paths, sd_samplers, processing
  12. from modules.hypernetworks import hypernetwork
  13. from modules.processing import process_images, Processed, StableDiffusionProcessingTxt2Img
  14. from modules.shared import opts, cmd_opts, state
  15. import modules.shared as shared
  16. import modules.sd_samplers
  17. import modules.sd_models
  18. import modules.sd_vae
  19. import glob
  20. import os
  21. import re
  22. def apply_field(field):
  23. def fun(p, x, xs):
  24. setattr(p, field, x)
  25. return fun
  26. def apply_prompt(p, x, xs):
  27. if xs[0] not in p.prompt and xs[0] not in p.negative_prompt:
  28. raise RuntimeError(f"Prompt S/R did not find {xs[0]} in prompt or negative prompt.")
  29. p.prompt = p.prompt.replace(xs[0], x)
  30. p.negative_prompt = p.negative_prompt.replace(xs[0], x)
  31. def apply_order(p, x, xs):
  32. token_order = []
  33. # Initally grab the tokens from the prompt, so they can be replaced in order of earliest seen
  34. for token in x:
  35. token_order.append((p.prompt.find(token), token))
  36. token_order.sort(key=lambda t: t[0])
  37. prompt_parts = []
  38. # Split the prompt up, taking out the tokens
  39. for _, token in token_order:
  40. n = p.prompt.find(token)
  41. prompt_parts.append(p.prompt[0:n])
  42. p.prompt = p.prompt[n + len(token):]
  43. # Rebuild the prompt with the tokens in the order we want
  44. prompt_tmp = ""
  45. for idx, part in enumerate(prompt_parts):
  46. prompt_tmp += part
  47. prompt_tmp += x[idx]
  48. p.prompt = prompt_tmp + p.prompt
  49. def apply_sampler(p, x, xs):
  50. sampler_name = sd_samplers.samplers_map.get(x.lower(), None)
  51. if sampler_name is None:
  52. raise RuntimeError(f"Unknown sampler: {x}")
  53. p.sampler_name = sampler_name
  54. def confirm_samplers(p, xs):
  55. for x in xs:
  56. if x.lower() not in sd_samplers.samplers_map:
  57. raise RuntimeError(f"Unknown sampler: {x}")
  58. def apply_checkpoint(p, x, xs):
  59. info = modules.sd_models.get_closet_checkpoint_match(x)
  60. if info is None:
  61. raise RuntimeError(f"Unknown checkpoint: {x}")
  62. modules.sd_models.reload_model_weights(shared.sd_model, info)
  63. p.sd_model = shared.sd_model
  64. def confirm_checkpoints(p, xs):
  65. for x in xs:
  66. if modules.sd_models.get_closet_checkpoint_match(x) is None:
  67. raise RuntimeError(f"Unknown checkpoint: {x}")
  68. def apply_hypernetwork(p, x, xs):
  69. if x.lower() in ["", "none"]:
  70. name = None
  71. else:
  72. name = hypernetwork.find_closest_hypernetwork_name(x)
  73. if not name:
  74. raise RuntimeError(f"Unknown hypernetwork: {x}")
  75. hypernetwork.load_hypernetwork(name)
  76. def apply_hypernetwork_strength(p, x, xs):
  77. hypernetwork.apply_strength(x)
  78. def confirm_hypernetworks(p, xs):
  79. for x in xs:
  80. if x.lower() in ["", "none"]:
  81. continue
  82. if not hypernetwork.find_closest_hypernetwork_name(x):
  83. raise RuntimeError(f"Unknown hypernetwork: {x}")
  84. def apply_clip_skip(p, x, xs):
  85. opts.data["CLIP_stop_at_last_layers"] = x
  86. def apply_upscale_latent_space(p, x, xs):
  87. if x.lower().strip() != '0':
  88. opts.data["use_scale_latent_for_hires_fix"] = True
  89. else:
  90. opts.data["use_scale_latent_for_hires_fix"] = False
  91. def find_vae(name: str):
  92. if name.lower() in ['auto', 'none']:
  93. return name
  94. else:
  95. vae_path = os.path.abspath(os.path.join(paths.models_path, 'VAE'))
  96. found = glob.glob(os.path.join(vae_path, f'**/{name}.*pt'), recursive=True)
  97. if found:
  98. return found[0]
  99. else:
  100. return 'auto'
  101. def apply_vae(p, x, xs):
  102. if x.lower().strip() == 'none':
  103. modules.sd_vae.reload_vae_weights(shared.sd_model, vae_file='None')
  104. else:
  105. found = find_vae(x)
  106. if found:
  107. v = modules.sd_vae.reload_vae_weights(shared.sd_model, vae_file=found)
  108. def apply_styles(p: StableDiffusionProcessingTxt2Img, x: str, _):
  109. p.styles = x.split(',')
  110. def format_value_add_label(p, opt, x):
  111. if type(x) == float:
  112. x = round(x, 8)
  113. return f"{opt.label}: {x}"
  114. def format_value(p, opt, x):
  115. if type(x) == float:
  116. x = round(x, 8)
  117. return x
  118. def format_value_join_list(p, opt, x):
  119. return ", ".join(x)
  120. def do_nothing(p, x, xs):
  121. pass
  122. def format_nothing(p, opt, x):
  123. return ""
  124. def str_permutations(x):
  125. """dummy function for specifying it in AxisOption's type when you want to get a list of permutations"""
  126. return x
  127. AxisOption = namedtuple("AxisOption", ["label", "type", "apply", "format_value", "confirm"])
  128. AxisOptionImg2Img = namedtuple("AxisOptionImg2Img", ["label", "type", "apply", "format_value", "confirm"])
  129. axis_options = [
  130. AxisOption("Nothing", str, do_nothing, format_nothing, None),
  131. AxisOption("Seed", int, apply_field("seed"), format_value_add_label, None),
  132. AxisOption("Var. seed", int, apply_field("subseed"), format_value_add_label, None),
  133. AxisOption("Var. strength", float, apply_field("subseed_strength"), format_value_add_label, None),
  134. AxisOption("Steps", int, apply_field("steps"), format_value_add_label, None),
  135. AxisOption("CFG Scale", float, apply_field("cfg_scale"), format_value_add_label, None),
  136. AxisOption("Prompt S/R", str, apply_prompt, format_value, None),
  137. AxisOption("Prompt order", str_permutations, apply_order, format_value_join_list, None),
  138. AxisOption("Sampler", str, apply_sampler, format_value, confirm_samplers),
  139. AxisOption("Checkpoint name", str, apply_checkpoint, format_value, confirm_checkpoints),
  140. AxisOption("Hypernetwork", str, apply_hypernetwork, format_value, confirm_hypernetworks),
  141. AxisOption("Hypernet str.", float, apply_hypernetwork_strength, format_value_add_label, None),
  142. AxisOption("Sigma Churn", float, apply_field("s_churn"), format_value_add_label, None),
  143. AxisOption("Sigma min", float, apply_field("s_tmin"), format_value_add_label, None),
  144. AxisOption("Sigma max", float, apply_field("s_tmax"), format_value_add_label, None),
  145. AxisOption("Sigma noise", float, apply_field("s_noise"), format_value_add_label, None),
  146. AxisOption("Eta", float, apply_field("eta"), format_value_add_label, None),
  147. AxisOption("Clip skip", int, apply_clip_skip, format_value_add_label, None),
  148. AxisOption("Denoising", float, apply_field("denoising_strength"), format_value_add_label, None),
  149. AxisOption("Hires upscaler", str, apply_field("hr_upscaler"), format_value_add_label, None),
  150. AxisOption("Cond. Image Mask Weight", float, apply_field("inpainting_mask_weight"), format_value_add_label, None),
  151. AxisOption("VAE", str, apply_vae, format_value_add_label, None),
  152. AxisOption("Styles", str, apply_styles, format_value_add_label, None),
  153. ]
  154. def draw_xy_grid(p, xs, ys, x_labels, y_labels, cell, draw_legend, include_lone_images):
  155. ver_texts = [[images.GridAnnotation(y)] for y in y_labels]
  156. hor_texts = [[images.GridAnnotation(x)] for x in x_labels]
  157. # Temporary list of all the images that are generated to be populated into the grid.
  158. # Will be filled with empty images for any individual step that fails to process properly
  159. image_cache = []
  160. processed_result = None
  161. cell_mode = "P"
  162. cell_size = (1,1)
  163. state.job_count = len(xs) * len(ys) * p.n_iter
  164. for iy, y in enumerate(ys):
  165. for ix, x in enumerate(xs):
  166. state.job = f"{ix + iy * len(xs) + 1} out of {len(xs) * len(ys)}"
  167. processed:Processed = cell(x, y)
  168. try:
  169. # this dereference will throw an exception if the image was not processed
  170. # (this happens in cases such as if the user stops the process from the UI)
  171. processed_image = processed.images[0]
  172. if processed_result is None:
  173. # Use our first valid processed result as a template container to hold our full results
  174. processed_result = copy(processed)
  175. cell_mode = processed_image.mode
  176. cell_size = processed_image.size
  177. processed_result.images = [Image.new(cell_mode, cell_size)]
  178. image_cache.append(processed_image)
  179. if include_lone_images:
  180. processed_result.images.append(processed_image)
  181. processed_result.all_prompts.append(processed.prompt)
  182. processed_result.all_seeds.append(processed.seed)
  183. processed_result.infotexts.append(processed.infotexts[0])
  184. except:
  185. image_cache.append(Image.new(cell_mode, cell_size))
  186. if not processed_result:
  187. print("Unexpected error: draw_xy_grid failed to return even a single processed image")
  188. return Processed()
  189. grid = images.image_grid(image_cache, rows=len(ys))
  190. if draw_legend:
  191. grid = images.draw_grid_annotations(grid, cell_size[0], cell_size[1], hor_texts, ver_texts)
  192. processed_result.images[0] = grid
  193. return processed_result
  194. class SharedSettingsStackHelper(object):
  195. def __enter__(self):
  196. self.CLIP_stop_at_last_layers = opts.CLIP_stop_at_last_layers
  197. self.hypernetwork = opts.sd_hypernetwork
  198. self.model = shared.sd_model
  199. self.vae = opts.sd_vae
  200. def __exit__(self, exc_type, exc_value, tb):
  201. modules.sd_models.reload_model_weights(self.model)
  202. modules.sd_vae.reload_vae_weights(self.model, vae_file=find_vae(self.vae))
  203. hypernetwork.load_hypernetwork(self.hypernetwork)
  204. hypernetwork.apply_strength()
  205. opts.data["CLIP_stop_at_last_layers"] = self.CLIP_stop_at_last_layers
  206. re_range = re.compile(r"\s*([+-]?\s*\d+)\s*-\s*([+-]?\s*\d+)(?:\s*\(([+-]\d+)\s*\))?\s*")
  207. re_range_float = re.compile(r"\s*([+-]?\s*\d+(?:.\d*)?)\s*-\s*([+-]?\s*\d+(?:.\d*)?)(?:\s*\(([+-]\d+(?:.\d*)?)\s*\))?\s*")
  208. re_range_count = re.compile(r"\s*([+-]?\s*\d+)\s*-\s*([+-]?\s*\d+)(?:\s*\[(\d+)\s*\])?\s*")
  209. re_range_count_float = re.compile(r"\s*([+-]?\s*\d+(?:.\d*)?)\s*-\s*([+-]?\s*\d+(?:.\d*)?)(?:\s*\[(\d+(?:.\d*)?)\s*\])?\s*")
  210. class Script(scripts.Script):
  211. def title(self):
  212. return "X/Y plot"
  213. def ui(self, is_img2img):
  214. current_axis_options = [x for x in axis_options if type(x) == AxisOption or type(x) == AxisOptionImg2Img and is_img2img]
  215. with gr.Row():
  216. x_type = gr.Dropdown(label="X type", choices=[x.label for x in current_axis_options], value=current_axis_options[1].label, type="index", elem_id=self.elem_id("x_type"))
  217. x_values = gr.Textbox(label="X values", lines=1, elem_id=self.elem_id("x_values"))
  218. with gr.Row():
  219. y_type = gr.Dropdown(label="Y type", choices=[x.label for x in current_axis_options], value=current_axis_options[0].label, type="index", elem_id=self.elem_id("y_type"))
  220. y_values = gr.Textbox(label="Y values", lines=1, elem_id=self.elem_id("y_values"))
  221. draw_legend = gr.Checkbox(label='Draw legend', value=True, elem_id=self.elem_id("draw_legend"))
  222. include_lone_images = gr.Checkbox(label='Include Separate Images', value=False, elem_id=self.elem_id("include_lone_images"))
  223. no_fixed_seeds = gr.Checkbox(label='Keep -1 for seeds', value=False, elem_id=self.elem_id("no_fixed_seeds"))
  224. return [x_type, x_values, y_type, y_values, draw_legend, include_lone_images, no_fixed_seeds]
  225. def run(self, p, x_type, x_values, y_type, y_values, draw_legend, include_lone_images, no_fixed_seeds):
  226. if not no_fixed_seeds:
  227. modules.processing.fix_seed(p)
  228. if not opts.return_grid:
  229. p.batch_size = 1
  230. def process_axis(opt, vals):
  231. if opt.label == 'Nothing':
  232. return [0]
  233. valslist = [x.strip() for x in chain.from_iterable(csv.reader(StringIO(vals)))]
  234. if opt.type == int:
  235. valslist_ext = []
  236. for val in valslist:
  237. m = re_range.fullmatch(val)
  238. mc = re_range_count.fullmatch(val)
  239. if m is not None:
  240. start = int(m.group(1))
  241. end = int(m.group(2))+1
  242. step = int(m.group(3)) if m.group(3) is not None else 1
  243. valslist_ext += list(range(start, end, step))
  244. elif mc is not None:
  245. start = int(mc.group(1))
  246. end = int(mc.group(2))
  247. num = int(mc.group(3)) if mc.group(3) is not None else 1
  248. valslist_ext += [int(x) for x in np.linspace(start=start, stop=end, num=num).tolist()]
  249. else:
  250. valslist_ext.append(val)
  251. valslist = valslist_ext
  252. elif opt.type == float:
  253. valslist_ext = []
  254. for val in valslist:
  255. m = re_range_float.fullmatch(val)
  256. mc = re_range_count_float.fullmatch(val)
  257. if m is not None:
  258. start = float(m.group(1))
  259. end = float(m.group(2))
  260. step = float(m.group(3)) if m.group(3) is not None else 1
  261. valslist_ext += np.arange(start, end + step, step).tolist()
  262. elif mc is not None:
  263. start = float(mc.group(1))
  264. end = float(mc.group(2))
  265. num = int(mc.group(3)) if mc.group(3) is not None else 1
  266. valslist_ext += np.linspace(start=start, stop=end, num=num).tolist()
  267. else:
  268. valslist_ext.append(val)
  269. valslist = valslist_ext
  270. elif opt.type == str_permutations:
  271. valslist = list(permutations(valslist))
  272. valslist = [opt.type(x) for x in valslist]
  273. # Confirm options are valid before starting
  274. if opt.confirm:
  275. opt.confirm(p, valslist)
  276. return valslist
  277. x_opt = axis_options[x_type]
  278. xs = process_axis(x_opt, x_values)
  279. y_opt = axis_options[y_type]
  280. ys = process_axis(y_opt, y_values)
  281. def fix_axis_seeds(axis_opt, axis_list):
  282. if axis_opt.label in ['Seed', 'Var. seed']:
  283. return [int(random.randrange(4294967294)) if val is None or val == '' or val == -1 else val for val in axis_list]
  284. else:
  285. return axis_list
  286. if not no_fixed_seeds:
  287. xs = fix_axis_seeds(x_opt, xs)
  288. ys = fix_axis_seeds(y_opt, ys)
  289. if x_opt.label == 'Steps':
  290. total_steps = sum(xs) * len(ys)
  291. elif y_opt.label == 'Steps':
  292. total_steps = sum(ys) * len(xs)
  293. else:
  294. total_steps = p.steps * len(xs) * len(ys)
  295. if isinstance(p, StableDiffusionProcessingTxt2Img) and p.enable_hr:
  296. total_steps *= 2
  297. print(f"X/Y plot will create {len(xs) * len(ys) * p.n_iter} images on a {len(xs)}x{len(ys)} grid. (Total steps to process: {total_steps * p.n_iter})")
  298. shared.total_tqdm.updateTotal(total_steps * p.n_iter)
  299. grid_infotext = [None]
  300. def cell(x, y):
  301. pc = copy(p)
  302. x_opt.apply(pc, x, xs)
  303. y_opt.apply(pc, y, ys)
  304. res = process_images(pc)
  305. if grid_infotext[0] is None:
  306. pc.extra_generation_params = copy(pc.extra_generation_params)
  307. if x_opt.label != 'Nothing':
  308. pc.extra_generation_params["X Type"] = x_opt.label
  309. pc.extra_generation_params["X Values"] = x_values
  310. if x_opt.label in ["Seed", "Var. seed"] and not no_fixed_seeds:
  311. pc.extra_generation_params["Fixed X Values"] = ", ".join([str(x) for x in xs])
  312. if y_opt.label != 'Nothing':
  313. pc.extra_generation_params["Y Type"] = y_opt.label
  314. pc.extra_generation_params["Y Values"] = y_values
  315. if y_opt.label in ["Seed", "Var. seed"] and not no_fixed_seeds:
  316. pc.extra_generation_params["Fixed Y Values"] = ", ".join([str(y) for y in ys])
  317. grid_infotext[0] = processing.create_infotext(pc, pc.all_prompts, pc.all_seeds, pc.all_subseeds)
  318. return res
  319. with SharedSettingsStackHelper():
  320. processed = draw_xy_grid(
  321. p,
  322. xs=xs,
  323. ys=ys,
  324. x_labels=[x_opt.format_value(p, x_opt, x) for x in xs],
  325. y_labels=[y_opt.format_value(p, y_opt, y) for y in ys],
  326. cell=cell,
  327. draw_legend=draw_legend,
  328. include_lone_images=include_lone_images
  329. )
  330. if opts.grid_save:
  331. images.save_image(processed.images[0], p.outpath_grids, "xy_grid", info=grid_infotext[0], extension=opts.grid_format, prompt=p.prompt, seed=processed.seed, grid=True, p=p)
  332. return processed