poor_mans_outpainting.py 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146
  1. import math
  2. import modules.scripts as scripts
  3. import gradio as gr
  4. from PIL import Image, ImageDraw
  5. from modules import images, devices
  6. from modules.processing import Processed, process_images
  7. from modules.shared import opts, state
  8. class Script(scripts.Script):
  9. def title(self):
  10. return "Poor man's outpainting"
  11. def show(self, is_img2img):
  12. return is_img2img
  13. def ui(self, is_img2img):
  14. if not is_img2img:
  15. return None
  16. pixels = gr.Slider(label="Pixels to expand", minimum=8, maximum=256, step=8, value=128, elem_id=self.elem_id("pixels"))
  17. mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=4, elem_id=self.elem_id("mask_blur"))
  18. inpainting_fill = gr.Radio(label='Masked content', choices=['fill', 'original', 'latent noise', 'latent nothing'], value='fill', type="index", elem_id=self.elem_id("inpainting_fill"))
  19. direction = gr.CheckboxGroup(label="Outpainting direction", choices=['left', 'right', 'up', 'down'], value=['left', 'right', 'up', 'down'], elem_id=self.elem_id("direction"))
  20. return [pixels, mask_blur, inpainting_fill, direction]
  21. def run(self, p, pixels, mask_blur, inpainting_fill, direction):
  22. initial_seed = None
  23. initial_info = None
  24. p.mask_blur = mask_blur * 2
  25. p.inpainting_fill = inpainting_fill
  26. p.inpaint_full_res = False
  27. left = pixels if "left" in direction else 0
  28. right = pixels if "right" in direction else 0
  29. up = pixels if "up" in direction else 0
  30. down = pixels if "down" in direction else 0
  31. init_img = p.init_images[0]
  32. target_w = math.ceil((init_img.width + left + right) / 64) * 64
  33. target_h = math.ceil((init_img.height + up + down) / 64) * 64
  34. if left > 0:
  35. left = left * (target_w - init_img.width) // (left + right)
  36. if right > 0:
  37. right = target_w - init_img.width - left
  38. if up > 0:
  39. up = up * (target_h - init_img.height) // (up + down)
  40. if down > 0:
  41. down = target_h - init_img.height - up
  42. img = Image.new("RGB", (target_w, target_h))
  43. img.paste(init_img, (left, up))
  44. mask = Image.new("L", (img.width, img.height), "white")
  45. draw = ImageDraw.Draw(mask)
  46. draw.rectangle((
  47. left + (mask_blur * 2 if left > 0 else 0),
  48. up + (mask_blur * 2 if up > 0 else 0),
  49. mask.width - right - (mask_blur * 2 if right > 0 else 0),
  50. mask.height - down - (mask_blur * 2 if down > 0 else 0)
  51. ), fill="black")
  52. latent_mask = Image.new("L", (img.width, img.height), "white")
  53. latent_draw = ImageDraw.Draw(latent_mask)
  54. latent_draw.rectangle((
  55. left + (mask_blur//2 if left > 0 else 0),
  56. up + (mask_blur//2 if up > 0 else 0),
  57. mask.width - right - (mask_blur//2 if right > 0 else 0),
  58. mask.height - down - (mask_blur//2 if down > 0 else 0)
  59. ), fill="black")
  60. devices.torch_gc()
  61. grid = images.split_grid(img, tile_w=p.width, tile_h=p.height, overlap=pixels)
  62. grid_mask = images.split_grid(mask, tile_w=p.width, tile_h=p.height, overlap=pixels)
  63. grid_latent_mask = images.split_grid(latent_mask, tile_w=p.width, tile_h=p.height, overlap=pixels)
  64. p.n_iter = 1
  65. p.batch_size = 1
  66. p.do_not_save_grid = True
  67. p.do_not_save_samples = True
  68. work = []
  69. work_mask = []
  70. work_latent_mask = []
  71. work_results = []
  72. for (y, h, row), (_, _, row_mask), (_, _, row_latent_mask) in zip(grid.tiles, grid_mask.tiles, grid_latent_mask.tiles):
  73. for tiledata, tiledata_mask, tiledata_latent_mask in zip(row, row_mask, row_latent_mask):
  74. x, w = tiledata[0:2]
  75. if x >= left and x+w <= img.width - right and y >= up and y+h <= img.height - down:
  76. continue
  77. work.append(tiledata[2])
  78. work_mask.append(tiledata_mask[2])
  79. work_latent_mask.append(tiledata_latent_mask[2])
  80. batch_count = len(work)
  81. print(f"Poor man's outpainting will process a total of {len(work)} images tiled as {len(grid.tiles[0][2])}x{len(grid.tiles)}.")
  82. state.job_count = batch_count
  83. for i in range(batch_count):
  84. p.init_images = [work[i]]
  85. p.image_mask = work_mask[i]
  86. p.latent_mask = work_latent_mask[i]
  87. state.job = f"Batch {i + 1} out of {batch_count}"
  88. processed = process_images(p)
  89. if initial_seed is None:
  90. initial_seed = processed.seed
  91. initial_info = processed.info
  92. p.seed = processed.seed + 1
  93. work_results += processed.images
  94. image_index = 0
  95. for y, h, row in grid.tiles:
  96. for tiledata in row:
  97. x, w = tiledata[0:2]
  98. if x >= left and x+w <= img.width - right and y >= up and y+h <= img.height - down:
  99. continue
  100. tiledata[2] = work_results[image_index] if image_index < len(work_results) else Image.new("RGB", (p.width, p.height))
  101. image_index += 1
  102. combined_image = images.combine_grid(grid)
  103. if opts.samples_save:
  104. images.save_image(combined_image, p.outpath_samples, "", initial_seed, p.prompt, opts.samples_format, info=initial_info, p=p)
  105. processed = Processed(p, [combined_image], initial_seed, initial_info)
  106. return processed