processing.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409
  1. import contextlib
  2. import json
  3. import math
  4. import os
  5. import sys
  6. import torch
  7. import numpy as np
  8. from PIL import Image, ImageFilter, ImageOps
  9. import random
  10. from modules.sd_hijack import model_hijack
  11. from modules.sd_samplers import samplers, samplers_for_img2img
  12. from modules.shared import opts, cmd_opts, state
  13. import modules.shared as shared
  14. import modules.gfpgan_model as gfpgan
  15. import modules.images as images
  16. # some of those options should not be changed at all because they would break the model, so I removed them from options.
  17. opt_C = 4
  18. opt_f = 8
  19. def torch_gc():
  20. if torch.cuda.is_available():
  21. torch.cuda.empty_cache()
  22. torch.cuda.ipc_collect()
  23. class StableDiffusionProcessing:
  24. def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt="", seed=-1, sampler_index=0, batch_size=1, n_iter=1, steps=50, cfg_scale=7.0, width=512, height=512, prompt_matrix=False, use_GFPGAN=False, do_not_save_samples=False, do_not_save_grid=False, extra_generation_params=None, overlay_images=None, negative_prompt=None):
  25. self.sd_model = sd_model
  26. self.outpath_samples: str = outpath_samples
  27. self.outpath_grids: str = outpath_grids
  28. self.prompt: str = prompt
  29. self.negative_prompt: str = (negative_prompt or "")
  30. self.seed: int = seed
  31. self.sampler_index: int = sampler_index
  32. self.batch_size: int = batch_size
  33. self.n_iter: int = n_iter
  34. self.steps: int = steps
  35. self.cfg_scale: float = cfg_scale
  36. self.width: int = width
  37. self.height: int = height
  38. self.prompt_matrix: bool = prompt_matrix
  39. self.use_GFPGAN: bool = use_GFPGAN
  40. self.do_not_save_samples: bool = do_not_save_samples
  41. self.do_not_save_grid: bool = do_not_save_grid
  42. self.extra_generation_params: dict = extra_generation_params
  43. self.overlay_images = overlay_images
  44. self.paste_to = None
  45. def init(self):
  46. pass
  47. def sample(self, x, conditioning, unconditional_conditioning):
  48. raise NotImplementedError()
  49. class Processed:
  50. def __init__(self, p: StableDiffusionProcessing, images_list, seed, info):
  51. self.images = images_list
  52. self.prompt = p.prompt
  53. self.seed = seed
  54. self.info = info
  55. self.width = p.width
  56. self.height = p.height
  57. self.sampler = samplers[p.sampler_index].name
  58. self.cfg_scale = p.cfg_scale
  59. self.steps = p.steps
  60. def js(self):
  61. obj = {
  62. "prompt": self.prompt,
  63. "seed": int(self.seed),
  64. "width": self.width,
  65. "height": self.height,
  66. "sampler": self.sampler,
  67. "cfg_scale": self.cfg_scale,
  68. "steps": self.steps,
  69. }
  70. return json.dumps(obj)
  71. def create_random_tensors(shape, seeds):
  72. xs = []
  73. for seed in seeds:
  74. torch.manual_seed(seed)
  75. # randn results depend on device; gpu and cpu get different results for same seed;
  76. # the way I see it, it's better to do this on CPU, so that everyone gets same result;
  77. # but the original script had it like this so I do not dare change it for now because
  78. # it will break everyone's seeds.
  79. xs.append(torch.randn(shape, device=shared.device))
  80. x = torch.stack(xs)
  81. return x
  82. def process_images(p: StableDiffusionProcessing) -> Processed:
  83. """this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch"""
  84. prompt = p.prompt
  85. assert p.prompt is not None
  86. torch_gc()
  87. seed = int(random.randrange(4294967294) if p.seed == -1 else p.seed)
  88. os.makedirs(p.outpath_samples, exist_ok=True)
  89. os.makedirs(p.outpath_grids, exist_ok=True)
  90. comments = []
  91. prompt_matrix_parts = []
  92. if p.prompt_matrix:
  93. all_prompts = []
  94. prompt_matrix_parts = prompt.split("|")
  95. combination_count = 2 ** (len(prompt_matrix_parts) - 1)
  96. for combination_num in range(combination_count):
  97. selected_prompts = [text.strip().strip(',') for n, text in enumerate(prompt_matrix_parts[1:]) if combination_num & (1 << n)]
  98. if opts.prompt_matrix_add_to_start:
  99. selected_prompts = selected_prompts + [prompt_matrix_parts[0]]
  100. else:
  101. selected_prompts = [prompt_matrix_parts[0]] + selected_prompts
  102. all_prompts.append(", ".join(selected_prompts))
  103. p.n_iter = math.ceil(len(all_prompts) / p.batch_size)
  104. all_seeds = len(all_prompts) * [seed]
  105. print(f"Prompt matrix will create {len(all_prompts)} images using a total of {p.n_iter} batches.")
  106. else:
  107. all_prompts = p.batch_size * p.n_iter * [prompt]
  108. all_seeds = [seed + x for x in range(len(all_prompts))]
  109. def infotext(iteration=0, position_in_batch=0):
  110. generation_params = {
  111. "Steps": p.steps,
  112. "Sampler": samplers[p.sampler_index].name,
  113. "CFG scale": p.cfg_scale,
  114. "Seed": all_seeds[position_in_batch + iteration * p.batch_size],
  115. "GFPGAN": ("GFPGAN" if p.use_GFPGAN else None)
  116. }
  117. if p.extra_generation_params is not None:
  118. generation_params.update(p.extra_generation_params)
  119. generation_params_text = ", ".join([k if k == v else f'{k}: {v}' for k, v in generation_params.items() if v is not None])
  120. return f"{prompt}\n{generation_params_text}".strip() + "".join(["\n\n" + x for x in comments])
  121. if os.path.exists(cmd_opts.embeddings_dir):
  122. model_hijack.load_textual_inversion_embeddings(cmd_opts.embeddings_dir, p.sd_model)
  123. output_images = []
  124. precision_scope = torch.autocast if cmd_opts.precision == "autocast" else contextlib.nullcontext
  125. ema_scope = (contextlib.nullcontext if cmd_opts.lowvram else p.sd_model.ema_scope)
  126. with torch.no_grad(), precision_scope("cuda"), ema_scope():
  127. p.init()
  128. for n in range(p.n_iter):
  129. if state.interrupted:
  130. break
  131. prompts = all_prompts[n * p.batch_size:(n + 1) * p.batch_size]
  132. seeds = all_seeds[n * p.batch_size:(n + 1) * p.batch_size]
  133. uc = p.sd_model.get_learned_conditioning(len(prompts) * [p.negative_prompt])
  134. c = p.sd_model.get_learned_conditioning(prompts)
  135. if len(model_hijack.comments) > 0:
  136. comments += model_hijack.comments
  137. # we manually generate all input noises because each one should have a specific seed
  138. x = create_random_tensors([opt_C, p.height // opt_f, p.width // opt_f], seeds=seeds)
  139. if p.n_iter > 1:
  140. shared.state.job = f"Batch {n+1} out of {p.n_iter}"
  141. samples_ddim = p.sample(x=x, conditioning=c, unconditional_conditioning=uc)
  142. x_samples_ddim = p.sd_model.decode_first_stage(samples_ddim)
  143. x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
  144. for i, x_sample in enumerate(x_samples_ddim):
  145. x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
  146. x_sample = x_sample.astype(np.uint8)
  147. if p.use_GFPGAN:
  148. torch_gc()
  149. x_sample = gfpgan.gfpgan_fix_faces(x_sample)
  150. image = Image.fromarray(x_sample)
  151. if p.overlay_images is not None and i < len(p.overlay_images):
  152. overlay = p.overlay_images[i]
  153. if p.paste_to is not None:
  154. x, y, w, h = p.paste_to
  155. base_image = Image.new('RGBA', (overlay.width, overlay.height))
  156. image = images.resize_image(1, image, w, h)
  157. base_image.paste(image, (x, y))
  158. image = base_image
  159. image = image.convert('RGBA')
  160. image.alpha_composite(overlay)
  161. image = image.convert('RGB')
  162. if opts.samples_save and not p.do_not_save_samples:
  163. images.save_image(image, p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i))
  164. output_images.append(image)
  165. unwanted_grid_because_of_img_count = len(output_images) < 2 and opts.grid_only_if_multiple
  166. if not p.do_not_save_grid and not unwanted_grid_because_of_img_count:
  167. return_grid = opts.return_grid
  168. if p.prompt_matrix:
  169. grid = images.image_grid(output_images, p.batch_size, rows=1 << ((len(prompt_matrix_parts)-1)//2))
  170. try:
  171. grid = images.draw_prompt_matrix(grid, p.width, p.height, prompt_matrix_parts)
  172. except Exception:
  173. import traceback
  174. print("Error creating prompt_matrix text:", file=sys.stderr)
  175. print(traceback.format_exc(), file=sys.stderr)
  176. return_grid = True
  177. else:
  178. grid = images.image_grid(output_images, p.batch_size)
  179. if return_grid:
  180. output_images.insert(0, grid)
  181. if opts.grid_save:
  182. images.save_image(grid, p.outpath_grids, "grid", seed, prompt, opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename)
  183. torch_gc()
  184. return Processed(p, output_images, seed, infotext())
  185. class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
  186. sampler = None
  187. def init(self):
  188. self.sampler = samplers[self.sampler_index].constructor(self.sd_model)
  189. def sample(self, x, conditioning, unconditional_conditioning):
  190. samples_ddim = self.sampler.sample(self, x, conditioning, unconditional_conditioning)
  191. return samples_ddim
  192. def get_crop_region(mask, pad=0):
  193. h, w = mask.shape
  194. crop_left = 0
  195. for i in range(w):
  196. if not (mask[:, i] == 0).all():
  197. break
  198. crop_left += 1
  199. crop_right = 0
  200. for i in reversed(range(w)):
  201. if not (mask[:, i] == 0).all():
  202. break
  203. crop_right += 1
  204. crop_top = 0
  205. for i in range(h):
  206. if not (mask[i] == 0).all():
  207. break
  208. crop_top += 1
  209. crop_bottom = 0
  210. for i in reversed(range(h)):
  211. if not (mask[i] == 0).all():
  212. break
  213. crop_bottom += 1
  214. return (
  215. int(max(crop_left-pad, 0)),
  216. int(max(crop_top-pad, 0)),
  217. int(min(w - crop_right + pad, w)),
  218. int(min(h - crop_bottom + pad, h))
  219. )
  220. def fill(image, mask):
  221. image_mod = Image.new('RGBA', (image.width, image.height))
  222. image_masked = Image.new('RGBa', (image.width, image.height))
  223. image_masked.paste(image.convert("RGBA").convert("RGBa"), mask=ImageOps.invert(mask.convert('L')))
  224. image_masked = image_masked.convert('RGBa')
  225. for radius, repeats in [(64, 1), (16, 2), (4, 4), (2, 2), (0, 1)]:
  226. blurred = image_masked.filter(ImageFilter.GaussianBlur(radius)).convert('RGBA')
  227. for _ in range(repeats):
  228. image_mod.alpha_composite(blurred)
  229. return image_mod.convert("RGB")
  230. class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
  231. sampler = None
  232. def __init__(self, init_images=None, resize_mode=0, denoising_strength=0.75, mask=None, mask_blur=4, inpainting_fill=0, inpaint_full_res=True, **kwargs):
  233. super().__init__(**kwargs)
  234. self.init_images = init_images
  235. self.resize_mode: int = resize_mode
  236. self.denoising_strength: float = denoising_strength
  237. self.init_latent = None
  238. self.image_mask = mask
  239. self.mask_for_overlay = None
  240. self.mask_blur = mask_blur
  241. self.inpainting_fill = inpainting_fill
  242. self.inpaint_full_res = inpaint_full_res
  243. self.mask = None
  244. self.nmask = None
  245. def init(self):
  246. self.sampler = samplers_for_img2img[self.sampler_index].constructor(self.sd_model)
  247. crop_region = None
  248. if self.image_mask is not None:
  249. if self.mask_blur > 0:
  250. self.image_mask = self.image_mask.filter(ImageFilter.GaussianBlur(self.mask_blur)).convert('L')
  251. if self.inpaint_full_res:
  252. self.mask_for_overlay = self.image_mask
  253. mask = self.image_mask.convert('L')
  254. crop_region = get_crop_region(np.array(mask), 64)
  255. x1, y1, x2, y2 = crop_region
  256. mask = mask.crop(crop_region)
  257. self.image_mask = images.resize_image(2, mask, self.width, self.height)
  258. self.paste_to = (x1, y1, x2-x1, y2-y1)
  259. else:
  260. self.image_mask = images.resize_image(self.resize_mode, self.image_mask, self.width, self.height)
  261. self.mask_for_overlay = self.image_mask
  262. self.overlay_images = []
  263. imgs = []
  264. for img in self.init_images:
  265. image = img.convert("RGB")
  266. if crop_region is None:
  267. image = images.resize_image(self.resize_mode, image, self.width, self.height)
  268. if self.image_mask is not None:
  269. if self.inpainting_fill != 1:
  270. image = fill(image, self.mask_for_overlay)
  271. image_masked = Image.new('RGBa', (image.width, image.height))
  272. image_masked.paste(image.convert("RGBA").convert("RGBa"), mask=ImageOps.invert(self.mask_for_overlay.convert('L')))
  273. self.overlay_images.append(image_masked.convert('RGBA'))
  274. if crop_region is not None:
  275. image = image.crop(crop_region)
  276. image = images.resize_image(2, image, self.width, self.height)
  277. image = np.array(image).astype(np.float32) / 255.0
  278. image = np.moveaxis(image, 2, 0)
  279. imgs.append(image)
  280. if len(imgs) == 1:
  281. batch_images = np.expand_dims(imgs[0], axis=0).repeat(self.batch_size, axis=0)
  282. if self.overlay_images is not None:
  283. self.overlay_images = self.overlay_images * self.batch_size
  284. elif len(imgs) <= self.batch_size:
  285. self.batch_size = len(imgs)
  286. batch_images = np.array(imgs)
  287. else:
  288. raise RuntimeError(f"bad number of images passed: {len(imgs)}; expecting {self.batch_size} or less")
  289. image = torch.from_numpy(batch_images)
  290. image = 2. * image - 1.
  291. image = image.to(shared.device)
  292. self.init_latent = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(image))
  293. if self.image_mask is not None:
  294. latmask = self.image_mask.convert('RGB').resize((self.init_latent.shape[3], self.init_latent.shape[2]))
  295. latmask = np.moveaxis(np.array(latmask, dtype=np.float64), 2, 0) / 255
  296. latmask = latmask[0]
  297. latmask = np.tile(latmask[None], (4, 1, 1))
  298. self.mask = torch.asarray(1.0 - latmask).to(shared.device).type(self.sd_model.dtype)
  299. self.nmask = torch.asarray(latmask).to(shared.device).type(self.sd_model.dtype)
  300. if self.inpainting_fill == 2:
  301. self.init_latent = self.init_latent * self.mask + create_random_tensors(self.init_latent.shape[1:], [self.seed + x + 1 for x in range(self.init_latent.shape[0])]) * self.nmask
  302. elif self.inpainting_fill == 3:
  303. self.init_latent = self.init_latent * self.mask
  304. def sample(self, x, conditioning, unconditional_conditioning):
  305. samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning)
  306. if self.mask is not None:
  307. samples = samples * self.nmask + self.init_latent * self.mask
  308. return samples