batch.py 1.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960
  1. import math
  2. import os
  3. import sys
  4. import traceback
  5. import modules.scripts as scripts
  6. import gradio as gr
  7. from modules.processing import Processed, process_images
  8. from PIL import Image
  9. from modules.shared import opts, cmd_opts, state
  10. class Script(scripts.Script):
  11. def title(self):
  12. return "Batch processing"
  13. def show(self, is_img2img):
  14. return is_img2img
  15. def ui(self, is_img2img):
  16. input_dir = gr.Textbox(label="Input directory", lines=1)
  17. output_dir = gr.Textbox(label="Output directory", lines=1)
  18. return [input_dir, output_dir]
  19. def run(self, p, input_dir, output_dir):
  20. images = [file for file in [os.path.join(input_dir, x) for x in os.listdir(input_dir)] if os.path.isfile(file)]
  21. batch_count = math.ceil(len(images) / p.batch_size)
  22. print(f"Will process {len(images)} images in {batch_count} batches.")
  23. p.batch_count = 1
  24. p.do_not_save_grid = True
  25. p.do_not_save_samples = True
  26. state.job_count = batch_count
  27. for batch_no in range(batch_count):
  28. batch_images = []
  29. for path in images[batch_no*p.batch_size:(batch_no+1)*p.batch_size]:
  30. try:
  31. img = Image.open(path)
  32. batch_images.append((img, path))
  33. except:
  34. print(f"Error processing {path}:", file=sys.stderr)
  35. print(traceback.format_exc(), file=sys.stderr)
  36. if len(batch_images) == 0:
  37. continue
  38. state.job = f"{batch_no} out of {batch_count}: {batch_images[0][1]}"
  39. p.init_images = [x[0] for x in batch_images]
  40. proc = process_images(p)
  41. for image, (_, path) in zip(proc.images, batch_images):
  42. filename = os.path.basename(path)
  43. image.save(os.path.join(output_dir, filename))
  44. return Processed(p, [], p.seed, "")