processing.py 30 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653
  1. import contextlib
  2. import json
  3. import math
  4. import os
  5. import sys
  6. import torch
  7. import numpy as np
  8. from PIL import Image, ImageFilter, ImageOps
  9. import random
  10. import cv2
  11. from skimage import exposure
  12. import modules.sd_hijack
  13. from modules import devices, prompt_parser, masking
  14. from modules.sd_hijack import model_hijack
  15. from modules.sd_samplers import samplers, samplers_for_img2img
  16. from modules.shared import opts, cmd_opts, state
  17. import modules.shared as shared
  18. import modules.face_restoration
  19. import modules.images as images
  20. import modules.styles
  21. import logging
  22. # some of those options should not be changed at all because they would break the model, so I removed them from options.
  23. opt_C = 4
  24. opt_f = 8
  25. def setup_color_correction(image):
  26. logging.info("Calibrating color correction.")
  27. correction_target = cv2.cvtColor(np.asarray(image.copy()), cv2.COLOR_RGB2LAB)
  28. return correction_target
  29. def apply_color_correction(correction, image):
  30. logging.info("Applying color correction.")
  31. image = Image.fromarray(cv2.cvtColor(exposure.match_histograms(
  32. cv2.cvtColor(
  33. np.asarray(image),
  34. cv2.COLOR_RGB2LAB
  35. ),
  36. correction,
  37. channel_axis=2
  38. ), cv2.COLOR_LAB2RGB).astype("uint8"))
  39. return image
  40. class StableDiffusionProcessing:
  41. def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt="", styles=None, seed=-1, subseed=-1, subseed_strength=0, seed_resize_from_h=-1, seed_resize_from_w=-1, seed_enable_extras=True, sampler_index=0, batch_size=1, n_iter=1, steps=50, cfg_scale=7.0, width=512, height=512, restore_faces=False, tiling=False, do_not_save_samples=False, do_not_save_grid=False, extra_generation_params=None, overlay_images=None, negative_prompt=None, eta=None):
  42. self.sd_model = sd_model
  43. self.outpath_samples: str = outpath_samples
  44. self.outpath_grids: str = outpath_grids
  45. self.prompt: str = prompt
  46. self.prompt_for_display: str = None
  47. self.negative_prompt: str = (negative_prompt or "")
  48. self.styles: list = styles or []
  49. self.seed: int = seed
  50. self.subseed: int = subseed
  51. self.subseed_strength: float = subseed_strength
  52. self.seed_resize_from_h: int = seed_resize_from_h
  53. self.seed_resize_from_w: int = seed_resize_from_w
  54. self.sampler_index: int = sampler_index
  55. self.batch_size: int = batch_size
  56. self.n_iter: int = n_iter
  57. self.steps: int = steps
  58. self.cfg_scale: float = cfg_scale
  59. self.width: int = width
  60. self.height: int = height
  61. self.restore_faces: bool = restore_faces
  62. self.tiling: bool = tiling
  63. self.do_not_save_samples: bool = do_not_save_samples
  64. self.do_not_save_grid: bool = do_not_save_grid
  65. self.extra_generation_params: dict = extra_generation_params or {}
  66. self.overlay_images = overlay_images
  67. self.eta = eta
  68. self.paste_to = None
  69. self.color_corrections = None
  70. self.denoising_strength: float = 0
  71. self.ddim_discretize = opts.ddim_discretize
  72. self.s_churn = opts.s_churn
  73. self.s_tmin = opts.s_tmin
  74. self.s_tmax = float('inf') # not representable as a standard ui option
  75. self.s_noise = opts.s_noise
  76. if not seed_enable_extras:
  77. self.subseed = -1
  78. self.subseed_strength = 0
  79. self.seed_resize_from_h = 0
  80. self.seed_resize_from_w = 0
  81. def init(self, all_prompts, all_seeds, all_subseeds):
  82. pass
  83. def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength):
  84. raise NotImplementedError()
  85. class Processed:
  86. def __init__(self, p: StableDiffusionProcessing, images_list, seed=-1, info="", subseed=None, all_prompts=None, all_seeds=None, all_subseeds=None, index_of_first_image=0, infotexts=None):
  87. self.images = images_list
  88. self.prompt = p.prompt
  89. self.negative_prompt = p.negative_prompt
  90. self.seed = seed
  91. self.subseed = subseed
  92. self.subseed_strength = p.subseed_strength
  93. self.info = info
  94. self.width = p.width
  95. self.height = p.height
  96. self.sampler_index = p.sampler_index
  97. self.sampler = samplers[p.sampler_index].name
  98. self.cfg_scale = p.cfg_scale
  99. self.steps = p.steps
  100. self.batch_size = p.batch_size
  101. self.restore_faces = p.restore_faces
  102. self.face_restoration_model = opts.face_restoration_model if p.restore_faces else None
  103. self.sd_model_hash = shared.sd_model.sd_model_hash
  104. self.seed_resize_from_w = p.seed_resize_from_w
  105. self.seed_resize_from_h = p.seed_resize_from_h
  106. self.denoising_strength = getattr(p, 'denoising_strength', None)
  107. self.extra_generation_params = p.extra_generation_params
  108. self.index_of_first_image = index_of_first_image
  109. self.eta = p.eta
  110. self.ddim_discretize = p.ddim_discretize
  111. self.s_churn = p.s_churn
  112. self.s_tmin = p.s_tmin
  113. self.s_tmax = p.s_tmax
  114. self.s_noise = p.s_noise
  115. self.prompt = self.prompt if type(self.prompt) != list else self.prompt[0]
  116. self.negative_prompt = self.negative_prompt if type(self.negative_prompt) != list else self.negative_prompt[0]
  117. self.seed = int(self.seed if type(self.seed) != list else self.seed[0])
  118. self.subseed = int(self.subseed if type(self.subseed) != list else self.subseed[0]) if self.subseed is not None else -1
  119. self.all_prompts = all_prompts or [self.prompt]
  120. self.all_seeds = all_seeds or [self.seed]
  121. self.all_subseeds = all_subseeds or [self.subseed]
  122. self.infotexts = infotexts or [info]
  123. def js(self):
  124. obj = {
  125. "prompt": self.prompt,
  126. "all_prompts": self.all_prompts,
  127. "negative_prompt": self.negative_prompt,
  128. "seed": self.seed,
  129. "all_seeds": self.all_seeds,
  130. "subseed": self.subseed,
  131. "all_subseeds": self.all_subseeds,
  132. "subseed_strength": self.subseed_strength,
  133. "width": self.width,
  134. "height": self.height,
  135. "sampler_index": self.sampler_index,
  136. "sampler": self.sampler,
  137. "cfg_scale": self.cfg_scale,
  138. "steps": self.steps,
  139. "batch_size": self.batch_size,
  140. "restore_faces": self.restore_faces,
  141. "face_restoration_model": self.face_restoration_model,
  142. "sd_model_hash": self.sd_model_hash,
  143. "seed_resize_from_w": self.seed_resize_from_w,
  144. "seed_resize_from_h": self.seed_resize_from_h,
  145. "denoising_strength": self.denoising_strength,
  146. "extra_generation_params": self.extra_generation_params,
  147. "index_of_first_image": self.index_of_first_image,
  148. "infotexts": self.infotexts,
  149. }
  150. return json.dumps(obj)
  151. def infotext(self, p: StableDiffusionProcessing, index):
  152. return create_infotext(p, self.all_prompts, self.all_seeds, self.all_subseeds, comments=[], position_in_batch=index % self.batch_size, iteration=index // self.batch_size)
  153. # from https://discuss.pytorch.org/t/help-regarding-slerp-function-for-generative-model-sampling/32475/3
  154. def slerp(val, low, high):
  155. low_norm = low/torch.norm(low, dim=1, keepdim=True)
  156. high_norm = high/torch.norm(high, dim=1, keepdim=True)
  157. dot = (low_norm*high_norm).sum(1)
  158. if dot.mean() > 0.9995:
  159. return low * val + high * (1 - val)
  160. omega = torch.acos(dot)
  161. so = torch.sin(omega)
  162. res = (torch.sin((1.0-val)*omega)/so).unsqueeze(1)*low + (torch.sin(val*omega)/so).unsqueeze(1) * high
  163. return res
  164. def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, seed_resize_from_h=0, seed_resize_from_w=0, p=None):
  165. xs = []
  166. # if we have multiple seeds, this means we are working with batch size>1; this then
  167. # enables the generation of additional tensors with noise that the sampler will use during its processing.
  168. # Using those pre-generated tensors instead of simple torch.randn allows a batch with seeds [100, 101] to
  169. # produce the same images as with two batches [100], [101].
  170. if p is not None and p.sampler is not None and len(seeds) > 1 and opts.enable_batch_seeds:
  171. sampler_noises = [[] for _ in range(p.sampler.number_of_needed_noises(p))]
  172. else:
  173. sampler_noises = None
  174. for i, seed in enumerate(seeds):
  175. noise_shape = shape if seed_resize_from_h <= 0 or seed_resize_from_w <= 0 else (shape[0], seed_resize_from_h//8, seed_resize_from_w//8)
  176. subnoise = None
  177. if subseeds is not None:
  178. subseed = 0 if i >= len(subseeds) else subseeds[i]
  179. subnoise = devices.randn(subseed, noise_shape)
  180. # randn results depend on device; gpu and cpu get different results for same seed;
  181. # the way I see it, it's better to do this on CPU, so that everyone gets same result;
  182. # but the original script had it like this, so I do not dare change it for now because
  183. # it will break everyone's seeds.
  184. noise = devices.randn(seed, noise_shape)
  185. if subnoise is not None:
  186. noise = slerp(subseed_strength, noise, subnoise)
  187. if noise_shape != shape:
  188. x = devices.randn(seed, shape)
  189. dx = (shape[2] - noise_shape[2]) // 2
  190. dy = (shape[1] - noise_shape[1]) // 2
  191. w = noise_shape[2] if dx >= 0 else noise_shape[2] + 2 * dx
  192. h = noise_shape[1] if dy >= 0 else noise_shape[1] + 2 * dy
  193. tx = 0 if dx < 0 else dx
  194. ty = 0 if dy < 0 else dy
  195. dx = max(-dx, 0)
  196. dy = max(-dy, 0)
  197. x[:, ty:ty+h, tx:tx+w] = noise[:, dy:dy+h, dx:dx+w]
  198. noise = x
  199. if sampler_noises is not None:
  200. cnt = p.sampler.number_of_needed_noises(p)
  201. for j in range(cnt):
  202. sampler_noises[j].append(devices.randn_without_seed(tuple(noise_shape)))
  203. xs.append(noise)
  204. if sampler_noises is not None:
  205. p.sampler.sampler_noises = [torch.stack(n).to(shared.device) for n in sampler_noises]
  206. x = torch.stack(xs).to(shared.device)
  207. return x
  208. def fix_seed(p):
  209. p.seed = int(random.randrange(4294967294)) if p.seed is None or p.seed == '' or p.seed == -1 else p.seed
  210. p.subseed = int(random.randrange(4294967294)) if p.subseed is None or p.subseed == '' or p.subseed == -1 else p.subseed
  211. def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration=0, position_in_batch=0):
  212. index = position_in_batch + iteration * p.batch_size
  213. generation_params = {
  214. "Steps": p.steps,
  215. "Sampler": samplers[p.sampler_index].name,
  216. "CFG scale": p.cfg_scale,
  217. "Seed": all_seeds[index],
  218. "Face restoration": (opts.face_restoration_model if p.restore_faces else None),
  219. "Size": f"{p.width}x{p.height}",
  220. "Model hash": getattr(p, 'sd_model_hash', None if not opts.add_model_hash_to_info or not shared.sd_model.sd_model_hash else shared.sd_model.sd_model_hash),
  221. "Batch size": (None if p.batch_size < 2 else p.batch_size),
  222. "Batch pos": (None if p.batch_size < 2 else position_in_batch),
  223. "Variation seed": (None if p.subseed_strength == 0 else all_subseeds[index]),
  224. "Variation seed strength": (None if p.subseed_strength == 0 else p.subseed_strength),
  225. "Seed resize from": (None if p.seed_resize_from_w == 0 or p.seed_resize_from_h == 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"),
  226. "Denoising strength": getattr(p, 'denoising_strength', None),
  227. "Eta": (None if p.sampler is None or p.sampler.eta == p.sampler.default_eta else p.sampler.eta),
  228. }
  229. generation_params.update(p.extra_generation_params)
  230. generation_params_text = ", ".join([k if k == v else f'{k}: {v}' for k, v in generation_params.items() if v is not None])
  231. negative_prompt_text = "\nNegative prompt: " + p.negative_prompt if p.negative_prompt else ""
  232. return f"{all_prompts[index]}{negative_prompt_text}\n{generation_params_text}".strip()
  233. def process_images(p: StableDiffusionProcessing) -> Processed:
  234. """this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch"""
  235. if type(p.prompt) == list:
  236. assert(len(p.prompt) > 0)
  237. else:
  238. assert p.prompt is not None
  239. devices.torch_gc()
  240. fix_seed(p)
  241. if p.outpath_samples is not None:
  242. os.makedirs(p.outpath_samples, exist_ok=True)
  243. if p.outpath_grids is not None:
  244. os.makedirs(p.outpath_grids, exist_ok=True)
  245. modules.sd_hijack.model_hijack.apply_circular(p.tiling)
  246. comments = {}
  247. shared.prompt_styles.apply_styles(p)
  248. if type(p.prompt) == list:
  249. all_prompts = p.prompt
  250. else:
  251. all_prompts = p.batch_size * p.n_iter * [p.prompt]
  252. if type(p.seed) == list:
  253. all_seeds = p.seed
  254. else:
  255. all_seeds = [int(p.seed) + (x if p.subseed_strength == 0 else 0) for x in range(len(all_prompts))]
  256. if type(p.subseed) == list:
  257. all_subseeds = p.subseed
  258. else:
  259. all_subseeds = [int(p.subseed) + x for x in range(len(all_prompts))]
  260. def infotext(iteration=0, position_in_batch=0):
  261. return create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration, position_in_batch)
  262. if os.path.exists(cmd_opts.embeddings_dir):
  263. model_hijack.embedding_db.load_textual_inversion_embeddings()
  264. infotexts = []
  265. output_images = []
  266. precision_scope = torch.autocast if cmd_opts.precision == "autocast" else contextlib.nullcontext
  267. ema_scope = (contextlib.nullcontext if cmd_opts.lowvram else p.sd_model.ema_scope)
  268. with torch.no_grad(), precision_scope("cuda"), ema_scope():
  269. p.init(all_prompts, all_seeds, all_subseeds)
  270. if state.job_count == -1:
  271. state.job_count = p.n_iter
  272. for n in range(p.n_iter):
  273. if state.interrupted:
  274. break
  275. prompts = all_prompts[n * p.batch_size:(n + 1) * p.batch_size]
  276. seeds = all_seeds[n * p.batch_size:(n + 1) * p.batch_size]
  277. subseeds = all_subseeds[n * p.batch_size:(n + 1) * p.batch_size]
  278. if (len(prompts) == 0):
  279. break
  280. #uc = p.sd_model.get_learned_conditioning(len(prompts) * [p.negative_prompt])
  281. #c = p.sd_model.get_learned_conditioning(prompts)
  282. uc = prompt_parser.get_learned_conditioning(len(prompts) * [p.negative_prompt], p.steps)
  283. c = prompt_parser.get_learned_conditioning(prompts, p.steps)
  284. if len(model_hijack.comments) > 0:
  285. for comment in model_hijack.comments:
  286. comments[comment] = 1
  287. if p.n_iter > 1:
  288. shared.state.job = f"Batch {n+1} out of {p.n_iter}"
  289. samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength)
  290. if state.interrupted:
  291. # if we are interruped, sample returns just noise
  292. # use the image collected previously in sampler loop
  293. samples_ddim = shared.state.current_latent
  294. x_samples_ddim = p.sd_model.decode_first_stage(samples_ddim)
  295. x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
  296. if opts.filter_nsfw:
  297. import modules.safety as safety
  298. x_samples_ddim = modules.safety.censor_batch(x_samples_ddim)
  299. for i, x_sample in enumerate(x_samples_ddim):
  300. x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
  301. x_sample = x_sample.astype(np.uint8)
  302. if p.restore_faces:
  303. if opts.save and not p.do_not_save_samples and opts.save_images_before_face_restoration:
  304. images.save_image(Image.fromarray(x_sample), p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p, suffix="-before-face-restoration")
  305. devices.torch_gc()
  306. x_sample = modules.face_restoration.restore_faces(x_sample)
  307. image = Image.fromarray(x_sample)
  308. if p.color_corrections is not None and i < len(p.color_corrections):
  309. if opts.save and not p.do_not_save_samples and opts.save_images_before_color_correction:
  310. images.save_image(image, p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p, suffix="-before-color-correction")
  311. image = apply_color_correction(p.color_corrections[i], image)
  312. if p.overlay_images is not None and i < len(p.overlay_images):
  313. overlay = p.overlay_images[i]
  314. if p.paste_to is not None:
  315. x, y, w, h = p.paste_to
  316. base_image = Image.new('RGBA', (overlay.width, overlay.height))
  317. image = images.resize_image(1, image, w, h)
  318. base_image.paste(image, (x, y))
  319. image = base_image
  320. image = image.convert('RGBA')
  321. image.alpha_composite(overlay)
  322. image = image.convert('RGB')
  323. if opts.samples_save and not p.do_not_save_samples:
  324. images.save_image(image, p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p)
  325. infotexts.append(infotext(n, i))
  326. output_images.append(image)
  327. state.nextjob()
  328. p.color_corrections = None
  329. index_of_first_image = 0
  330. unwanted_grid_because_of_img_count = len(output_images) < 2 and opts.grid_only_if_multiple
  331. if (opts.return_grid or opts.grid_save) and not p.do_not_save_grid and not unwanted_grid_because_of_img_count:
  332. grid = images.image_grid(output_images, p.batch_size)
  333. if opts.return_grid:
  334. infotexts.insert(0, infotext())
  335. output_images.insert(0, grid)
  336. index_of_first_image = 1
  337. if opts.grid_save:
  338. images.save_image(grid, p.outpath_grids, "grid", all_seeds[0], all_prompts[0], opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename, p=p, grid=True)
  339. devices.torch_gc()
  340. return Processed(p, output_images, all_seeds[0], infotext() + "".join(["\n\n" + x for x in comments]), subseed=all_subseeds[0], all_prompts=all_prompts, all_seeds=all_seeds, all_subseeds=all_subseeds, index_of_first_image=index_of_first_image, infotexts=infotexts)
  341. class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
  342. sampler = None
  343. firstphase_width = 0
  344. firstphase_height = 0
  345. firstphase_width_truncated = 0
  346. firstphase_height_truncated = 0
  347. def __init__(self, enable_hr=False, scale_latent=True, denoising_strength=0.75, **kwargs):
  348. super().__init__(**kwargs)
  349. self.enable_hr = enable_hr
  350. self.scale_latent = scale_latent
  351. self.denoising_strength = denoising_strength
  352. def init(self, all_prompts, all_seeds, all_subseeds):
  353. if self.enable_hr:
  354. if state.job_count == -1:
  355. state.job_count = self.n_iter * 2
  356. else:
  357. state.job_count = state.job_count * 2
  358. desired_pixel_count = 512 * 512
  359. actual_pixel_count = self.width * self.height
  360. scale = math.sqrt(desired_pixel_count / actual_pixel_count)
  361. self.firstphase_width = math.ceil(scale * self.width / 64) * 64
  362. self.firstphase_height = math.ceil(scale * self.height / 64) * 64
  363. self.firstphase_width_truncated = int(scale * self.width)
  364. self.firstphase_height_truncated = int(scale * self.height)
  365. def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength):
  366. self.sampler = samplers[self.sampler_index].constructor(self.sd_model)
  367. if not self.enable_hr:
  368. x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
  369. samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning)
  370. return samples
  371. x = create_random_tensors([opt_C, self.firstphase_height // opt_f, self.firstphase_width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
  372. samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning)
  373. truncate_x = (self.firstphase_width - self.firstphase_width_truncated) // opt_f
  374. truncate_y = (self.firstphase_height - self.firstphase_height_truncated) // opt_f
  375. samples = samples[:, :, truncate_y//2:samples.shape[2]-truncate_y//2, truncate_x//2:samples.shape[3]-truncate_x//2]
  376. if self.scale_latent:
  377. samples = torch.nn.functional.interpolate(samples, size=(self.height // opt_f, self.width // opt_f), mode="bilinear")
  378. else:
  379. decoded_samples = self.sd_model.decode_first_stage(samples)
  380. if opts.upscaler_for_img2img is None or opts.upscaler_for_img2img == "None":
  381. decoded_samples = torch.nn.functional.interpolate(decoded_samples, size=(self.height, self.width), mode="bilinear")
  382. else:
  383. lowres_samples = torch.clamp((decoded_samples + 1.0) / 2.0, min=0.0, max=1.0)
  384. batch_images = []
  385. for i, x_sample in enumerate(lowres_samples):
  386. x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
  387. x_sample = x_sample.astype(np.uint8)
  388. image = Image.fromarray(x_sample)
  389. image = images.resize_image(0, image, self.width, self.height)
  390. image = np.array(image).astype(np.float32) / 255.0
  391. image = np.moveaxis(image, 2, 0)
  392. batch_images.append(image)
  393. decoded_samples = torch.from_numpy(np.array(batch_images))
  394. decoded_samples = decoded_samples.to(shared.device)
  395. decoded_samples = 2. * decoded_samples - 1.
  396. samples = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(decoded_samples))
  397. shared.state.nextjob()
  398. self.sampler = samplers[self.sampler_index].constructor(self.sd_model)
  399. noise = create_random_tensors(samples.shape[1:], seeds=seeds, subseeds=subseeds, subseed_strength=subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
  400. # GC now before running the next img2img to prevent running out of memory
  401. x = None
  402. devices.torch_gc()
  403. samples = self.sampler.sample_img2img(self, samples, noise, conditioning, unconditional_conditioning, steps=self.steps)
  404. return samples
  405. class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
  406. sampler = None
  407. def __init__(self, init_images=None, resize_mode=0, denoising_strength=0.75, mask=None, mask_blur=4, inpainting_fill=0, inpaint_full_res=True, inpaint_full_res_padding=0, inpainting_mask_invert=0, **kwargs):
  408. super().__init__(**kwargs)
  409. self.init_images = init_images
  410. self.resize_mode: int = resize_mode
  411. self.denoising_strength: float = denoising_strength
  412. self.init_latent = None
  413. self.image_mask = mask
  414. #self.image_unblurred_mask = None
  415. self.latent_mask = None
  416. self.mask_for_overlay = None
  417. self.mask_blur = mask_blur
  418. self.inpainting_fill = inpainting_fill
  419. self.inpaint_full_res = inpaint_full_res
  420. self.inpaint_full_res_padding = inpaint_full_res_padding
  421. self.inpainting_mask_invert = inpainting_mask_invert
  422. self.mask = None
  423. self.nmask = None
  424. def init(self, all_prompts, all_seeds, all_subseeds):
  425. self.sampler = samplers_for_img2img[self.sampler_index].constructor(self.sd_model)
  426. crop_region = None
  427. if self.image_mask is not None:
  428. self.image_mask = self.image_mask.convert('L')
  429. if self.inpainting_mask_invert:
  430. self.image_mask = ImageOps.invert(self.image_mask)
  431. #self.image_unblurred_mask = self.image_mask
  432. if self.mask_blur > 0:
  433. self.image_mask = self.image_mask.filter(ImageFilter.GaussianBlur(self.mask_blur))
  434. if self.inpaint_full_res:
  435. self.mask_for_overlay = self.image_mask
  436. mask = self.image_mask.convert('L')
  437. crop_region = masking.get_crop_region(np.array(mask), self.inpaint_full_res_padding)
  438. crop_region = masking.expand_crop_region(crop_region, self.width, self.height, mask.width, mask.height)
  439. x1, y1, x2, y2 = crop_region
  440. mask = mask.crop(crop_region)
  441. self.image_mask = images.resize_image(2, mask, self.width, self.height)
  442. self.paste_to = (x1, y1, x2-x1, y2-y1)
  443. else:
  444. self.image_mask = images.resize_image(self.resize_mode, self.image_mask, self.width, self.height)
  445. np_mask = np.array(self.image_mask)
  446. np_mask = np.clip((np_mask.astype(np.float32)) * 2, 0, 255).astype(np.uint8)
  447. self.mask_for_overlay = Image.fromarray(np_mask)
  448. self.overlay_images = []
  449. latent_mask = self.latent_mask if self.latent_mask is not None else self.image_mask
  450. add_color_corrections = opts.img2img_color_correction and self.color_corrections is None
  451. if add_color_corrections:
  452. self.color_corrections = []
  453. imgs = []
  454. for img in self.init_images:
  455. image = img.convert("RGB")
  456. if crop_region is None:
  457. image = images.resize_image(self.resize_mode, image, self.width, self.height)
  458. if self.image_mask is not None:
  459. image_masked = Image.new('RGBa', (image.width, image.height))
  460. image_masked.paste(image.convert("RGBA").convert("RGBa"), mask=ImageOps.invert(self.mask_for_overlay.convert('L')))
  461. self.overlay_images.append(image_masked.convert('RGBA'))
  462. if crop_region is not None:
  463. image = image.crop(crop_region)
  464. image = images.resize_image(2, image, self.width, self.height)
  465. if self.image_mask is not None:
  466. if self.inpainting_fill != 1:
  467. image = masking.fill(image, latent_mask)
  468. if add_color_corrections:
  469. self.color_corrections.append(setup_color_correction(image))
  470. image = np.array(image).astype(np.float32) / 255.0
  471. image = np.moveaxis(image, 2, 0)
  472. imgs.append(image)
  473. if len(imgs) == 1:
  474. batch_images = np.expand_dims(imgs[0], axis=0).repeat(self.batch_size, axis=0)
  475. if self.overlay_images is not None:
  476. self.overlay_images = self.overlay_images * self.batch_size
  477. elif len(imgs) <= self.batch_size:
  478. self.batch_size = len(imgs)
  479. batch_images = np.array(imgs)
  480. else:
  481. raise RuntimeError(f"bad number of images passed: {len(imgs)}; expecting {self.batch_size} or less")
  482. image = torch.from_numpy(batch_images)
  483. image = 2. * image - 1.
  484. image = image.to(shared.device)
  485. self.init_latent = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(image))
  486. if self.image_mask is not None:
  487. init_mask = latent_mask
  488. latmask = init_mask.convert('RGB').resize((self.init_latent.shape[3], self.init_latent.shape[2]))
  489. latmask = np.moveaxis(np.array(latmask, dtype=np.float32), 2, 0) / 255
  490. latmask = latmask[0]
  491. latmask = np.around(latmask)
  492. latmask = np.tile(latmask[None], (4, 1, 1))
  493. self.mask = torch.asarray(1.0 - latmask).to(shared.device).type(self.sd_model.dtype)
  494. self.nmask = torch.asarray(latmask).to(shared.device).type(self.sd_model.dtype)
  495. # this needs to be fixed to be done in sample() using actual seeds for batches
  496. if self.inpainting_fill == 2:
  497. self.init_latent = self.init_latent * self.mask + create_random_tensors(self.init_latent.shape[1:], all_seeds[0:self.init_latent.shape[0]]) * self.nmask
  498. elif self.inpainting_fill == 3:
  499. self.init_latent = self.init_latent * self.mask
  500. def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength):
  501. x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
  502. samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning)
  503. if self.mask is not None:
  504. samples = samples * self.nmask + self.init_latent * self.mask
  505. return samples