outpainting_mk_2.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295
  1. import math
  2. import numpy as np
  3. import skimage
  4. import modules.scripts as scripts
  5. import gradio as gr
  6. from PIL import Image, ImageDraw
  7. from modules import images
  8. from modules.processing import Processed, process_images
  9. from modules.shared import opts, state
  10. # this function is taken from https://github.com/parlance-zz/g-diffuser-bot
  11. def get_matched_noise(_np_src_image, np_mask_rgb, noise_q=1, color_variation=0.05):
  12. # helper fft routines that keep ortho normalization and auto-shift before and after fft
  13. def _fft2(data):
  14. if data.ndim > 2: # has channels
  15. out_fft = np.zeros((data.shape[0], data.shape[1], data.shape[2]), dtype=np.complex128)
  16. for c in range(data.shape[2]):
  17. c_data = data[:, :, c]
  18. out_fft[:, :, c] = np.fft.fft2(np.fft.fftshift(c_data), norm="ortho")
  19. out_fft[:, :, c] = np.fft.ifftshift(out_fft[:, :, c])
  20. else: # one channel
  21. out_fft = np.zeros((data.shape[0], data.shape[1]), dtype=np.complex128)
  22. out_fft[:, :] = np.fft.fft2(np.fft.fftshift(data), norm="ortho")
  23. out_fft[:, :] = np.fft.ifftshift(out_fft[:, :])
  24. return out_fft
  25. def _ifft2(data):
  26. if data.ndim > 2: # has channels
  27. out_ifft = np.zeros((data.shape[0], data.shape[1], data.shape[2]), dtype=np.complex128)
  28. for c in range(data.shape[2]):
  29. c_data = data[:, :, c]
  30. out_ifft[:, :, c] = np.fft.ifft2(np.fft.fftshift(c_data), norm="ortho")
  31. out_ifft[:, :, c] = np.fft.ifftshift(out_ifft[:, :, c])
  32. else: # one channel
  33. out_ifft = np.zeros((data.shape[0], data.shape[1]), dtype=np.complex128)
  34. out_ifft[:, :] = np.fft.ifft2(np.fft.fftshift(data), norm="ortho")
  35. out_ifft[:, :] = np.fft.ifftshift(out_ifft[:, :])
  36. return out_ifft
  37. def _get_gaussian_window(width, height, std=3.14, mode=0):
  38. window_scale_x = float(width / min(width, height))
  39. window_scale_y = float(height / min(width, height))
  40. window = np.zeros((width, height))
  41. x = (np.arange(width) / width * 2. - 1.) * window_scale_x
  42. for y in range(height):
  43. fy = (y / height * 2. - 1.) * window_scale_y
  44. if mode == 0:
  45. window[:, y] = np.exp(-(x ** 2 + fy ** 2) * std)
  46. else:
  47. window[:, y] = (1 / ((x ** 2 + 1.) * (fy ** 2 + 1.))) ** (std / 3.14) # hey wait a minute that's not gaussian
  48. return window
  49. def _get_masked_window_rgb(np_mask_grey, hardness=1.):
  50. np_mask_rgb = np.zeros((np_mask_grey.shape[0], np_mask_grey.shape[1], 3))
  51. if hardness != 1.:
  52. hardened = np_mask_grey[:] ** hardness
  53. else:
  54. hardened = np_mask_grey[:]
  55. for c in range(3):
  56. np_mask_rgb[:, :, c] = hardened[:]
  57. return np_mask_rgb
  58. width = _np_src_image.shape[0]
  59. height = _np_src_image.shape[1]
  60. num_channels = _np_src_image.shape[2]
  61. _np_src_image[:] * (1. - np_mask_rgb)
  62. np_mask_grey = (np.sum(np_mask_rgb, axis=2) / 3.)
  63. img_mask = np_mask_grey > 1e-6
  64. ref_mask = np_mask_grey < 1e-3
  65. windowed_image = _np_src_image * (1. - _get_masked_window_rgb(np_mask_grey))
  66. windowed_image /= np.max(windowed_image)
  67. windowed_image += np.average(_np_src_image) * np_mask_rgb # / (1.-np.average(np_mask_rgb)) # rather than leave the masked area black, we get better results from fft by filling the average unmasked color
  68. src_fft = _fft2(windowed_image) # get feature statistics from masked src img
  69. src_dist = np.absolute(src_fft)
  70. src_phase = src_fft / src_dist
  71. # create a generator with a static seed to make outpainting deterministic / only follow global seed
  72. rng = np.random.default_rng(0)
  73. noise_window = _get_gaussian_window(width, height, mode=1) # start with simple gaussian noise
  74. noise_rgb = rng.random((width, height, num_channels))
  75. noise_grey = (np.sum(noise_rgb, axis=2) / 3.)
  76. noise_rgb *= color_variation # the colorfulness of the starting noise is blended to greyscale with a parameter
  77. for c in range(num_channels):
  78. noise_rgb[:, :, c] += (1. - color_variation) * noise_grey
  79. noise_fft = _fft2(noise_rgb)
  80. for c in range(num_channels):
  81. noise_fft[:, :, c] *= noise_window
  82. noise_rgb = np.real(_ifft2(noise_fft))
  83. shaped_noise_fft = _fft2(noise_rgb)
  84. shaped_noise_fft[:, :, :] = np.absolute(shaped_noise_fft[:, :, :]) ** 2 * (src_dist ** noise_q) * src_phase # perform the actual shaping
  85. brightness_variation = 0. # color_variation # todo: temporarily tieing brightness variation to color variation for now
  86. contrast_adjusted_np_src = _np_src_image[:] * (brightness_variation + 1.) - brightness_variation * 2.
  87. # scikit-image is used for histogram matching, very convenient!
  88. shaped_noise = np.real(_ifft2(shaped_noise_fft))
  89. shaped_noise -= np.min(shaped_noise)
  90. shaped_noise /= np.max(shaped_noise)
  91. shaped_noise[img_mask, :] = skimage.exposure.match_histograms(shaped_noise[img_mask, :] ** 1., contrast_adjusted_np_src[ref_mask, :], channel_axis=1)
  92. shaped_noise = _np_src_image[:] * (1. - np_mask_rgb) + shaped_noise * np_mask_rgb
  93. matched_noise = shaped_noise[:]
  94. return np.clip(matched_noise, 0., 1.)
  95. class Script(scripts.Script):
  96. def title(self):
  97. return "Outpainting mk2"
  98. def show(self, is_img2img):
  99. return is_img2img
  100. def ui(self, is_img2img):
  101. if not is_img2img:
  102. return None
  103. info = gr.HTML("<p style=\"margin-bottom:0.75em\">Recommended settings: Sampling Steps: 80-100, Sampler: Euler a, Denoising strength: 0.8</p>")
  104. pixels = gr.Slider(label="Pixels to expand", minimum=8, maximum=256, step=8, value=128, elem_id=self.elem_id("pixels"))
  105. mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=8, elem_id=self.elem_id("mask_blur"))
  106. direction = gr.CheckboxGroup(label="Outpainting direction", choices=['left', 'right', 'up', 'down'], value=['left', 'right', 'up', 'down'], elem_id=self.elem_id("direction"))
  107. noise_q = gr.Slider(label="Fall-off exponent (lower=higher detail)", minimum=0.0, maximum=4.0, step=0.01, value=1.0, elem_id=self.elem_id("noise_q"))
  108. color_variation = gr.Slider(label="Color variation", minimum=0.0, maximum=1.0, step=0.01, value=0.05, elem_id=self.elem_id("color_variation"))
  109. return [info, pixels, mask_blur, direction, noise_q, color_variation]
  110. def run(self, p, _, pixels, mask_blur, direction, noise_q, color_variation):
  111. initial_seed_and_info = [None, None]
  112. process_width = p.width
  113. process_height = p.height
  114. p.inpaint_full_res = False
  115. p.inpainting_fill = 1
  116. p.do_not_save_samples = True
  117. p.do_not_save_grid = True
  118. left = pixels if "left" in direction else 0
  119. right = pixels if "right" in direction else 0
  120. up = pixels if "up" in direction else 0
  121. down = pixels if "down" in direction else 0
  122. if left > 0 or right > 0:
  123. mask_blur_x = mask_blur
  124. else:
  125. mask_blur_x = 0
  126. if up > 0 or down > 0:
  127. mask_blur_y = mask_blur
  128. else:
  129. mask_blur_y = 0
  130. p.mask_blur_x = mask_blur_x*4
  131. p.mask_blur_y = mask_blur_y*4
  132. init_img = p.init_images[0]
  133. target_w = math.ceil((init_img.width + left + right) / 64) * 64
  134. target_h = math.ceil((init_img.height + up + down) / 64) * 64
  135. if left > 0:
  136. left = left * (target_w - init_img.width) // (left + right)
  137. if right > 0:
  138. right = target_w - init_img.width - left
  139. if up > 0:
  140. up = up * (target_h - init_img.height) // (up + down)
  141. if down > 0:
  142. down = target_h - init_img.height - up
  143. def expand(init, count, expand_pixels, is_left=False, is_right=False, is_top=False, is_bottom=False):
  144. is_horiz = is_left or is_right
  145. is_vert = is_top or is_bottom
  146. pixels_horiz = expand_pixels if is_horiz else 0
  147. pixels_vert = expand_pixels if is_vert else 0
  148. images_to_process = []
  149. output_images = []
  150. for n in range(count):
  151. res_w = init[n].width + pixels_horiz
  152. res_h = init[n].height + pixels_vert
  153. process_res_w = math.ceil(res_w / 64) * 64
  154. process_res_h = math.ceil(res_h / 64) * 64
  155. img = Image.new("RGB", (process_res_w, process_res_h))
  156. img.paste(init[n], (pixels_horiz if is_left else 0, pixels_vert if is_top else 0))
  157. mask = Image.new("RGB", (process_res_w, process_res_h), "white")
  158. draw = ImageDraw.Draw(mask)
  159. draw.rectangle((
  160. expand_pixels + mask_blur_x if is_left else 0,
  161. expand_pixels + mask_blur_y if is_top else 0,
  162. mask.width - expand_pixels - mask_blur_x if is_right else res_w,
  163. mask.height - expand_pixels - mask_blur_y if is_bottom else res_h,
  164. ), fill="black")
  165. np_image = (np.asarray(img) / 255.0).astype(np.float64)
  166. np_mask = (np.asarray(mask) / 255.0).astype(np.float64)
  167. noised = get_matched_noise(np_image, np_mask, noise_q, color_variation)
  168. output_images.append(Image.fromarray(np.clip(noised * 255., 0., 255.).astype(np.uint8), mode="RGB"))
  169. target_width = min(process_width, init[n].width + pixels_horiz) if is_horiz else img.width
  170. target_height = min(process_height, init[n].height + pixels_vert) if is_vert else img.height
  171. p.width = target_width if is_horiz else img.width
  172. p.height = target_height if is_vert else img.height
  173. crop_region = (
  174. 0 if is_left else output_images[n].width - target_width,
  175. 0 if is_top else output_images[n].height - target_height,
  176. target_width if is_left else output_images[n].width,
  177. target_height if is_top else output_images[n].height,
  178. )
  179. mask = mask.crop(crop_region)
  180. p.image_mask = mask
  181. image_to_process = output_images[n].crop(crop_region)
  182. images_to_process.append(image_to_process)
  183. p.init_images = images_to_process
  184. latent_mask = Image.new("RGB", (p.width, p.height), "white")
  185. draw = ImageDraw.Draw(latent_mask)
  186. draw.rectangle((
  187. expand_pixels + mask_blur_x * 2 if is_left else 0,
  188. expand_pixels + mask_blur_y * 2 if is_top else 0,
  189. mask.width - expand_pixels - mask_blur_x * 2 if is_right else res_w,
  190. mask.height - expand_pixels - mask_blur_y * 2 if is_bottom else res_h,
  191. ), fill="black")
  192. p.latent_mask = latent_mask
  193. proc = process_images(p)
  194. if initial_seed_and_info[0] is None:
  195. initial_seed_and_info[0] = proc.seed
  196. initial_seed_and_info[1] = proc.info
  197. for n in range(count):
  198. output_images[n].paste(proc.images[n], (0 if is_left else output_images[n].width - proc.images[n].width, 0 if is_top else output_images[n].height - proc.images[n].height))
  199. output_images[n] = output_images[n].crop((0, 0, res_w, res_h))
  200. return output_images
  201. batch_count = p.n_iter
  202. batch_size = p.batch_size
  203. p.n_iter = 1
  204. state.job_count = batch_count * ((1 if left > 0 else 0) + (1 if right > 0 else 0) + (1 if up > 0 else 0) + (1 if down > 0 else 0))
  205. all_processed_images = []
  206. for i in range(batch_count):
  207. imgs = [init_img] * batch_size
  208. state.job = f"Batch {i + 1} out of {batch_count}"
  209. if left > 0:
  210. imgs = expand(imgs, batch_size, left, is_left=True)
  211. if right > 0:
  212. imgs = expand(imgs, batch_size, right, is_right=True)
  213. if up > 0:
  214. imgs = expand(imgs, batch_size, up, is_top=True)
  215. if down > 0:
  216. imgs = expand(imgs, batch_size, down, is_bottom=True)
  217. all_processed_images += imgs
  218. all_images = all_processed_images
  219. combined_grid_image = images.image_grid(all_processed_images)
  220. unwanted_grid_because_of_img_count = len(all_processed_images) < 2 and opts.grid_only_if_multiple
  221. if opts.return_grid and not unwanted_grid_because_of_img_count:
  222. all_images = [combined_grid_image] + all_processed_images
  223. res = Processed(p, all_images, initial_seed_and_info[0], initial_seed_and_info[1])
  224. if opts.samples_save:
  225. for img in all_processed_images:
  226. images.save_image(img, p.outpath_samples, "", res.seed, p.prompt, opts.samples_format, info=res.info, p=p)
  227. if opts.grid_save and not unwanted_grid_because_of_img_count:
  228. images.save_image(combined_grid_image, p.outpath_grids, "grid", res.seed, p.prompt, opts.grid_format, info=res.info, short_filename=not opts.grid_extended_filename, grid=True, p=p)
  229. return res