processing.py 30 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655
  1. import json
  2. import math
  3. import os
  4. import sys
  5. import torch
  6. import numpy as np
  7. from PIL import Image, ImageFilter, ImageOps
  8. import random
  9. import cv2
  10. from skimage import exposure
  11. import modules.sd_hijack
  12. from modules import devices, prompt_parser, masking
  13. from modules.sd_hijack import model_hijack
  14. from modules.sd_samplers import samplers, samplers_for_img2img
  15. from modules.shared import opts, cmd_opts, state
  16. import modules.shared as shared
  17. import modules.face_restoration
  18. import modules.images as images
  19. import modules.styles
  20. import logging
  21. # some of those options should not be changed at all because they would break the model, so I removed them from options.
  22. opt_C = 4
  23. opt_f = 8
  24. def setup_color_correction(image):
  25. logging.info("Calibrating color correction.")
  26. correction_target = cv2.cvtColor(np.asarray(image.copy()), cv2.COLOR_RGB2LAB)
  27. return correction_target
  28. def apply_color_correction(correction, image):
  29. logging.info("Applying color correction.")
  30. image = Image.fromarray(cv2.cvtColor(exposure.match_histograms(
  31. cv2.cvtColor(
  32. np.asarray(image),
  33. cv2.COLOR_RGB2LAB
  34. ),
  35. correction,
  36. channel_axis=2
  37. ), cv2.COLOR_LAB2RGB).astype("uint8"))
  38. return image
  39. class StableDiffusionProcessing:
  40. def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt="", styles=None, seed=-1, subseed=-1, subseed_strength=0, seed_resize_from_h=-1, seed_resize_from_w=-1, seed_enable_extras=True, sampler_index=0, batch_size=1, n_iter=1, steps=50, cfg_scale=7.0, width=512, height=512, restore_faces=False, tiling=False, do_not_save_samples=False, do_not_save_grid=False, extra_generation_params=None, overlay_images=None, negative_prompt=None, eta=None):
  41. self.sd_model = sd_model
  42. self.outpath_samples: str = outpath_samples
  43. self.outpath_grids: str = outpath_grids
  44. self.prompt: str = prompt
  45. self.prompt_for_display: str = None
  46. self.negative_prompt: str = (negative_prompt or "")
  47. self.styles: list = styles or []
  48. self.seed: int = seed
  49. self.subseed: int = subseed
  50. self.subseed_strength: float = subseed_strength
  51. self.seed_resize_from_h: int = seed_resize_from_h
  52. self.seed_resize_from_w: int = seed_resize_from_w
  53. self.sampler_index: int = sampler_index
  54. self.batch_size: int = batch_size
  55. self.n_iter: int = n_iter
  56. self.steps: int = steps
  57. self.cfg_scale: float = cfg_scale
  58. self.width: int = width
  59. self.height: int = height
  60. self.restore_faces: bool = restore_faces
  61. self.tiling: bool = tiling
  62. self.do_not_save_samples: bool = do_not_save_samples
  63. self.do_not_save_grid: bool = do_not_save_grid
  64. self.extra_generation_params: dict = extra_generation_params or {}
  65. self.overlay_images = overlay_images
  66. self.eta = eta
  67. self.paste_to = None
  68. self.color_corrections = None
  69. self.denoising_strength: float = 0
  70. self.sampler_noise_scheduler_override = None
  71. self.ddim_discretize = opts.ddim_discretize
  72. self.s_churn = opts.s_churn
  73. self.s_tmin = opts.s_tmin
  74. self.s_tmax = float('inf') # not representable as a standard ui option
  75. self.s_noise = opts.s_noise
  76. if not seed_enable_extras:
  77. self.subseed = -1
  78. self.subseed_strength = 0
  79. self.seed_resize_from_h = 0
  80. self.seed_resize_from_w = 0
  81. def init(self, all_prompts, all_seeds, all_subseeds):
  82. pass
  83. def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength):
  84. raise NotImplementedError()
  85. class Processed:
  86. def __init__(self, p: StableDiffusionProcessing, images_list, seed=-1, info="", subseed=None, all_prompts=None, all_seeds=None, all_subseeds=None, index_of_first_image=0, infotexts=None):
  87. self.images = images_list
  88. self.prompt = p.prompt
  89. self.negative_prompt = p.negative_prompt
  90. self.seed = seed
  91. self.subseed = subseed
  92. self.subseed_strength = p.subseed_strength
  93. self.info = info
  94. self.width = p.width
  95. self.height = p.height
  96. self.sampler_index = p.sampler_index
  97. self.sampler = samplers[p.sampler_index].name
  98. self.cfg_scale = p.cfg_scale
  99. self.steps = p.steps
  100. self.batch_size = p.batch_size
  101. self.restore_faces = p.restore_faces
  102. self.face_restoration_model = opts.face_restoration_model if p.restore_faces else None
  103. self.sd_model_hash = shared.sd_model.sd_model_hash
  104. self.seed_resize_from_w = p.seed_resize_from_w
  105. self.seed_resize_from_h = p.seed_resize_from_h
  106. self.denoising_strength = getattr(p, 'denoising_strength', None)
  107. self.extra_generation_params = p.extra_generation_params
  108. self.index_of_first_image = index_of_first_image
  109. self.eta = p.eta
  110. self.ddim_discretize = p.ddim_discretize
  111. self.s_churn = p.s_churn
  112. self.s_tmin = p.s_tmin
  113. self.s_tmax = p.s_tmax
  114. self.s_noise = p.s_noise
  115. self.sampler_noise_scheduler_override = p.sampler_noise_scheduler_override
  116. self.prompt = self.prompt if type(self.prompt) != list else self.prompt[0]
  117. self.negative_prompt = self.negative_prompt if type(self.negative_prompt) != list else self.negative_prompt[0]
  118. self.seed = int(self.seed if type(self.seed) != list else self.seed[0])
  119. self.subseed = int(self.subseed if type(self.subseed) != list else self.subseed[0]) if self.subseed is not None else -1
  120. self.all_prompts = all_prompts or [self.prompt]
  121. self.all_seeds = all_seeds or [self.seed]
  122. self.all_subseeds = all_subseeds or [self.subseed]
  123. self.infotexts = infotexts or [info]
  124. def js(self):
  125. obj = {
  126. "prompt": self.prompt,
  127. "all_prompts": self.all_prompts,
  128. "negative_prompt": self.negative_prompt,
  129. "seed": self.seed,
  130. "all_seeds": self.all_seeds,
  131. "subseed": self.subseed,
  132. "all_subseeds": self.all_subseeds,
  133. "subseed_strength": self.subseed_strength,
  134. "width": self.width,
  135. "height": self.height,
  136. "sampler_index": self.sampler_index,
  137. "sampler": self.sampler,
  138. "cfg_scale": self.cfg_scale,
  139. "steps": self.steps,
  140. "batch_size": self.batch_size,
  141. "restore_faces": self.restore_faces,
  142. "face_restoration_model": self.face_restoration_model,
  143. "sd_model_hash": self.sd_model_hash,
  144. "seed_resize_from_w": self.seed_resize_from_w,
  145. "seed_resize_from_h": self.seed_resize_from_h,
  146. "denoising_strength": self.denoising_strength,
  147. "extra_generation_params": self.extra_generation_params,
  148. "index_of_first_image": self.index_of_first_image,
  149. "infotexts": self.infotexts,
  150. }
  151. return json.dumps(obj)
  152. def infotext(self, p: StableDiffusionProcessing, index):
  153. return create_infotext(p, self.all_prompts, self.all_seeds, self.all_subseeds, comments=[], position_in_batch=index % self.batch_size, iteration=index // self.batch_size)
  154. # from https://discuss.pytorch.org/t/help-regarding-slerp-function-for-generative-model-sampling/32475/3
  155. def slerp(val, low, high):
  156. low_norm = low/torch.norm(low, dim=1, keepdim=True)
  157. high_norm = high/torch.norm(high, dim=1, keepdim=True)
  158. dot = (low_norm*high_norm).sum(1)
  159. if dot.mean() > 0.9995:
  160. return low * val + high * (1 - val)
  161. omega = torch.acos(dot)
  162. so = torch.sin(omega)
  163. res = (torch.sin((1.0-val)*omega)/so).unsqueeze(1)*low + (torch.sin(val*omega)/so).unsqueeze(1) * high
  164. return res
  165. def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, seed_resize_from_h=0, seed_resize_from_w=0, p=None):
  166. xs = []
  167. # if we have multiple seeds, this means we are working with batch size>1; this then
  168. # enables the generation of additional tensors with noise that the sampler will use during its processing.
  169. # Using those pre-generated tensors instead of simple torch.randn allows a batch with seeds [100, 101] to
  170. # produce the same images as with two batches [100], [101].
  171. if p is not None and p.sampler is not None and len(seeds) > 1 and opts.enable_batch_seeds:
  172. sampler_noises = [[] for _ in range(p.sampler.number_of_needed_noises(p))]
  173. else:
  174. sampler_noises = None
  175. for i, seed in enumerate(seeds):
  176. noise_shape = shape if seed_resize_from_h <= 0 or seed_resize_from_w <= 0 else (shape[0], seed_resize_from_h//8, seed_resize_from_w//8)
  177. subnoise = None
  178. if subseeds is not None:
  179. subseed = 0 if i >= len(subseeds) else subseeds[i]
  180. subnoise = devices.randn(subseed, noise_shape)
  181. # randn results depend on device; gpu and cpu get different results for same seed;
  182. # the way I see it, it's better to do this on CPU, so that everyone gets same result;
  183. # but the original script had it like this, so I do not dare change it for now because
  184. # it will break everyone's seeds.
  185. noise = devices.randn(seed, noise_shape)
  186. if subnoise is not None:
  187. noise = slerp(subseed_strength, noise, subnoise)
  188. if noise_shape != shape:
  189. x = devices.randn(seed, shape)
  190. dx = (shape[2] - noise_shape[2]) // 2
  191. dy = (shape[1] - noise_shape[1]) // 2
  192. w = noise_shape[2] if dx >= 0 else noise_shape[2] + 2 * dx
  193. h = noise_shape[1] if dy >= 0 else noise_shape[1] + 2 * dy
  194. tx = 0 if dx < 0 else dx
  195. ty = 0 if dy < 0 else dy
  196. dx = max(-dx, 0)
  197. dy = max(-dy, 0)
  198. x[:, ty:ty+h, tx:tx+w] = noise[:, dy:dy+h, dx:dx+w]
  199. noise = x
  200. if sampler_noises is not None:
  201. cnt = p.sampler.number_of_needed_noises(p)
  202. for j in range(cnt):
  203. sampler_noises[j].append(devices.randn_without_seed(tuple(noise_shape)))
  204. xs.append(noise)
  205. if sampler_noises is not None:
  206. p.sampler.sampler_noises = [torch.stack(n).to(shared.device) for n in sampler_noises]
  207. x = torch.stack(xs).to(shared.device)
  208. return x
  209. def fix_seed(p):
  210. p.seed = int(random.randrange(4294967294)) if p.seed is None or p.seed == '' or p.seed == -1 else p.seed
  211. p.subseed = int(random.randrange(4294967294)) if p.subseed is None or p.subseed == '' or p.subseed == -1 else p.subseed
  212. def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration=0, position_in_batch=0):
  213. index = position_in_batch + iteration * p.batch_size
  214. generation_params = {
  215. "Steps": p.steps,
  216. "Sampler": samplers[p.sampler_index].name,
  217. "CFG scale": p.cfg_scale,
  218. "Seed": all_seeds[index],
  219. "Face restoration": (opts.face_restoration_model if p.restore_faces else None),
  220. "Size": f"{p.width}x{p.height}",
  221. "Model hash": getattr(p, 'sd_model_hash', None if not opts.add_model_hash_to_info or not shared.sd_model.sd_model_hash else shared.sd_model.sd_model_hash),
  222. "Batch size": (None if p.batch_size < 2 else p.batch_size),
  223. "Batch pos": (None if p.batch_size < 2 else position_in_batch),
  224. "Variation seed": (None if p.subseed_strength == 0 else all_subseeds[index]),
  225. "Variation seed strength": (None if p.subseed_strength == 0 else p.subseed_strength),
  226. "Seed resize from": (None if p.seed_resize_from_w == 0 or p.seed_resize_from_h == 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"),
  227. "Denoising strength": getattr(p, 'denoising_strength', None),
  228. "Eta": (None if p.sampler is None or p.sampler.eta == p.sampler.default_eta else p.sampler.eta),
  229. }
  230. generation_params.update(p.extra_generation_params)
  231. generation_params_text = ", ".join([k if k == v else f'{k}: {v}' for k, v in generation_params.items() if v is not None])
  232. negative_prompt_text = "\nNegative prompt: " + p.negative_prompt if p.negative_prompt else ""
  233. return f"{all_prompts[index]}{negative_prompt_text}\n{generation_params_text}".strip()
  234. def process_images(p: StableDiffusionProcessing) -> Processed:
  235. """this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch"""
  236. if type(p.prompt) == list:
  237. assert(len(p.prompt) > 0)
  238. else:
  239. assert p.prompt is not None
  240. devices.torch_gc()
  241. fix_seed(p)
  242. if p.outpath_samples is not None:
  243. os.makedirs(p.outpath_samples, exist_ok=True)
  244. if p.outpath_grids is not None:
  245. os.makedirs(p.outpath_grids, exist_ok=True)
  246. modules.sd_hijack.model_hijack.apply_circular(p.tiling)
  247. comments = {}
  248. shared.prompt_styles.apply_styles(p)
  249. if type(p.prompt) == list:
  250. all_prompts = p.prompt
  251. else:
  252. all_prompts = p.batch_size * p.n_iter * [p.prompt]
  253. if type(p.seed) == list:
  254. all_seeds = p.seed
  255. else:
  256. all_seeds = [int(p.seed) + (x if p.subseed_strength == 0 else 0) for x in range(len(all_prompts))]
  257. if type(p.subseed) == list:
  258. all_subseeds = p.subseed
  259. else:
  260. all_subseeds = [int(p.subseed) + x for x in range(len(all_prompts))]
  261. def infotext(iteration=0, position_in_batch=0):
  262. return create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration, position_in_batch)
  263. if os.path.exists(cmd_opts.embeddings_dir):
  264. model_hijack.embedding_db.load_textual_inversion_embeddings()
  265. infotexts = []
  266. output_images = []
  267. with torch.no_grad():
  268. p.init(all_prompts, all_seeds, all_subseeds)
  269. if state.job_count == -1:
  270. state.job_count = p.n_iter
  271. for n in range(p.n_iter):
  272. if state.interrupted:
  273. break
  274. prompts = all_prompts[n * p.batch_size:(n + 1) * p.batch_size]
  275. seeds = all_seeds[n * p.batch_size:(n + 1) * p.batch_size]
  276. subseeds = all_subseeds[n * p.batch_size:(n + 1) * p.batch_size]
  277. if (len(prompts) == 0):
  278. break
  279. #uc = p.sd_model.get_learned_conditioning(len(prompts) * [p.negative_prompt])
  280. #c = p.sd_model.get_learned_conditioning(prompts)
  281. with devices.autocast():
  282. uc = prompt_parser.get_learned_conditioning(len(prompts) * [p.negative_prompt], p.steps)
  283. c = prompt_parser.get_learned_conditioning(prompts, p.steps)
  284. if len(model_hijack.comments) > 0:
  285. for comment in model_hijack.comments:
  286. comments[comment] = 1
  287. if p.n_iter > 1:
  288. shared.state.job = f"Batch {n+1} out of {p.n_iter}"
  289. with devices.autocast():
  290. samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength).to(devices.dtype)
  291. if state.interrupted:
  292. # if we are interruped, sample returns just noise
  293. # use the image collected previously in sampler loop
  294. samples_ddim = shared.state.current_latent
  295. x_samples_ddim = p.sd_model.decode_first_stage(samples_ddim)
  296. x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
  297. if opts.filter_nsfw:
  298. import modules.safety as safety
  299. x_samples_ddim = modules.safety.censor_batch(x_samples_ddim)
  300. for i, x_sample in enumerate(x_samples_ddim):
  301. x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
  302. x_sample = x_sample.astype(np.uint8)
  303. if p.restore_faces:
  304. if opts.save and not p.do_not_save_samples and opts.save_images_before_face_restoration:
  305. images.save_image(Image.fromarray(x_sample), p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p, suffix="-before-face-restoration")
  306. devices.torch_gc()
  307. x_sample = modules.face_restoration.restore_faces(x_sample)
  308. devices.torch_gc()
  309. image = Image.fromarray(x_sample)
  310. if p.color_corrections is not None and i < len(p.color_corrections):
  311. if opts.save and not p.do_not_save_samples and opts.save_images_before_color_correction:
  312. images.save_image(image, p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p, suffix="-before-color-correction")
  313. image = apply_color_correction(p.color_corrections[i], image)
  314. if p.overlay_images is not None and i < len(p.overlay_images):
  315. overlay = p.overlay_images[i]
  316. if p.paste_to is not None:
  317. x, y, w, h = p.paste_to
  318. base_image = Image.new('RGBA', (overlay.width, overlay.height))
  319. image = images.resize_image(1, image, w, h)
  320. base_image.paste(image, (x, y))
  321. image = base_image
  322. image = image.convert('RGBA')
  323. image.alpha_composite(overlay)
  324. image = image.convert('RGB')
  325. if opts.samples_save and not p.do_not_save_samples:
  326. images.save_image(image, p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p)
  327. infotexts.append(infotext(n, i))
  328. output_images.append(image)
  329. state.nextjob()
  330. p.color_corrections = None
  331. index_of_first_image = 0
  332. unwanted_grid_because_of_img_count = len(output_images) < 2 and opts.grid_only_if_multiple
  333. if (opts.return_grid or opts.grid_save) and not p.do_not_save_grid and not unwanted_grid_because_of_img_count:
  334. grid = images.image_grid(output_images, p.batch_size)
  335. if opts.return_grid:
  336. infotexts.insert(0, infotext())
  337. output_images.insert(0, grid)
  338. index_of_first_image = 1
  339. if opts.grid_save:
  340. images.save_image(grid, p.outpath_grids, "grid", all_seeds[0], all_prompts[0], opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename, p=p, grid=True)
  341. devices.torch_gc()
  342. return Processed(p, output_images, all_seeds[0], infotext() + "".join(["\n\n" + x for x in comments]), subseed=all_subseeds[0], all_prompts=all_prompts, all_seeds=all_seeds, all_subseeds=all_subseeds, index_of_first_image=index_of_first_image, infotexts=infotexts)
  343. class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
  344. sampler = None
  345. firstphase_width = 0
  346. firstphase_height = 0
  347. firstphase_width_truncated = 0
  348. firstphase_height_truncated = 0
  349. def __init__(self, enable_hr=False, scale_latent=True, denoising_strength=0.75, **kwargs):
  350. super().__init__(**kwargs)
  351. self.enable_hr = enable_hr
  352. self.scale_latent = scale_latent
  353. self.denoising_strength = denoising_strength
  354. def init(self, all_prompts, all_seeds, all_subseeds):
  355. if self.enable_hr:
  356. if state.job_count == -1:
  357. state.job_count = self.n_iter * 2
  358. else:
  359. state.job_count = state.job_count * 2
  360. desired_pixel_count = 512 * 512
  361. actual_pixel_count = self.width * self.height
  362. scale = math.sqrt(desired_pixel_count / actual_pixel_count)
  363. self.firstphase_width = math.ceil(scale * self.width / 64) * 64
  364. self.firstphase_height = math.ceil(scale * self.height / 64) * 64
  365. self.firstphase_width_truncated = int(scale * self.width)
  366. self.firstphase_height_truncated = int(scale * self.height)
  367. def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength):
  368. self.sampler = samplers[self.sampler_index].constructor(self.sd_model)
  369. if not self.enable_hr:
  370. x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
  371. samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning)
  372. return samples
  373. x = create_random_tensors([opt_C, self.firstphase_height // opt_f, self.firstphase_width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
  374. samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning)
  375. truncate_x = (self.firstphase_width - self.firstphase_width_truncated) // opt_f
  376. truncate_y = (self.firstphase_height - self.firstphase_height_truncated) // opt_f
  377. samples = samples[:, :, truncate_y//2:samples.shape[2]-truncate_y//2, truncate_x//2:samples.shape[3]-truncate_x//2]
  378. if self.scale_latent:
  379. samples = torch.nn.functional.interpolate(samples, size=(self.height // opt_f, self.width // opt_f), mode="bilinear")
  380. else:
  381. decoded_samples = self.sd_model.decode_first_stage(samples)
  382. if opts.upscaler_for_img2img is None or opts.upscaler_for_img2img == "None":
  383. decoded_samples = torch.nn.functional.interpolate(decoded_samples, size=(self.height, self.width), mode="bilinear")
  384. else:
  385. lowres_samples = torch.clamp((decoded_samples + 1.0) / 2.0, min=0.0, max=1.0)
  386. batch_images = []
  387. for i, x_sample in enumerate(lowres_samples):
  388. x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
  389. x_sample = x_sample.astype(np.uint8)
  390. image = Image.fromarray(x_sample)
  391. image = images.resize_image(0, image, self.width, self.height)
  392. image = np.array(image).astype(np.float32) / 255.0
  393. image = np.moveaxis(image, 2, 0)
  394. batch_images.append(image)
  395. decoded_samples = torch.from_numpy(np.array(batch_images))
  396. decoded_samples = decoded_samples.to(shared.device)
  397. decoded_samples = 2. * decoded_samples - 1.
  398. samples = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(decoded_samples))
  399. shared.state.nextjob()
  400. self.sampler = samplers[self.sampler_index].constructor(self.sd_model)
  401. noise = create_random_tensors(samples.shape[1:], seeds=seeds, subseeds=subseeds, subseed_strength=subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
  402. # GC now before running the next img2img to prevent running out of memory
  403. x = None
  404. devices.torch_gc()
  405. samples = self.sampler.sample_img2img(self, samples, noise, conditioning, unconditional_conditioning, steps=self.steps)
  406. return samples
  407. class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
  408. sampler = None
  409. def __init__(self, init_images=None, resize_mode=0, denoising_strength=0.75, mask=None, mask_blur=4, inpainting_fill=0, inpaint_full_res=True, inpaint_full_res_padding=0, inpainting_mask_invert=0, **kwargs):
  410. super().__init__(**kwargs)
  411. self.init_images = init_images
  412. self.resize_mode: int = resize_mode
  413. self.denoising_strength: float = denoising_strength
  414. self.init_latent = None
  415. self.image_mask = mask
  416. #self.image_unblurred_mask = None
  417. self.latent_mask = None
  418. self.mask_for_overlay = None
  419. self.mask_blur = mask_blur
  420. self.inpainting_fill = inpainting_fill
  421. self.inpaint_full_res = inpaint_full_res
  422. self.inpaint_full_res_padding = inpaint_full_res_padding
  423. self.inpainting_mask_invert = inpainting_mask_invert
  424. self.mask = None
  425. self.nmask = None
  426. def init(self, all_prompts, all_seeds, all_subseeds):
  427. self.sampler = samplers_for_img2img[self.sampler_index].constructor(self.sd_model)
  428. crop_region = None
  429. if self.image_mask is not None:
  430. self.image_mask = self.image_mask.convert('L')
  431. if self.inpainting_mask_invert:
  432. self.image_mask = ImageOps.invert(self.image_mask)
  433. #self.image_unblurred_mask = self.image_mask
  434. if self.mask_blur > 0:
  435. self.image_mask = self.image_mask.filter(ImageFilter.GaussianBlur(self.mask_blur))
  436. if self.inpaint_full_res:
  437. self.mask_for_overlay = self.image_mask
  438. mask = self.image_mask.convert('L')
  439. crop_region = masking.get_crop_region(np.array(mask), self.inpaint_full_res_padding)
  440. crop_region = masking.expand_crop_region(crop_region, self.width, self.height, mask.width, mask.height)
  441. x1, y1, x2, y2 = crop_region
  442. mask = mask.crop(crop_region)
  443. self.image_mask = images.resize_image(2, mask, self.width, self.height)
  444. self.paste_to = (x1, y1, x2-x1, y2-y1)
  445. else:
  446. self.image_mask = images.resize_image(self.resize_mode, self.image_mask, self.width, self.height)
  447. np_mask = np.array(self.image_mask)
  448. np_mask = np.clip((np_mask.astype(np.float32)) * 2, 0, 255).astype(np.uint8)
  449. self.mask_for_overlay = Image.fromarray(np_mask)
  450. self.overlay_images = []
  451. latent_mask = self.latent_mask if self.latent_mask is not None else self.image_mask
  452. add_color_corrections = opts.img2img_color_correction and self.color_corrections is None
  453. if add_color_corrections:
  454. self.color_corrections = []
  455. imgs = []
  456. for img in self.init_images:
  457. image = img.convert("RGB")
  458. if crop_region is None:
  459. image = images.resize_image(self.resize_mode, image, self.width, self.height)
  460. if self.image_mask is not None:
  461. image_masked = Image.new('RGBa', (image.width, image.height))
  462. image_masked.paste(image.convert("RGBA").convert("RGBa"), mask=ImageOps.invert(self.mask_for_overlay.convert('L')))
  463. self.overlay_images.append(image_masked.convert('RGBA'))
  464. if crop_region is not None:
  465. image = image.crop(crop_region)
  466. image = images.resize_image(2, image, self.width, self.height)
  467. if self.image_mask is not None:
  468. if self.inpainting_fill != 1:
  469. image = masking.fill(image, latent_mask)
  470. if add_color_corrections:
  471. self.color_corrections.append(setup_color_correction(image))
  472. image = np.array(image).astype(np.float32) / 255.0
  473. image = np.moveaxis(image, 2, 0)
  474. imgs.append(image)
  475. if len(imgs) == 1:
  476. batch_images = np.expand_dims(imgs[0], axis=0).repeat(self.batch_size, axis=0)
  477. if self.overlay_images is not None:
  478. self.overlay_images = self.overlay_images * self.batch_size
  479. elif len(imgs) <= self.batch_size:
  480. self.batch_size = len(imgs)
  481. batch_images = np.array(imgs)
  482. else:
  483. raise RuntimeError(f"bad number of images passed: {len(imgs)}; expecting {self.batch_size} or less")
  484. image = torch.from_numpy(batch_images)
  485. image = 2. * image - 1.
  486. image = image.to(shared.device)
  487. self.init_latent = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(image))
  488. if self.image_mask is not None:
  489. init_mask = latent_mask
  490. latmask = init_mask.convert('RGB').resize((self.init_latent.shape[3], self.init_latent.shape[2]))
  491. latmask = np.moveaxis(np.array(latmask, dtype=np.float32), 2, 0) / 255
  492. latmask = latmask[0]
  493. latmask = np.around(latmask)
  494. latmask = np.tile(latmask[None], (4, 1, 1))
  495. self.mask = torch.asarray(1.0 - latmask).to(shared.device).type(self.sd_model.dtype)
  496. self.nmask = torch.asarray(latmask).to(shared.device).type(self.sd_model.dtype)
  497. # this needs to be fixed to be done in sample() using actual seeds for batches
  498. if self.inpainting_fill == 2:
  499. self.init_latent = self.init_latent * self.mask + create_random_tensors(self.init_latent.shape[1:], all_seeds[0:self.init_latent.shape[0]]) * self.nmask
  500. elif self.inpainting_fill == 3:
  501. self.init_latent = self.init_latent * self.mask
  502. def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength):
  503. x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
  504. samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning)
  505. if self.mask is not None:
  506. samples = samples * self.nmask + self.init_latent * self.mask
  507. return samples