processing.py 28 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621
  1. import contextlib
  2. import json
  3. import math
  4. import os
  5. import sys
  6. import torch
  7. import numpy as np
  8. from PIL import Image, ImageFilter, ImageOps
  9. import random
  10. import cv2
  11. from skimage import exposure
  12. import modules.sd_hijack
  13. from modules import devices, prompt_parser, masking
  14. from modules.sd_hijack import model_hijack
  15. from modules.sd_samplers import samplers, samplers_for_img2img
  16. from modules.shared import opts, cmd_opts, state
  17. import modules.shared as shared
  18. import modules.face_restoration
  19. import modules.images as images
  20. import modules.styles
  21. # some of those options should not be changed at all because they would break the model, so I removed them from options.
  22. opt_C = 4
  23. opt_f = 8
  24. def setup_color_correction(image):
  25. correction_target = cv2.cvtColor(np.asarray(image.copy()), cv2.COLOR_RGB2LAB)
  26. return correction_target
  27. def apply_color_correction(correction, image):
  28. image = Image.fromarray(cv2.cvtColor(exposure.match_histograms(
  29. cv2.cvtColor(
  30. np.asarray(image),
  31. cv2.COLOR_RGB2LAB
  32. ),
  33. correction,
  34. channel_axis=2
  35. ), cv2.COLOR_LAB2RGB).astype("uint8"))
  36. return image
  37. class StableDiffusionProcessing:
  38. def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt="", styles=None, seed=-1, subseed=-1, subseed_strength=0, seed_resize_from_h=-1, seed_resize_from_w=-1, seed_enable_extras=True, sampler_index=0, batch_size=1, n_iter=1, steps=50, cfg_scale=7.0, width=512, height=512, restore_faces=False, tiling=False, do_not_save_samples=False, do_not_save_grid=False, extra_generation_params=None, overlay_images=None, negative_prompt=None):
  39. self.sd_model = sd_model
  40. self.outpath_samples: str = outpath_samples
  41. self.outpath_grids: str = outpath_grids
  42. self.prompt: str = prompt
  43. self.prompt_for_display: str = None
  44. self.negative_prompt: str = (negative_prompt or "")
  45. self.styles: str = styles
  46. self.seed: int = seed
  47. self.subseed: int = subseed
  48. self.subseed_strength: float = subseed_strength
  49. self.seed_resize_from_h: int = seed_resize_from_h
  50. self.seed_resize_from_w: int = seed_resize_from_w
  51. self.sampler_index: int = sampler_index
  52. self.batch_size: int = batch_size
  53. self.n_iter: int = n_iter
  54. self.steps: int = steps
  55. self.cfg_scale: float = cfg_scale
  56. self.width: int = width
  57. self.height: int = height
  58. self.restore_faces: bool = restore_faces
  59. self.tiling: bool = tiling
  60. self.do_not_save_samples: bool = do_not_save_samples
  61. self.do_not_save_grid: bool = do_not_save_grid
  62. self.extra_generation_params: dict = extra_generation_params or {}
  63. self.overlay_images = overlay_images
  64. self.paste_to = None
  65. self.color_corrections = None
  66. self.denoising_strength: float = 0
  67. if not seed_enable_extras:
  68. self.subseed = -1
  69. self.subseed_strength = 0
  70. self.seed_resize_from_h = 0
  71. self.seed_resize_from_w = 0
  72. def init(self, all_prompts, all_seeds, all_subseeds):
  73. pass
  74. def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength):
  75. raise NotImplementedError()
  76. class Processed:
  77. def __init__(self, p: StableDiffusionProcessing, images_list, seed=-1, info="", subseed=None, all_prompts=None, all_seeds=None, all_subseeds=None, index_of_first_image=0):
  78. self.images = images_list
  79. self.prompt = p.prompt
  80. self.negative_prompt = p.negative_prompt
  81. self.seed = seed
  82. self.subseed = subseed
  83. self.subseed_strength = p.subseed_strength
  84. self.info = info
  85. self.width = p.width
  86. self.height = p.height
  87. self.sampler_index = p.sampler_index
  88. self.sampler = samplers[p.sampler_index].name
  89. self.cfg_scale = p.cfg_scale
  90. self.steps = p.steps
  91. self.batch_size = p.batch_size
  92. self.restore_faces = p.restore_faces
  93. self.face_restoration_model = opts.face_restoration_model if p.restore_faces else None
  94. self.sd_model_hash = shared.sd_model.sd_model_hash
  95. self.seed_resize_from_w = p.seed_resize_from_w
  96. self.seed_resize_from_h = p.seed_resize_from_h
  97. self.denoising_strength = getattr(p, 'denoising_strength', None)
  98. self.extra_generation_params = p.extra_generation_params
  99. self.index_of_first_image = index_of_first_image
  100. self.prompt = self.prompt if type(self.prompt) != list else self.prompt[0]
  101. self.negative_prompt = self.negative_prompt if type(self.negative_prompt) != list else self.negative_prompt[0]
  102. self.seed = int(self.seed if type(self.seed) != list else self.seed[0])
  103. self.subseed = int(self.subseed if type(self.subseed) != list else self.subseed[0]) if self.subseed is not None else -1
  104. self.all_prompts = all_prompts or [self.prompt]
  105. self.all_seeds = all_seeds or [self.seed]
  106. self.all_subseeds = all_subseeds or [self.subseed]
  107. def js(self):
  108. obj = {
  109. "prompt": self.prompt,
  110. "all_prompts": self.all_prompts,
  111. "negative_prompt": self.negative_prompt,
  112. "seed": self.seed,
  113. "all_seeds": self.all_seeds,
  114. "subseed": self.subseed,
  115. "all_subseeds": self.all_subseeds,
  116. "subseed_strength": self.subseed_strength,
  117. "width": self.width,
  118. "height": self.height,
  119. "sampler_index": self.sampler_index,
  120. "sampler": self.sampler,
  121. "cfg_scale": self.cfg_scale,
  122. "steps": self.steps,
  123. "batch_size": self.batch_size,
  124. "restore_faces": self.restore_faces,
  125. "face_restoration_model": self.face_restoration_model,
  126. "sd_model_hash": self.sd_model_hash,
  127. "seed_resize_from_w": self.seed_resize_from_w,
  128. "seed_resize_from_h": self.seed_resize_from_h,
  129. "denoising_strength": self.denoising_strength,
  130. "extra_generation_params": self.extra_generation_params,
  131. "index_of_first_image": self.index_of_first_image,
  132. }
  133. return json.dumps(obj)
  134. def infotext(self, p: StableDiffusionProcessing, index):
  135. return create_infotext(p, self.all_prompts, self.all_seeds, self.all_subseeds, comments=[], position_in_batch=index % self.batch_size, iteration=index // self.batch_size)
  136. # from https://discuss.pytorch.org/t/help-regarding-slerp-function-for-generative-model-sampling/32475/3
  137. def slerp(val, low, high):
  138. low_norm = low/torch.norm(low, dim=1, keepdim=True)
  139. high_norm = high/torch.norm(high, dim=1, keepdim=True)
  140. dot = (low_norm*high_norm).sum(1)
  141. if dot.mean() > 0.9995:
  142. return low * val + high * (1 - val)
  143. omega = torch.acos(dot)
  144. so = torch.sin(omega)
  145. res = (torch.sin((1.0-val)*omega)/so).unsqueeze(1)*low + (torch.sin(val*omega)/so).unsqueeze(1) * high
  146. return res
  147. def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, seed_resize_from_h=0, seed_resize_from_w=0, p=None):
  148. xs = []
  149. # if we have multiple seeds, this means we are working with batch size>1; this then
  150. # enables the generation of additional tensors with noise that the sampler will use during its processing.
  151. # Using those pre-generated tensors instead of simple torch.randn allows a batch with seeds [100, 101] to
  152. # produce the same images as with two batches [100], [101].
  153. if p is not None and p.sampler is not None and len(seeds) > 1 and opts.enable_batch_seeds:
  154. sampler_noises = [[] for _ in range(p.sampler.number_of_needed_noises(p))]
  155. else:
  156. sampler_noises = None
  157. for i, seed in enumerate(seeds):
  158. noise_shape = shape if seed_resize_from_h <= 0 or seed_resize_from_w <= 0 else (shape[0], seed_resize_from_h//8, seed_resize_from_w//8)
  159. subnoise = None
  160. if subseeds is not None:
  161. subseed = 0 if i >= len(subseeds) else subseeds[i]
  162. subnoise = devices.randn(subseed, noise_shape)
  163. # randn results depend on device; gpu and cpu get different results for same seed;
  164. # the way I see it, it's better to do this on CPU, so that everyone gets same result;
  165. # but the original script had it like this, so I do not dare change it for now because
  166. # it will break everyone's seeds.
  167. noise = devices.randn(seed, noise_shape)
  168. if subnoise is not None:
  169. noise = slerp(subseed_strength, noise, subnoise)
  170. if noise_shape != shape:
  171. x = devices.randn(seed, shape)
  172. dx = (shape[2] - noise_shape[2]) // 2
  173. dy = (shape[1] - noise_shape[1]) // 2
  174. w = noise_shape[2] if dx >= 0 else noise_shape[2] + 2 * dx
  175. h = noise_shape[1] if dy >= 0 else noise_shape[1] + 2 * dy
  176. tx = 0 if dx < 0 else dx
  177. ty = 0 if dy < 0 else dy
  178. dx = max(-dx, 0)
  179. dy = max(-dy, 0)
  180. x[:, ty:ty+h, tx:tx+w] = noise[:, dy:dy+h, dx:dx+w]
  181. noise = x
  182. if sampler_noises is not None:
  183. cnt = p.sampler.number_of_needed_noises(p)
  184. for j in range(cnt):
  185. sampler_noises[j].append(devices.randn_without_seed(tuple(noise_shape)))
  186. xs.append(noise)
  187. if sampler_noises is not None:
  188. p.sampler.sampler_noises = [torch.stack(n).to(shared.device) for n in sampler_noises]
  189. x = torch.stack(xs).to(shared.device)
  190. return x
  191. def fix_seed(p):
  192. p.seed = int(random.randrange(4294967294)) if p.seed is None or p.seed == '' or p.seed == -1 else p.seed
  193. p.subseed = int(random.randrange(4294967294)) if p.subseed is None or p.subseed == '' or p.subseed == -1 else p.subseed
  194. def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration=0, position_in_batch=0):
  195. index = position_in_batch + iteration * p.batch_size
  196. generation_params = {
  197. "Steps": p.steps,
  198. "Sampler": samplers[p.sampler_index].name,
  199. "CFG scale": p.cfg_scale,
  200. "Seed": all_seeds[index],
  201. "Face restoration": (opts.face_restoration_model if p.restore_faces else None),
  202. "Size": f"{p.width}x{p.height}",
  203. "Model hash": getattr(p, 'sd_model_hash', None if not opts.add_model_hash_to_info or not shared.sd_model.sd_model_hash else shared.sd_model.sd_model_hash),
  204. "Batch size": (None if p.batch_size < 2 else p.batch_size),
  205. "Batch pos": (None if p.batch_size < 2 else position_in_batch),
  206. "Variation seed": (None if p.subseed_strength == 0 else all_subseeds[index]),
  207. "Variation seed strength": (None if p.subseed_strength == 0 else p.subseed_strength),
  208. "Seed resize from": (None if p.seed_resize_from_w == 0 or p.seed_resize_from_h == 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"),
  209. "Denoising strength": getattr(p, 'denoising_strength', None),
  210. }
  211. generation_params.update(p.extra_generation_params)
  212. generation_params_text = ", ".join([k if k == v else f'{k}: {v}' for k, v in generation_params.items() if v is not None])
  213. negative_prompt_text = "\nNegative prompt: " + p.negative_prompt if p.negative_prompt else ""
  214. return f"{all_prompts[index]}{negative_prompt_text}\n{generation_params_text}".strip() + "".join(["\n\n" + x for x in comments])
  215. def process_images(p: StableDiffusionProcessing) -> Processed:
  216. """this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch"""
  217. if type(p.prompt) == list:
  218. assert(len(p.prompt) > 0)
  219. else:
  220. assert p.prompt is not None
  221. devices.torch_gc()
  222. fix_seed(p)
  223. os.makedirs(p.outpath_samples, exist_ok=True)
  224. os.makedirs(p.outpath_grids, exist_ok=True)
  225. modules.sd_hijack.model_hijack.apply_circular(p.tiling)
  226. comments = {}
  227. shared.prompt_styles.apply_styles(p)
  228. if type(p.prompt) == list:
  229. all_prompts = p.prompt
  230. else:
  231. all_prompts = p.batch_size * p.n_iter * [p.prompt]
  232. if type(p.seed) == list:
  233. all_seeds = p.seed
  234. else:
  235. all_seeds = [int(p.seed) + (x if p.subseed_strength == 0 else 0) for x in range(len(all_prompts))]
  236. if type(p.subseed) == list:
  237. all_subseeds = p.subseed
  238. else:
  239. all_subseeds = [int(p.subseed) + x for x in range(len(all_prompts))]
  240. def infotext(iteration=0, position_in_batch=0):
  241. return create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration, position_in_batch)
  242. if os.path.exists(cmd_opts.embeddings_dir):
  243. model_hijack.load_textual_inversion_embeddings(cmd_opts.embeddings_dir, p.sd_model)
  244. output_images = []
  245. precision_scope = torch.autocast if cmd_opts.precision == "autocast" else contextlib.nullcontext
  246. ema_scope = (contextlib.nullcontext if cmd_opts.lowvram else p.sd_model.ema_scope)
  247. with torch.no_grad(), precision_scope("cuda"), ema_scope():
  248. p.init(all_prompts, all_seeds, all_subseeds)
  249. if state.job_count == -1:
  250. state.job_count = p.n_iter
  251. for n in range(p.n_iter):
  252. if state.interrupted:
  253. break
  254. prompts = all_prompts[n * p.batch_size:(n + 1) * p.batch_size]
  255. seeds = all_seeds[n * p.batch_size:(n + 1) * p.batch_size]
  256. subseeds = all_subseeds[n * p.batch_size:(n + 1) * p.batch_size]
  257. if (len(prompts) == 0):
  258. break
  259. #uc = p.sd_model.get_learned_conditioning(len(prompts) * [p.negative_prompt])
  260. #c = p.sd_model.get_learned_conditioning(prompts)
  261. uc = prompt_parser.get_learned_conditioning(len(prompts) * [p.negative_prompt], p.steps)
  262. c = prompt_parser.get_learned_conditioning(prompts, p.steps)
  263. if len(model_hijack.comments) > 0:
  264. for comment in model_hijack.comments:
  265. comments[comment] = 1
  266. if p.n_iter > 1:
  267. shared.state.job = f"Batch {n+1} out of {p.n_iter}"
  268. samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength)
  269. if state.interrupted:
  270. # if we are interruped, sample returns just noise
  271. # use the image collected previously in sampler loop
  272. samples_ddim = shared.state.current_latent
  273. x_samples_ddim = p.sd_model.decode_first_stage(samples_ddim)
  274. x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
  275. if opts.filter_nsfw:
  276. import modules.safety as safety
  277. x_samples_ddim = modules.safety.censor_batch(x_samples_ddim)
  278. for i, x_sample in enumerate(x_samples_ddim):
  279. x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
  280. x_sample = x_sample.astype(np.uint8)
  281. if p.restore_faces:
  282. if opts.save and not p.do_not_save_samples and opts.save_images_before_face_restoration:
  283. images.save_image(Image.fromarray(x_sample), p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p)
  284. devices.torch_gc()
  285. x_sample = modules.face_restoration.restore_faces(x_sample)
  286. image = Image.fromarray(x_sample)
  287. if p.color_corrections is not None and i < len(p.color_corrections):
  288. image = apply_color_correction(p.color_corrections[i], image)
  289. if p.overlay_images is not None and i < len(p.overlay_images):
  290. overlay = p.overlay_images[i]
  291. if p.paste_to is not None:
  292. x, y, w, h = p.paste_to
  293. base_image = Image.new('RGBA', (overlay.width, overlay.height))
  294. image = images.resize_image(1, image, w, h)
  295. base_image.paste(image, (x, y))
  296. image = base_image
  297. image = image.convert('RGBA')
  298. image.alpha_composite(overlay)
  299. image = image.convert('RGB')
  300. if opts.samples_save and not p.do_not_save_samples:
  301. images.save_image(image, p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p)
  302. output_images.append(image)
  303. state.nextjob()
  304. p.color_corrections = None
  305. index_of_first_image = 0
  306. unwanted_grid_because_of_img_count = len(output_images) < 2 and opts.grid_only_if_multiple
  307. if (opts.return_grid or opts.grid_save) and not p.do_not_save_grid and not unwanted_grid_because_of_img_count:
  308. grid = images.image_grid(output_images, p.batch_size)
  309. if opts.return_grid:
  310. output_images.insert(0, grid)
  311. index_of_first_image = 1
  312. if opts.grid_save:
  313. images.save_image(grid, p.outpath_grids, "grid", all_seeds[0], all_prompts[0], opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename, p=p)
  314. devices.torch_gc()
  315. return Processed(p, output_images, all_seeds[0], infotext(), subseed=all_subseeds[0], all_prompts=all_prompts, all_seeds=all_seeds, all_subseeds=all_subseeds, index_of_first_image=index_of_first_image)
  316. class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
  317. sampler = None
  318. firstphase_width = 0
  319. firstphase_height = 0
  320. firstphase_width_truncated = 0
  321. firstphase_height_truncated = 0
  322. def __init__(self, enable_hr=False, scale_latent=True, denoising_strength=0.75, **kwargs):
  323. super().__init__(**kwargs)
  324. self.enable_hr = enable_hr
  325. self.scale_latent = scale_latent
  326. self.denoising_strength = denoising_strength
  327. def init(self, all_prompts, all_seeds, all_subseeds):
  328. if self.enable_hr:
  329. if state.job_count == -1:
  330. state.job_count = self.n_iter * 2
  331. else:
  332. state.job_count = state.job_count * 2
  333. desired_pixel_count = 512 * 512
  334. actual_pixel_count = self.width * self.height
  335. scale = math.sqrt(desired_pixel_count / actual_pixel_count)
  336. self.firstphase_width = math.ceil(scale * self.width / 64) * 64
  337. self.firstphase_height = math.ceil(scale * self.height / 64) * 64
  338. self.firstphase_width_truncated = int(scale * self.width)
  339. self.firstphase_height_truncated = int(scale * self.height)
  340. def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength):
  341. self.sampler = samplers[self.sampler_index].constructor(self.sd_model)
  342. if not self.enable_hr:
  343. x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
  344. samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning)
  345. return samples
  346. x = create_random_tensors([opt_C, self.firstphase_height // opt_f, self.firstphase_width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
  347. samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning)
  348. truncate_x = (self.firstphase_width - self.firstphase_width_truncated) // opt_f
  349. truncate_y = (self.firstphase_height - self.firstphase_height_truncated) // opt_f
  350. samples = samples[:, :, truncate_y//2:samples.shape[2]-truncate_y//2, truncate_x//2:samples.shape[3]-truncate_x//2]
  351. if self.scale_latent:
  352. samples = torch.nn.functional.interpolate(samples, size=(self.height // opt_f, self.width // opt_f), mode="bilinear")
  353. else:
  354. decoded_samples = self.sd_model.decode_first_stage(samples)
  355. if opts.upscaler_for_hires_fix is None or opts.upscaler_for_hires_fix == "None":
  356. decoded_samples = torch.nn.functional.interpolate(decoded_samples, size=(self.height, self.width), mode="bilinear")
  357. else:
  358. lowres_samples = torch.clamp((decoded_samples + 1.0) / 2.0, min=0.0, max=1.0)
  359. batch_images = []
  360. for i, x_sample in enumerate(lowres_samples):
  361. x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
  362. x_sample = x_sample.astype(np.uint8)
  363. image = Image.fromarray(x_sample)
  364. upscaler = [x for x in shared.sd_upscalers if x.name == opts.upscaler_for_hires_fix][0]
  365. image = upscaler.upscale(image, self.width, self.height)
  366. image = np.array(image).astype(np.float32) / 255.0
  367. image = np.moveaxis(image, 2, 0)
  368. batch_images.append(image)
  369. decoded_samples = torch.from_numpy(np.array(batch_images))
  370. decoded_samples = decoded_samples.to(shared.device)
  371. decoded_samples = 2. * decoded_samples - 1.
  372. samples = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(decoded_samples))
  373. shared.state.nextjob()
  374. self.sampler = samplers[self.sampler_index].constructor(self.sd_model)
  375. noise = create_random_tensors(samples.shape[1:], seeds=seeds, subseeds=subseeds, subseed_strength=subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
  376. samples = self.sampler.sample_img2img(self, samples, noise, conditioning, unconditional_conditioning, steps=self.steps)
  377. return samples
  378. class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
  379. sampler = None
  380. def __init__(self, init_images=None, resize_mode=0, denoising_strength=0.75, mask=None, mask_blur=4, inpainting_fill=0, inpaint_full_res=True, inpaint_full_res_padding=0, inpainting_mask_invert=0, **kwargs):
  381. super().__init__(**kwargs)
  382. self.init_images = init_images
  383. self.resize_mode: int = resize_mode
  384. self.denoising_strength: float = denoising_strength
  385. self.init_latent = None
  386. self.image_mask = mask
  387. #self.image_unblurred_mask = None
  388. self.latent_mask = None
  389. self.mask_for_overlay = None
  390. self.mask_blur = mask_blur
  391. self.inpainting_fill = inpainting_fill
  392. self.inpaint_full_res = inpaint_full_res
  393. self.inpaint_full_res_padding = inpaint_full_res_padding
  394. self.inpainting_mask_invert = inpainting_mask_invert
  395. self.mask = None
  396. self.nmask = None
  397. def init(self, all_prompts, all_seeds, all_subseeds):
  398. self.sampler = samplers_for_img2img[self.sampler_index].constructor(self.sd_model)
  399. crop_region = None
  400. if self.image_mask is not None:
  401. self.image_mask = self.image_mask.convert('L')
  402. if self.inpainting_mask_invert:
  403. self.image_mask = ImageOps.invert(self.image_mask)
  404. #self.image_unblurred_mask = self.image_mask
  405. if self.mask_blur > 0:
  406. self.image_mask = self.image_mask.filter(ImageFilter.GaussianBlur(self.mask_blur))
  407. if self.inpaint_full_res:
  408. self.mask_for_overlay = self.image_mask
  409. mask = self.image_mask.convert('L')
  410. crop_region = masking.get_crop_region(np.array(mask), self.inpaint_full_res_padding)
  411. crop_region = masking.expand_crop_region(crop_region, self.width, self.height, mask.width, mask.height)
  412. x1, y1, x2, y2 = crop_region
  413. mask = mask.crop(crop_region)
  414. self.image_mask = images.resize_image(2, mask, self.width, self.height)
  415. self.paste_to = (x1, y1, x2-x1, y2-y1)
  416. else:
  417. self.image_mask = images.resize_image(self.resize_mode, self.image_mask, self.width, self.height)
  418. np_mask = np.array(self.image_mask)
  419. np_mask = np.clip((np_mask.astype(np.float32)) * 2, 0, 255).astype(np.uint8)
  420. self.mask_for_overlay = Image.fromarray(np_mask)
  421. self.overlay_images = []
  422. latent_mask = self.latent_mask if self.latent_mask is not None else self.image_mask
  423. add_color_corrections = opts.img2img_color_correction and self.color_corrections is None
  424. if add_color_corrections:
  425. self.color_corrections = []
  426. imgs = []
  427. for img in self.init_images:
  428. image = img.convert("RGB")
  429. if crop_region is None:
  430. image = images.resize_image(self.resize_mode, image, self.width, self.height)
  431. if self.image_mask is not None:
  432. image_masked = Image.new('RGBa', (image.width, image.height))
  433. image_masked.paste(image.convert("RGBA").convert("RGBa"), mask=ImageOps.invert(self.mask_for_overlay.convert('L')))
  434. self.overlay_images.append(image_masked.convert('RGBA'))
  435. if crop_region is not None:
  436. image = image.crop(crop_region)
  437. image = images.resize_image(2, image, self.width, self.height)
  438. if self.image_mask is not None:
  439. if self.inpainting_fill != 1:
  440. image = masking.fill(image, latent_mask)
  441. if add_color_corrections:
  442. self.color_corrections.append(setup_color_correction(image))
  443. image = np.array(image).astype(np.float32) / 255.0
  444. image = np.moveaxis(image, 2, 0)
  445. imgs.append(image)
  446. if len(imgs) == 1:
  447. batch_images = np.expand_dims(imgs[0], axis=0).repeat(self.batch_size, axis=0)
  448. if self.overlay_images is not None:
  449. self.overlay_images = self.overlay_images * self.batch_size
  450. elif len(imgs) <= self.batch_size:
  451. self.batch_size = len(imgs)
  452. batch_images = np.array(imgs)
  453. else:
  454. raise RuntimeError(f"bad number of images passed: {len(imgs)}; expecting {self.batch_size} or less")
  455. image = torch.from_numpy(batch_images)
  456. image = 2. * image - 1.
  457. image = image.to(shared.device)
  458. self.init_latent = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(image))
  459. if self.image_mask is not None:
  460. init_mask = latent_mask
  461. latmask = init_mask.convert('RGB').resize((self.init_latent.shape[3], self.init_latent.shape[2]))
  462. latmask = np.moveaxis(np.array(latmask, dtype=np.float32), 2, 0) / 255
  463. latmask = latmask[0]
  464. latmask = np.around(latmask)
  465. latmask = np.tile(latmask[None], (4, 1, 1))
  466. self.mask = torch.asarray(1.0 - latmask).to(shared.device).type(self.sd_model.dtype)
  467. self.nmask = torch.asarray(latmask).to(shared.device).type(self.sd_model.dtype)
  468. # this needs to be fixed to be done in sample() using actual seeds for batches
  469. if self.inpainting_fill == 2:
  470. self.init_latent = self.init_latent * self.mask + create_random_tensors(self.init_latent.shape[1:], all_seeds[0:self.init_latent.shape[0]]) * self.nmask
  471. elif self.inpainting_fill == 3:
  472. self.init_latent = self.init_latent * self.mask
  473. def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength):
  474. x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
  475. samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning)
  476. if self.mask is not None:
  477. samples = samples * self.nmask + self.init_latent * self.mask
  478. return samples