sd_upscale.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101
  1. import math
  2. import modules.scripts as scripts
  3. import gradio as gr
  4. from PIL import Image
  5. from modules import processing, shared, images, devices
  6. from modules.processing import Processed
  7. from modules.shared import opts, state
  8. class Script(scripts.Script):
  9. def title(self):
  10. return "SD upscale"
  11. def show(self, is_img2img):
  12. return is_img2img
  13. def ui(self, is_img2img):
  14. info = gr.HTML("<p style=\"margin-bottom:0.75em\">Will upscale the image by the selected scale factor; use width and height sliders to set tile size</p>")
  15. overlap = gr.Slider(minimum=0, maximum=256, step=16, label='Tile overlap', value=64, elem_id=self.elem_id("overlap"))
  16. scale_factor = gr.Slider(minimum=1.0, maximum=4.0, step=0.05, label='Scale Factor', value=2.0, elem_id=self.elem_id("scale_factor"))
  17. upscaler_index = gr.Radio(label='Upscaler', choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index", elem_id=self.elem_id("upscaler_index"))
  18. return [info, overlap, upscaler_index, scale_factor]
  19. def run(self, p, _, overlap, upscaler_index, scale_factor):
  20. if isinstance(upscaler_index, str):
  21. upscaler_index = [x.name.lower() for x in shared.sd_upscalers].index(upscaler_index.lower())
  22. processing.fix_seed(p)
  23. upscaler = shared.sd_upscalers[upscaler_index]
  24. p.extra_generation_params["SD upscale overlap"] = overlap
  25. p.extra_generation_params["SD upscale upscaler"] = upscaler.name
  26. initial_info = None
  27. seed = p.seed
  28. init_img = p.init_images[0]
  29. init_img = images.flatten(init_img, opts.img2img_background_color)
  30. if upscaler.name != "None":
  31. img = upscaler.scaler.upscale(init_img, scale_factor, upscaler.data_path)
  32. else:
  33. img = init_img
  34. devices.torch_gc()
  35. grid = images.split_grid(img, tile_w=p.width, tile_h=p.height, overlap=overlap)
  36. batch_size = p.batch_size
  37. upscale_count = p.n_iter
  38. p.n_iter = 1
  39. p.do_not_save_grid = True
  40. p.do_not_save_samples = True
  41. work = []
  42. for _y, _h, row in grid.tiles:
  43. for tiledata in row:
  44. work.append(tiledata[2])
  45. batch_count = math.ceil(len(work) / batch_size)
  46. state.job_count = batch_count * upscale_count
  47. print(f"SD upscaling will process a total of {len(work)} images tiled as {len(grid.tiles[0][2])}x{len(grid.tiles)} per upscale in a total of {state.job_count} batches.")
  48. result_images = []
  49. for n in range(upscale_count):
  50. start_seed = seed + n
  51. p.seed = start_seed
  52. work_results = []
  53. for i in range(batch_count):
  54. p.batch_size = batch_size
  55. p.init_images = work[i * batch_size:(i + 1) * batch_size]
  56. state.job = f"Batch {i + 1 + n * batch_count} out of {state.job_count}"
  57. processed = processing.process_images(p)
  58. if initial_info is None:
  59. initial_info = processed.info
  60. p.seed = processed.seed + 1
  61. work_results += processed.images
  62. image_index = 0
  63. for _y, _h, row in grid.tiles:
  64. for tiledata in row:
  65. tiledata[2] = work_results[image_index] if image_index < len(work_results) else Image.new("RGB", (p.width, p.height))
  66. image_index += 1
  67. combined_image = images.combine_grid(grid)
  68. result_images.append(combined_image)
  69. if opts.samples_save:
  70. images.save_image(combined_image, p.outpath_samples, "", start_seed, p.prompt, opts.samples_format, info=initial_info, p=p)
  71. processed = Processed(p, result_images, seed, initial_info)
  72. return processed