prompts_from_file.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161
  1. import copy
  2. import math
  3. import os
  4. import random
  5. import sys
  6. import traceback
  7. import shlex
  8. import modules.scripts as scripts
  9. import gradio as gr
  10. from modules.processing import Processed, process_images
  11. from PIL import Image
  12. from modules.shared import opts, cmd_opts, state
  13. def process_string_tag(tag):
  14. return tag
  15. def process_int_tag(tag):
  16. return int(tag)
  17. def process_float_tag(tag):
  18. return float(tag)
  19. def process_boolean_tag(tag):
  20. return True if (tag == "true") else False
  21. prompt_tags = {
  22. "sd_model": None,
  23. "outpath_samples": process_string_tag,
  24. "outpath_grids": process_string_tag,
  25. "prompt_for_display": process_string_tag,
  26. "prompt": process_string_tag,
  27. "negative_prompt": process_string_tag,
  28. "styles": process_string_tag,
  29. "seed": process_int_tag,
  30. "subseed_strength": process_float_tag,
  31. "subseed": process_int_tag,
  32. "seed_resize_from_h": process_int_tag,
  33. "seed_resize_from_w": process_int_tag,
  34. "sampler_index": process_int_tag,
  35. "batch_size": process_int_tag,
  36. "n_iter": process_int_tag,
  37. "steps": process_int_tag,
  38. "cfg_scale": process_float_tag,
  39. "width": process_int_tag,
  40. "height": process_int_tag,
  41. "restore_faces": process_boolean_tag,
  42. "tiling": process_boolean_tag,
  43. "do_not_save_samples": process_boolean_tag,
  44. "do_not_save_grid": process_boolean_tag
  45. }
  46. def cmdargs(line):
  47. args = shlex.split(line)
  48. pos = 0
  49. res = {}
  50. while pos < len(args):
  51. arg = args[pos]
  52. assert arg.startswith("--"), f'must start with "--": {arg}'
  53. tag = arg[2:]
  54. func = prompt_tags.get(tag, None)
  55. assert func, f'unknown commandline option: {arg}'
  56. assert pos+1 < len(args), f'missing argument for command line option {arg}'
  57. val = args[pos+1]
  58. res[tag] = func(val)
  59. pos += 2
  60. return res
  61. def load_prompt_file(file):
  62. if (file is None):
  63. lines = []
  64. else:
  65. lines = [x.strip() for x in file.decode('utf8', errors='ignore').split("\n")]
  66. return None, "\n".join(lines), gr.update(lines=7)
  67. class Script(scripts.Script):
  68. def title(self):
  69. return "Prompts from file or textbox"
  70. def ui(self, is_img2img):
  71. checkbox_iterate = gr.Checkbox(label="Iterate seed every line", value=False)
  72. checkbox_iterate_batch = gr.Checkbox(label="Preserve random seed across lines (for use with \"Generate Forever\")", value=False)
  73. prompt_txt = gr.Textbox(label="List of prompt inputs", lines=1)
  74. file = gr.File(label="Upload prompt inputs", type='bytes')
  75. file.change(fn=load_prompt_file, inputs=[file], outputs=[file, prompt_txt, prompt_txt])
  76. # We start at one line. When the text changes, we jump to seven lines, or two lines if no \n.
  77. # We don't shrink back to 1, because that causes the control to ignore [enter], and it may
  78. # be unclear to the user that shift-enter is needed.
  79. prompt_txt.change(lambda tb: gr.update(lines=7) if ("\n" in tb) else gr.update(lines=2), inputs=[prompt_txt], outputs=[prompt_txt])
  80. return [checkbox_iterate, checkbox_iterate_batch, file, prompt_txt]
  81. def run(self, p, checkbox_iterate, checkbox_iterate_batch, file, prompt_txt: str):
  82. lines = [x.strip() for x in prompt_txt.splitlines()]
  83. lines = [x for x in lines if len(x) > 0]
  84. p.do_not_save_grid = True
  85. job_count = 0
  86. jobs = []
  87. for line in lines:
  88. if "--" in line:
  89. try:
  90. args = cmdargs(line)
  91. except Exception:
  92. print(f"Error parsing line [line] as commandline:", file=sys.stderr)
  93. print(traceback.format_exc(), file=sys.stderr)
  94. args = {"prompt": line}
  95. else:
  96. args = {"prompt": line}
  97. n_iter = args.get("n_iter", 1)
  98. if n_iter != 1:
  99. job_count += n_iter
  100. else:
  101. job_count += 1
  102. jobs.append(args)
  103. print(f"Will process {len(lines)} lines in {job_count} jobs.")
  104. if ((checkbox_iterate or checkbox_iterate_batch) and p.seed == -1):
  105. p.seed = int(random.randrange(4294967294))
  106. state.job_count = job_count
  107. images = []
  108. for n, args in enumerate(jobs):
  109. state.job = f"{state.job_no + 1} out of {state.job_count}"
  110. copy_p = copy.copy(p)
  111. for k, v in args.items():
  112. setattr(copy_p, k, v)
  113. proc = process_images(copy_p)
  114. images += proc.images
  115. if (checkbox_iterate):
  116. p.seed = p.seed + (p.batch_size * p.n_iter)
  117. return Processed(p, images, p.seed, "")