prompt_matrix.py 2.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687
  1. import math
  2. from collections import namedtuple
  3. from copy import copy
  4. import random
  5. import modules.scripts as scripts
  6. import gradio as gr
  7. from modules import images
  8. from modules.processing import process_images, Processed
  9. from modules.shared import opts, cmd_opts, state
  10. import modules.sd_samplers
  11. def draw_xy_grid(xs, ys, x_label, y_label, cell):
  12. res = []
  13. ver_texts = [[images.GridAnnotation(y_label(y))] for y in ys]
  14. hor_texts = [[images.GridAnnotation(x_label(x))] for x in xs]
  15. first_pocessed = None
  16. state.job_count = len(xs) * len(ys)
  17. for iy, y in enumerate(ys):
  18. for ix, x in enumerate(xs):
  19. state.job = f"{ix + iy * len(xs) + 1} out of {len(xs) * len(ys)}"
  20. processed = cell(x, y)
  21. if first_pocessed is None:
  22. first_pocessed = processed
  23. res.append(processed.images[0])
  24. grid = images.image_grid(res, rows=len(ys))
  25. grid = images.draw_grid_annotations(grid, res[0].width, res[0].height, hor_texts, ver_texts)
  26. first_pocessed.images = [grid]
  27. return first_pocessed
  28. class Script(scripts.Script):
  29. def title(self):
  30. return "Prompt matrix"
  31. def ui(self, is_img2img):
  32. put_at_start = gr.Checkbox(label='Put variable parts at start of prompt', value=False)
  33. return [put_at_start]
  34. def run(self, p, put_at_start):
  35. seed = int(random.randrange(4294967294) if p.seed == -1 else p.seed)
  36. original_prompt = p.prompt[0] if type(p.prompt) == list else p.prompt
  37. all_prompts = []
  38. prompt_matrix_parts = original_prompt.split("|")
  39. combination_count = 2 ** (len(prompt_matrix_parts) - 1)
  40. for combination_num in range(combination_count):
  41. selected_prompts = [text.strip().strip(',') for n, text in enumerate(prompt_matrix_parts[1:]) if combination_num & (1 << n)]
  42. if put_at_start:
  43. selected_prompts = selected_prompts + [prompt_matrix_parts[0]]
  44. else:
  45. selected_prompts = [prompt_matrix_parts[0]] + selected_prompts
  46. all_prompts.append(", ".join(selected_prompts))
  47. p.n_iter = math.ceil(len(all_prompts) / p.batch_size)
  48. p.do_not_save_grid = True
  49. print(f"Prompt matrix will create {len(all_prompts)} images using a total of {p.n_iter} batches.")
  50. p.prompt = all_prompts
  51. p.prompt_for_display = original_prompt
  52. p.seed = len(all_prompts) * [seed]
  53. processed = process_images(p)
  54. grid = images.image_grid(processed.images, p.batch_size, rows=1 << ((len(prompt_matrix_parts) - 1) // 2))
  55. grid = images.draw_prompt_matrix(grid, p.width, p.height, prompt_matrix_parts)
  56. processed.images.insert(0, grid)
  57. if opts.grid_save:
  58. images.save_image(processed.images[0], p.outpath_grids, "prompt_matrix", prompt=original_prompt, seed=seed)
  59. return processed