prompt_matrix.py 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596
  1. import math
  2. from collections import namedtuple
  3. from copy import copy
  4. import random
  5. import modules.scripts as scripts
  6. import gradio as gr
  7. from modules import images
  8. from modules.processing import process_images, Processed
  9. from modules.shared import opts, cmd_opts, state
  10. import modules.sd_samplers
  11. import re
  12. def draw_xy_grid(xs, ys, x_label, y_label, cell):
  13. res = []
  14. ver_texts = [[images.GridAnnotation(y_label(y))] for y in ys]
  15. hor_texts = [[images.GridAnnotation(x_label(x))] for x in xs]
  16. first_processed = None
  17. state.job_count = len(xs) * len(ys)
  18. for iy, y in enumerate(ys):
  19. for ix, x in enumerate(xs):
  20. state.job = f"{ix + iy * len(xs) + 1} out of {len(xs) * len(ys)}"
  21. processed = cell(x, y)
  22. if first_processed is None:
  23. first_processed = processed
  24. res.append(processed.images[0])
  25. grid = images.image_grid(res, rows=len(ys))
  26. grid = images.draw_grid_annotations(grid, res[0].width, res[0].height, hor_texts, ver_texts)
  27. first_processed.images = [grid]
  28. return first_processed
  29. class Script(scripts.Script):
  30. def title(self):
  31. return "Prompt matrix"
  32. def elem_id(self, item_id):
  33. gen_elem_id = ('img2img' if self.is_img2img else 'txt2txt') + '_script_' + re.sub(r'\s', '_', self.title().lower()) + '_' + item_id
  34. gen_elem_id = re.sub(r'[^a-z_0-9]', '', gen_elem_id)
  35. return gen_elem_id
  36. def ui(self, is_img2img):
  37. put_at_start = gr.Checkbox(label='Put variable parts at start of prompt', value=False, elem_id=self.elem_id("put_at_start"))
  38. different_seeds = gr.Checkbox(label='Use different seed for each picture', value=False, elem_id=self.elem_id("different_seeds"))
  39. return [put_at_start, different_seeds]
  40. def run(self, p, put_at_start, different_seeds):
  41. modules.processing.fix_seed(p)
  42. original_prompt = p.prompt[0] if type(p.prompt) == list else p.prompt
  43. all_prompts = []
  44. prompt_matrix_parts = original_prompt.split("|")
  45. combination_count = 2 ** (len(prompt_matrix_parts) - 1)
  46. for combination_num in range(combination_count):
  47. selected_prompts = [text.strip().strip(',') for n, text in enumerate(prompt_matrix_parts[1:]) if combination_num & (1 << n)]
  48. if put_at_start:
  49. selected_prompts = selected_prompts + [prompt_matrix_parts[0]]
  50. else:
  51. selected_prompts = [prompt_matrix_parts[0]] + selected_prompts
  52. all_prompts.append(", ".join(selected_prompts))
  53. p.n_iter = math.ceil(len(all_prompts) / p.batch_size)
  54. p.do_not_save_grid = True
  55. print(f"Prompt matrix will create {len(all_prompts)} images using a total of {p.n_iter} batches.")
  56. p.prompt = all_prompts
  57. p.seed = [p.seed + (i if different_seeds else 0) for i in range(len(all_prompts))]
  58. p.prompt_for_display = original_prompt
  59. processed = process_images(p)
  60. grid = images.image_grid(processed.images, p.batch_size, rows=1 << ((len(prompt_matrix_parts) - 1) // 2))
  61. grid = images.draw_prompt_matrix(grid, p.width, p.height, prompt_matrix_parts)
  62. processed.images.insert(0, grid)
  63. processed.index_of_first_image = 1
  64. processed.infotexts.insert(0, processed.infotexts[0])
  65. if opts.grid_save:
  66. images.save_image(processed.images[0], p.outpath_grids, "prompt_matrix", extension=opts.grid_format, prompt=original_prompt, seed=processed.seed, grid=True, p=p)
  67. return processed