processing.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481
  1. import contextlib
  2. import json
  3. import math
  4. import os
  5. import sys
  6. import torch
  7. import numpy as np
  8. from PIL import Image, ImageFilter, ImageOps
  9. import random
  10. import modules.sd_hijack
  11. from modules import devices
  12. from modules.sd_hijack import model_hijack
  13. from modules.sd_samplers import samplers, samplers_for_img2img
  14. from modules.shared import opts, cmd_opts, state
  15. import modules.shared as shared
  16. import modules.face_restoration
  17. import modules.images as images
  18. import modules.styles
  19. # some of those options should not be changed at all because they would break the model, so I removed them from options.
  20. opt_C = 4
  21. opt_f = 8
  22. class StableDiffusionProcessing:
  23. def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt="", prompt_style="None", seed=-1, subseed=-1, subseed_strength=0, seed_resize_from_h=-1, seed_resize_from_w=-1, sampler_index=0, batch_size=1, n_iter=1, steps=50, cfg_scale=7.0, width=512, height=512, restore_faces=False, tiling=False, do_not_save_samples=False, do_not_save_grid=False, extra_generation_params=None, overlay_images=None, negative_prompt=None):
  24. self.sd_model = sd_model
  25. self.outpath_samples: str = outpath_samples
  26. self.outpath_grids: str = outpath_grids
  27. self.prompt: str = prompt
  28. self.prompt_for_display: str = None
  29. self.negative_prompt: str = (negative_prompt or "")
  30. self.prompt_style: str = prompt_style
  31. self.seed: int = seed
  32. self.subseed: int = subseed
  33. self.subseed_strength: float = subseed_strength
  34. self.seed_resize_from_h: int = seed_resize_from_h
  35. self.seed_resize_from_w: int = seed_resize_from_w
  36. self.sampler_index: int = sampler_index
  37. self.batch_size: int = batch_size
  38. self.n_iter: int = n_iter
  39. self.steps: int = steps
  40. self.cfg_scale: float = cfg_scale
  41. self.width: int = width
  42. self.height: int = height
  43. self.restore_faces: bool = restore_faces
  44. self.tiling: bool = tiling
  45. self.do_not_save_samples: bool = do_not_save_samples
  46. self.do_not_save_grid: bool = do_not_save_grid
  47. self.extra_generation_params: dict = extra_generation_params
  48. self.overlay_images = overlay_images
  49. self.paste_to = None
  50. def init(self, seed):
  51. pass
  52. def sample(self, x, conditioning, unconditional_conditioning):
  53. raise NotImplementedError()
  54. class Processed:
  55. def __init__(self, p: StableDiffusionProcessing, images_list, seed, info):
  56. self.images = images_list
  57. self.prompt = p.prompt
  58. self.seed = seed
  59. self.info = info
  60. self.width = p.width
  61. self.height = p.height
  62. self.sampler = samplers[p.sampler_index].name
  63. self.cfg_scale = p.cfg_scale
  64. self.steps = p.steps
  65. def js(self):
  66. obj = {
  67. "prompt": self.prompt if type(self.prompt) != list else self.prompt[0],
  68. "seed": int(self.seed if type(self.seed) != list else self.seed[0]),
  69. "width": self.width,
  70. "height": self.height,
  71. "sampler": self.sampler,
  72. "cfg_scale": self.cfg_scale,
  73. "steps": self.steps,
  74. }
  75. return json.dumps(obj)
  76. # from https://discuss.pytorch.org/t/help-regarding-slerp-function-for-generative-model-sampling/32475/3
  77. def slerp(val, low, high):
  78. low_norm = low/torch.norm(low, dim=1, keepdim=True)
  79. high_norm = high/torch.norm(high, dim=1, keepdim=True)
  80. omega = torch.acos((low_norm*high_norm).sum(1))
  81. so = torch.sin(omega)
  82. res = (torch.sin((1.0-val)*omega)/so).unsqueeze(1)*low + (torch.sin(val*omega)/so).unsqueeze(1) * high
  83. return res
  84. def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, seed_resize_from_h=0, seed_resize_from_w=0):
  85. xs = []
  86. for i, seed in enumerate(seeds):
  87. noise_shape = shape if seed_resize_from_h <= 0 or seed_resize_from_w <= 0 else (shape[0], seed_resize_from_h//8, seed_resize_from_w//8)
  88. subnoise = None
  89. if subseeds is not None:
  90. subseed = 0 if i >= len(subseeds) else subseeds[i]
  91. torch.manual_seed(subseed)
  92. subnoise = torch.randn(noise_shape, device=shared.device)
  93. # randn results depend on device; gpu and cpu get different results for same seed;
  94. # the way I see it, it's better to do this on CPU, so that everyone gets same result;
  95. # but the original script had it like this, so I do not dare change it for now because
  96. # it will break everyone's seeds.
  97. torch.manual_seed(seed)
  98. noise = torch.randn(noise_shape, device=shared.device)
  99. if subnoise is not None:
  100. #noise = subnoise * subseed_strength + noise * (1 - subseed_strength)
  101. noise = slerp(subseed_strength, noise, subnoise)
  102. if noise_shape != shape:
  103. #noise = torch.nn.functional.interpolate(noise.unsqueeze(1), size=shape[1:], mode="bilinear").squeeze()
  104. # noise_shape = (64, 80)
  105. # shape = (64, 72)
  106. torch.manual_seed(seed)
  107. x = torch.randn(shape, device=shared.device)
  108. dx = (shape[2] - noise_shape[2]) // 2 # -4
  109. dy = (shape[1] - noise_shape[1]) // 2
  110. w = noise_shape[2] if dx >= 0 else noise_shape[2] + 2 * dx
  111. h = noise_shape[1] if dy >= 0 else noise_shape[1] + 2 * dy
  112. tx = 0 if dx < 0 else dx
  113. ty = 0 if dy < 0 else dy
  114. dx = max(-dx, 0)
  115. dy = max(-dy, 0)
  116. x[:, ty:ty+h, tx:tx+w] = noise[:, dy:dy+h, dx:dx+w]
  117. noise = x
  118. xs.append(noise)
  119. x = torch.stack(xs).to(shared.device)
  120. return x
  121. def fix_seed(p):
  122. p.seed = int(random.randrange(4294967294)) if p.seed is None or p.seed == -1 else p.seed
  123. p.subseed = int(random.randrange(4294967294)) if p.subseed is None or p.subseed == -1 else p.subseed
  124. def process_images(p: StableDiffusionProcessing) -> Processed:
  125. """this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch"""
  126. assert p.prompt is not None
  127. devices.torch_gc()
  128. fix_seed(p)
  129. os.makedirs(p.outpath_samples, exist_ok=True)
  130. os.makedirs(p.outpath_grids, exist_ok=True)
  131. modules.sd_hijack.model_hijack.apply_circular(p.tiling)
  132. comments = []
  133. modules.styles.apply_style(p, shared.prompt_styles[p.prompt_style])
  134. if type(p.prompt) == list:
  135. all_prompts = p.prompt
  136. else:
  137. all_prompts = p.batch_size * p.n_iter * [p.prompt]
  138. if type(p.seed) == list:
  139. all_seeds = p.seed
  140. else:
  141. all_seeds = [int(p.seed + x) for x in range(len(all_prompts))]
  142. if type(p.subseed) == list:
  143. all_subseeds = p.subseed
  144. else:
  145. all_subseeds = [int(p.subseed + x) for x in range(len(all_prompts))]
  146. def infotext(iteration=0, position_in_batch=0):
  147. index = position_in_batch + iteration * p.batch_size
  148. generation_params = {
  149. "Steps": p.steps,
  150. "Sampler": samplers[p.sampler_index].name,
  151. "CFG scale": p.cfg_scale,
  152. "Seed": all_seeds[index],
  153. "Face restoration": (opts.face_restoration_model if p.restore_faces else None),
  154. "Size": f"{p.width}x{p.height}",
  155. "Batch size": (None if p.batch_size < 2 else p.batch_size),
  156. "Batch pos": (None if p.batch_size < 2 else position_in_batch),
  157. "Variation seed": (None if p.subseed_strength == 0 else all_subseeds[index]),
  158. "Variation seed strength": (None if p.subseed_strength == 0 else p.subseed_strength),
  159. "Seed resize from": (None if p.seed_resize_from_w == 0 or p.seed_resize_from_h == 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"),
  160. }
  161. if p.extra_generation_params is not None:
  162. generation_params.update(p.extra_generation_params)
  163. generation_params_text = ", ".join([k if k == v else f'{k}: {v}' for k, v in generation_params.items() if v is not None])
  164. negative_prompt_text = "\nNegative prompt: " + p.negative_prompt if p.negative_prompt else ""
  165. return f"{all_prompts[index]}{negative_prompt_text}\n{generation_params_text}".strip() + "".join(["\n\n" + x for x in comments])
  166. if os.path.exists(cmd_opts.embeddings_dir):
  167. model_hijack.load_textual_inversion_embeddings(cmd_opts.embeddings_dir, p.sd_model)
  168. output_images = []
  169. precision_scope = torch.autocast if cmd_opts.precision == "autocast" else contextlib.nullcontext
  170. ema_scope = (contextlib.nullcontext if cmd_opts.lowvram else p.sd_model.ema_scope)
  171. with torch.no_grad(), precision_scope("cuda"), ema_scope():
  172. p.init(seed=all_seeds[0])
  173. if state.job_count == -1:
  174. state.job_count = p.n_iter
  175. for n in range(p.n_iter):
  176. if state.interrupted:
  177. break
  178. prompts = all_prompts[n * p.batch_size:(n + 1) * p.batch_size]
  179. seeds = all_seeds[n * p.batch_size:(n + 1) * p.batch_size]
  180. subseeds = all_subseeds[n * p.batch_size:(n + 1) * p.batch_size]
  181. uc = p.sd_model.get_learned_conditioning(len(prompts) * [p.negative_prompt])
  182. c = p.sd_model.get_learned_conditioning(prompts)
  183. if len(model_hijack.comments) > 0:
  184. comments += model_hijack.comments
  185. # we manually generate all input noises because each one should have a specific seed
  186. x = create_random_tensors([opt_C, p.height // opt_f, p.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength, seed_resize_from_h=p.seed_resize_from_h, seed_resize_from_w=p.seed_resize_from_w)
  187. if p.n_iter > 1:
  188. shared.state.job = f"Batch {n+1} out of {p.n_iter}"
  189. samples_ddim = p.sample(x=x, conditioning=c, unconditional_conditioning=uc)
  190. if state.interrupted:
  191. # if we are interruped, sample returns just noise
  192. # use the image collected previously in sampler loop
  193. samples_ddim = shared.state.current_latent
  194. x_samples_ddim = p.sd_model.decode_first_stage(samples_ddim)
  195. x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
  196. for i, x_sample in enumerate(x_samples_ddim):
  197. x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
  198. x_sample = x_sample.astype(np.uint8)
  199. if p.restore_faces:
  200. devices.torch_gc()
  201. x_sample = modules.face_restoration.restore_faces(x_sample)
  202. image = Image.fromarray(x_sample)
  203. if p.overlay_images is not None and i < len(p.overlay_images):
  204. overlay = p.overlay_images[i]
  205. if p.paste_to is not None:
  206. x, y, w, h = p.paste_to
  207. base_image = Image.new('RGBA', (overlay.width, overlay.height))
  208. image = images.resize_image(1, image, w, h)
  209. base_image.paste(image, (x, y))
  210. image = base_image
  211. image = image.convert('RGBA')
  212. image.alpha_composite(overlay)
  213. image = image.convert('RGB')
  214. if opts.samples_save and not p.do_not_save_samples:
  215. images.save_image(image, p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p)
  216. output_images.append(image)
  217. state.nextjob()
  218. unwanted_grid_because_of_img_count = len(output_images) < 2 and opts.grid_only_if_multiple
  219. if not p.do_not_save_grid and not unwanted_grid_because_of_img_count:
  220. return_grid = opts.return_grid
  221. grid = images.image_grid(output_images, p.batch_size)
  222. if return_grid:
  223. output_images.insert(0, grid)
  224. if opts.grid_save:
  225. images.save_image(grid, p.outpath_grids, "grid", all_seeds[0], all_prompts[0], opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename, p=p)
  226. devices.torch_gc()
  227. return Processed(p, output_images, all_seeds[0], infotext())
  228. class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
  229. sampler = None
  230. def init(self, seed):
  231. self.sampler = samplers[self.sampler_index].constructor(self.sd_model)
  232. def sample(self, x, conditioning, unconditional_conditioning):
  233. samples_ddim = self.sampler.sample(self, x, conditioning, unconditional_conditioning)
  234. return samples_ddim
  235. def get_crop_region(mask, pad=0):
  236. h, w = mask.shape
  237. crop_left = 0
  238. for i in range(w):
  239. if not (mask[:, i] == 0).all():
  240. break
  241. crop_left += 1
  242. crop_right = 0
  243. for i in reversed(range(w)):
  244. if not (mask[:, i] == 0).all():
  245. break
  246. crop_right += 1
  247. crop_top = 0
  248. for i in range(h):
  249. if not (mask[i] == 0).all():
  250. break
  251. crop_top += 1
  252. crop_bottom = 0
  253. for i in reversed(range(h)):
  254. if not (mask[i] == 0).all():
  255. break
  256. crop_bottom += 1
  257. return (
  258. int(max(crop_left-pad, 0)),
  259. int(max(crop_top-pad, 0)),
  260. int(min(w - crop_right + pad, w)),
  261. int(min(h - crop_bottom + pad, h))
  262. )
  263. def fill(image, mask):
  264. image_mod = Image.new('RGBA', (image.width, image.height))
  265. image_masked = Image.new('RGBa', (image.width, image.height))
  266. image_masked.paste(image.convert("RGBA").convert("RGBa"), mask=ImageOps.invert(mask.convert('L')))
  267. image_masked = image_masked.convert('RGBa')
  268. for radius, repeats in [(256, 1), (64, 1), (16, 2), (4, 4), (2, 2), (0, 1)]:
  269. blurred = image_masked.filter(ImageFilter.GaussianBlur(radius)).convert('RGBA')
  270. for _ in range(repeats):
  271. image_mod.alpha_composite(blurred)
  272. return image_mod.convert("RGB")
  273. class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
  274. sampler = None
  275. def __init__(self, init_images=None, resize_mode=0, denoising_strength=0.75, mask=None, mask_blur=4, inpainting_fill=0, inpaint_full_res=True, inpainting_mask_invert=0, **kwargs):
  276. super().__init__(**kwargs)
  277. self.init_images = init_images
  278. self.resize_mode: int = resize_mode
  279. self.denoising_strength: float = denoising_strength
  280. self.init_latent = None
  281. self.image_mask = mask
  282. #self.image_unblurred_mask = None
  283. self.latent_mask = None
  284. self.mask_for_overlay = None
  285. self.mask_blur = mask_blur
  286. self.inpainting_fill = inpainting_fill
  287. self.inpaint_full_res = inpaint_full_res
  288. self.inpainting_mask_invert = inpainting_mask_invert
  289. self.mask = None
  290. self.nmask = None
  291. def init(self, seed):
  292. self.sampler = samplers_for_img2img[self.sampler_index].constructor(self.sd_model)
  293. crop_region = None
  294. if self.image_mask is not None:
  295. self.image_mask = self.image_mask.convert('L')
  296. if self.inpainting_mask_invert:
  297. self.image_mask = ImageOps.invert(self.image_mask)
  298. #self.image_unblurred_mask = self.image_mask
  299. if self.mask_blur > 0:
  300. self.image_mask = self.image_mask.filter(ImageFilter.GaussianBlur(self.mask_blur))
  301. if self.inpaint_full_res:
  302. self.mask_for_overlay = self.image_mask
  303. mask = self.image_mask.convert('L')
  304. crop_region = get_crop_region(np.array(mask), opts.upscale_at_full_resolution_padding)
  305. x1, y1, x2, y2 = crop_region
  306. mask = mask.crop(crop_region)
  307. self.image_mask = images.resize_image(2, mask, self.width, self.height)
  308. self.paste_to = (x1, y1, x2-x1, y2-y1)
  309. else:
  310. self.image_mask = images.resize_image(self.resize_mode, self.image_mask, self.width, self.height)
  311. np_mask = np.array(self.image_mask)
  312. np_mask = np.clip((np_mask.astype(np.float)) * 2, 0, 255).astype(np.uint8)
  313. self.mask_for_overlay = Image.fromarray(np_mask)
  314. self.overlay_images = []
  315. latent_mask = self.latent_mask if self.latent_mask is not None else self.image_mask
  316. imgs = []
  317. for img in self.init_images:
  318. image = img.convert("RGB")
  319. if crop_region is None:
  320. image = images.resize_image(self.resize_mode, image, self.width, self.height)
  321. if self.image_mask is not None:
  322. image_masked = Image.new('RGBa', (image.width, image.height))
  323. image_masked.paste(image.convert("RGBA").convert("RGBa"), mask=ImageOps.invert(self.mask_for_overlay.convert('L')))
  324. self.overlay_images.append(image_masked.convert('RGBA'))
  325. if crop_region is not None:
  326. image = image.crop(crop_region)
  327. image = images.resize_image(2, image, self.width, self.height)
  328. if self.image_mask is not None:
  329. if self.inpainting_fill != 1:
  330. image = fill(image, latent_mask)
  331. image = np.array(image).astype(np.float32) / 255.0
  332. image = np.moveaxis(image, 2, 0)
  333. imgs.append(image)
  334. if len(imgs) == 1:
  335. batch_images = np.expand_dims(imgs[0], axis=0).repeat(self.batch_size, axis=0)
  336. if self.overlay_images is not None:
  337. self.overlay_images = self.overlay_images * self.batch_size
  338. elif len(imgs) <= self.batch_size:
  339. self.batch_size = len(imgs)
  340. batch_images = np.array(imgs)
  341. else:
  342. raise RuntimeError(f"bad number of images passed: {len(imgs)}; expecting {self.batch_size} or less")
  343. image = torch.from_numpy(batch_images)
  344. image = 2. * image - 1.
  345. image = image.to(shared.device)
  346. self.init_latent = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(image))
  347. if self.image_mask is not None:
  348. init_mask = latent_mask
  349. latmask = init_mask.convert('RGB').resize((self.init_latent.shape[3], self.init_latent.shape[2]))
  350. latmask = np.moveaxis(np.array(latmask, dtype=np.float64), 2, 0) / 255
  351. latmask = latmask[0]
  352. latmask = np.around(latmask)
  353. latmask = np.tile(latmask[None], (4, 1, 1))
  354. self.mask = torch.asarray(1.0 - latmask).to(shared.device).type(self.sd_model.dtype)
  355. self.nmask = torch.asarray(latmask).to(shared.device).type(self.sd_model.dtype)
  356. if self.inpainting_fill == 2:
  357. self.init_latent = self.init_latent * self.mask + create_random_tensors(self.init_latent.shape[1:], [seed + x + 1 for x in range(self.init_latent.shape[0])]) * self.nmask
  358. elif self.inpainting_fill == 3:
  359. self.init_latent = self.init_latent * self.mask
  360. def sample(self, x, conditioning, unconditional_conditioning):
  361. samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning)
  362. if self.mask is not None:
  363. samples = samples * self.nmask + self.init_latent * self.mask
  364. return samples