processing.py 38 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846
  1. import json
  2. import math
  3. import os
  4. import sys
  5. import torch
  6. import numpy as np
  7. from PIL import Image, ImageFilter, ImageOps
  8. import random
  9. import cv2
  10. from skimage import exposure
  11. from typing import Any, Dict, List, Optional
  12. import modules.sd_hijack
  13. from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste
  14. from modules.sd_hijack import model_hijack
  15. from modules.shared import opts, cmd_opts, state
  16. import modules.shared as shared
  17. import modules.face_restoration
  18. import modules.images as images
  19. import modules.styles
  20. import logging
  21. # some of those options should not be changed at all because they would break the model, so I removed them from options.
  22. opt_C = 4
  23. opt_f = 8
  24. def setup_color_correction(image):
  25. logging.info("Calibrating color correction.")
  26. correction_target = cv2.cvtColor(np.asarray(image.copy()), cv2.COLOR_RGB2LAB)
  27. return correction_target
  28. def apply_color_correction(correction, image):
  29. logging.info("Applying color correction.")
  30. image = Image.fromarray(cv2.cvtColor(exposure.match_histograms(
  31. cv2.cvtColor(
  32. np.asarray(image),
  33. cv2.COLOR_RGB2LAB
  34. ),
  35. correction,
  36. channel_axis=2
  37. ), cv2.COLOR_LAB2RGB).astype("uint8"))
  38. return image
  39. def apply_overlay(image, paste_loc, index, overlays):
  40. if overlays is None or index >= len(overlays):
  41. return image
  42. overlay = overlays[index]
  43. if paste_loc is not None:
  44. x, y, w, h = paste_loc
  45. base_image = Image.new('RGBA', (overlay.width, overlay.height))
  46. image = images.resize_image(1, image, w, h)
  47. base_image.paste(image, (x, y))
  48. image = base_image
  49. image = image.convert('RGBA')
  50. image.alpha_composite(overlay)
  51. image = image.convert('RGB')
  52. return image
  53. def get_correct_sampler(p):
  54. if isinstance(p, modules.processing.StableDiffusionProcessingTxt2Img):
  55. return sd_samplers.samplers
  56. elif isinstance(p, modules.processing.StableDiffusionProcessingImg2Img):
  57. return sd_samplers.samplers_for_img2img
  58. elif isinstance(p, modules.api.processing.StableDiffusionProcessingAPI):
  59. return sd_samplers.samplers
  60. class StableDiffusionProcessing():
  61. """
  62. The first set of paramaters: sd_models -> do_not_reload_embeddings represent the minimum required to create a StableDiffusionProcessing
  63. """
  64. def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt: str = "", styles: List[str] = None, seed: int = -1, subseed: int = -1, subseed_strength: float = 0, seed_resize_from_h: int = -1, seed_resize_from_w: int = -1, seed_enable_extras: bool = True, sampler_index: int = 0, batch_size: int = 1, n_iter: int = 1, steps: int = 50, cfg_scale: float = 7.0, width: int = 512, height: int = 512, restore_faces: bool = False, tiling: bool = False, do_not_save_samples: bool = False, do_not_save_grid: bool = False, extra_generation_params: Dict[Any, Any] = None, overlay_images: Any = None, negative_prompt: str = None, eta: float = None, do_not_reload_embeddings: bool = False, denoising_strength: float = 0, ddim_discretize: str = None, s_churn: float = 0.0, s_tmax: float = None, s_tmin: float = 0.0, s_noise: float = 1.0, override_settings: Dict[str, Any] = None):
  65. self.sd_model = sd_model
  66. self.outpath_samples: str = outpath_samples
  67. self.outpath_grids: str = outpath_grids
  68. self.prompt: str = prompt
  69. self.prompt_for_display: str = None
  70. self.negative_prompt: str = (negative_prompt or "")
  71. self.styles: list = styles or []
  72. self.seed: int = seed
  73. self.subseed: int = subseed
  74. self.subseed_strength: float = subseed_strength
  75. self.seed_resize_from_h: int = seed_resize_from_h
  76. self.seed_resize_from_w: int = seed_resize_from_w
  77. self.sampler_index: int = sampler_index
  78. self.batch_size: int = batch_size
  79. self.n_iter: int = n_iter
  80. self.steps: int = steps
  81. self.cfg_scale: float = cfg_scale
  82. self.width: int = width
  83. self.height: int = height
  84. self.restore_faces: bool = restore_faces
  85. self.tiling: bool = tiling
  86. self.do_not_save_samples: bool = do_not_save_samples
  87. self.do_not_save_grid: bool = do_not_save_grid
  88. self.extra_generation_params: dict = extra_generation_params or {}
  89. self.overlay_images = overlay_images
  90. self.eta = eta
  91. self.do_not_reload_embeddings = do_not_reload_embeddings
  92. self.paste_to = None
  93. self.color_corrections = None
  94. self.denoising_strength: float = denoising_strength
  95. self.sampler_noise_scheduler_override = None
  96. self.ddim_discretize = ddim_discretize or opts.ddim_discretize
  97. self.s_churn = s_churn or opts.s_churn
  98. self.s_tmin = s_tmin or opts.s_tmin
  99. self.s_tmax = s_tmax or float('inf') # not representable as a standard ui option
  100. self.s_noise = s_noise or opts.s_noise
  101. self.override_settings = {k: v for k, v in (override_settings or {}).items() if k not in shared.restricted_opts}
  102. if not seed_enable_extras:
  103. self.subseed = -1
  104. self.subseed_strength = 0
  105. self.seed_resize_from_h = 0
  106. self.seed_resize_from_w = 0
  107. self.scripts = None
  108. self.script_args = None
  109. self.all_prompts = None
  110. self.all_seeds = None
  111. self.all_subseeds = None
  112. def txt2img_image_conditioning(self, x, width=None, height=None):
  113. if self.sampler.conditioning_key not in {'hybrid', 'concat'}:
  114. # Dummy zero conditioning if we're not using inpainting model.
  115. # Still takes up a bit of memory, but no encoder call.
  116. # Pretty sure we can just make this a 1x1 image since its not going to be used besides its batch size.
  117. return torch.zeros(
  118. x.shape[0], 5, 1, 1,
  119. dtype=x.dtype,
  120. device=x.device
  121. )
  122. height = height or self.height
  123. width = width or self.width
  124. # The "masked-image" in this case will just be all zeros since the entire image is masked.
  125. image_conditioning = torch.zeros(x.shape[0], 3, height, width, device=x.device)
  126. image_conditioning = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(image_conditioning))
  127. # Add the fake full 1s mask to the first dimension.
  128. image_conditioning = torch.nn.functional.pad(image_conditioning, (0, 0, 0, 0, 1, 0), value=1.0)
  129. image_conditioning = image_conditioning.to(x.dtype)
  130. return image_conditioning
  131. def img2img_image_conditioning(self, source_image, latent_image, image_mask = None):
  132. if self.sampler.conditioning_key not in {'hybrid', 'concat'}:
  133. # Dummy zero conditioning if we're not using inpainting model.
  134. return torch.zeros(
  135. latent_image.shape[0], 5, 1, 1,
  136. dtype=latent_image.dtype,
  137. device=latent_image.device
  138. )
  139. # Handle the different mask inputs
  140. if image_mask is not None:
  141. if torch.is_tensor(image_mask):
  142. conditioning_mask = image_mask
  143. else:
  144. conditioning_mask = np.array(image_mask.convert("L"))
  145. conditioning_mask = conditioning_mask.astype(np.float32) / 255.0
  146. conditioning_mask = torch.from_numpy(conditioning_mask[None, None])
  147. # Inpainting model uses a discretized mask as input, so we round to either 1.0 or 0.0
  148. conditioning_mask = torch.round(conditioning_mask)
  149. else:
  150. conditioning_mask = torch.ones(1, 1, *source_image.shape[-2:])
  151. # Create another latent image, this time with a masked version of the original input.
  152. # Smoothly interpolate between the masked and unmasked latent conditioning image using a parameter.
  153. conditioning_mask = conditioning_mask.to(source_image.device)
  154. conditioning_image = torch.lerp(
  155. source_image,
  156. source_image * (1.0 - conditioning_mask),
  157. getattr(self, "inpainting_mask_weight", shared.opts.inpainting_mask_weight)
  158. )
  159. # Encode the new masked image using first stage of network.
  160. conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(conditioning_image))
  161. # Create the concatenated conditioning tensor to be fed to `c_concat`
  162. conditioning_mask = torch.nn.functional.interpolate(conditioning_mask, size=latent_image.shape[-2:])
  163. conditioning_mask = conditioning_mask.expand(conditioning_image.shape[0], -1, -1, -1)
  164. image_conditioning = torch.cat([conditioning_mask, conditioning_image], dim=1)
  165. image_conditioning = image_conditioning.to(shared.device).type(self.sd_model.dtype)
  166. return image_conditioning
  167. def init(self, all_prompts, all_seeds, all_subseeds):
  168. pass
  169. def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength):
  170. raise NotImplementedError()
  171. def close(self):
  172. self.sd_model = None
  173. self.sampler = None
  174. class Processed:
  175. def __init__(self, p: StableDiffusionProcessing, images_list, seed=-1, info="", subseed=None, all_prompts=None, all_seeds=None, all_subseeds=None, index_of_first_image=0, infotexts=None):
  176. self.images = images_list
  177. self.prompt = p.prompt
  178. self.negative_prompt = p.negative_prompt
  179. self.seed = seed
  180. self.subseed = subseed
  181. self.subseed_strength = p.subseed_strength
  182. self.info = info
  183. self.width = p.width
  184. self.height = p.height
  185. self.sampler_index = p.sampler_index
  186. self.sampler = sd_samplers.samplers[p.sampler_index].name
  187. self.cfg_scale = p.cfg_scale
  188. self.steps = p.steps
  189. self.batch_size = p.batch_size
  190. self.restore_faces = p.restore_faces
  191. self.face_restoration_model = opts.face_restoration_model if p.restore_faces else None
  192. self.sd_model_hash = shared.sd_model.sd_model_hash
  193. self.seed_resize_from_w = p.seed_resize_from_w
  194. self.seed_resize_from_h = p.seed_resize_from_h
  195. self.denoising_strength = getattr(p, 'denoising_strength', None)
  196. self.extra_generation_params = p.extra_generation_params
  197. self.index_of_first_image = index_of_first_image
  198. self.styles = p.styles
  199. self.job_timestamp = state.job_timestamp
  200. self.clip_skip = opts.CLIP_stop_at_last_layers
  201. self.eta = p.eta
  202. self.ddim_discretize = p.ddim_discretize
  203. self.s_churn = p.s_churn
  204. self.s_tmin = p.s_tmin
  205. self.s_tmax = p.s_tmax
  206. self.s_noise = p.s_noise
  207. self.sampler_noise_scheduler_override = p.sampler_noise_scheduler_override
  208. self.prompt = self.prompt if type(self.prompt) != list else self.prompt[0]
  209. self.negative_prompt = self.negative_prompt if type(self.negative_prompt) != list else self.negative_prompt[0]
  210. self.seed = int(self.seed if type(self.seed) != list else self.seed[0]) if self.seed is not None else -1
  211. self.subseed = int(self.subseed if type(self.subseed) != list else self.subseed[0]) if self.subseed is not None else -1
  212. self.all_prompts = all_prompts or [self.prompt]
  213. self.all_seeds = all_seeds or [self.seed]
  214. self.all_subseeds = all_subseeds or [self.subseed]
  215. self.infotexts = infotexts or [info]
  216. def js(self):
  217. obj = {
  218. "prompt": self.prompt,
  219. "all_prompts": self.all_prompts,
  220. "negative_prompt": self.negative_prompt,
  221. "seed": self.seed,
  222. "all_seeds": self.all_seeds,
  223. "subseed": self.subseed,
  224. "all_subseeds": self.all_subseeds,
  225. "subseed_strength": self.subseed_strength,
  226. "width": self.width,
  227. "height": self.height,
  228. "sampler_index": self.sampler_index,
  229. "sampler": self.sampler,
  230. "cfg_scale": self.cfg_scale,
  231. "steps": self.steps,
  232. "batch_size": self.batch_size,
  233. "restore_faces": self.restore_faces,
  234. "face_restoration_model": self.face_restoration_model,
  235. "sd_model_hash": self.sd_model_hash,
  236. "seed_resize_from_w": self.seed_resize_from_w,
  237. "seed_resize_from_h": self.seed_resize_from_h,
  238. "denoising_strength": self.denoising_strength,
  239. "extra_generation_params": self.extra_generation_params,
  240. "index_of_first_image": self.index_of_first_image,
  241. "infotexts": self.infotexts,
  242. "styles": self.styles,
  243. "job_timestamp": self.job_timestamp,
  244. "clip_skip": self.clip_skip,
  245. }
  246. return json.dumps(obj)
  247. def infotext(self, p: StableDiffusionProcessing, index):
  248. return create_infotext(p, self.all_prompts, self.all_seeds, self.all_subseeds, comments=[], position_in_batch=index % self.batch_size, iteration=index // self.batch_size)
  249. # from https://discuss.pytorch.org/t/help-regarding-slerp-function-for-generative-model-sampling/32475/3
  250. def slerp(val, low, high):
  251. low_norm = low/torch.norm(low, dim=1, keepdim=True)
  252. high_norm = high/torch.norm(high, dim=1, keepdim=True)
  253. dot = (low_norm*high_norm).sum(1)
  254. if dot.mean() > 0.9995:
  255. return low * val + high * (1 - val)
  256. omega = torch.acos(dot)
  257. so = torch.sin(omega)
  258. res = (torch.sin((1.0-val)*omega)/so).unsqueeze(1)*low + (torch.sin(val*omega)/so).unsqueeze(1) * high
  259. return res
  260. def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, seed_resize_from_h=0, seed_resize_from_w=0, p=None):
  261. xs = []
  262. # if we have multiple seeds, this means we are working with batch size>1; this then
  263. # enables the generation of additional tensors with noise that the sampler will use during its processing.
  264. # Using those pre-generated tensors instead of simple torch.randn allows a batch with seeds [100, 101] to
  265. # produce the same images as with two batches [100], [101].
  266. if p is not None and p.sampler is not None and (len(seeds) > 1 and opts.enable_batch_seeds or opts.eta_noise_seed_delta > 0):
  267. sampler_noises = [[] for _ in range(p.sampler.number_of_needed_noises(p))]
  268. else:
  269. sampler_noises = None
  270. for i, seed in enumerate(seeds):
  271. noise_shape = shape if seed_resize_from_h <= 0 or seed_resize_from_w <= 0 else (shape[0], seed_resize_from_h//8, seed_resize_from_w//8)
  272. subnoise = None
  273. if subseeds is not None:
  274. subseed = 0 if i >= len(subseeds) else subseeds[i]
  275. subnoise = devices.randn(subseed, noise_shape)
  276. # randn results depend on device; gpu and cpu get different results for same seed;
  277. # the way I see it, it's better to do this on CPU, so that everyone gets same result;
  278. # but the original script had it like this, so I do not dare change it for now because
  279. # it will break everyone's seeds.
  280. noise = devices.randn(seed, noise_shape)
  281. if subnoise is not None:
  282. noise = slerp(subseed_strength, noise, subnoise)
  283. if noise_shape != shape:
  284. x = devices.randn(seed, shape)
  285. dx = (shape[2] - noise_shape[2]) // 2
  286. dy = (shape[1] - noise_shape[1]) // 2
  287. w = noise_shape[2] if dx >= 0 else noise_shape[2] + 2 * dx
  288. h = noise_shape[1] if dy >= 0 else noise_shape[1] + 2 * dy
  289. tx = 0 if dx < 0 else dx
  290. ty = 0 if dy < 0 else dy
  291. dx = max(-dx, 0)
  292. dy = max(-dy, 0)
  293. x[:, ty:ty+h, tx:tx+w] = noise[:, dy:dy+h, dx:dx+w]
  294. noise = x
  295. if sampler_noises is not None:
  296. cnt = p.sampler.number_of_needed_noises(p)
  297. if opts.eta_noise_seed_delta > 0:
  298. torch.manual_seed(seed + opts.eta_noise_seed_delta)
  299. for j in range(cnt):
  300. sampler_noises[j].append(devices.randn_without_seed(tuple(noise_shape)))
  301. xs.append(noise)
  302. if sampler_noises is not None:
  303. p.sampler.sampler_noises = [torch.stack(n).to(shared.device) for n in sampler_noises]
  304. x = torch.stack(xs).to(shared.device)
  305. return x
  306. def decode_first_stage(model, x):
  307. with devices.autocast(disable=x.dtype == devices.dtype_vae):
  308. x = model.decode_first_stage(x)
  309. return x
  310. def get_fixed_seed(seed):
  311. if seed is None or seed == '' or seed == -1:
  312. return int(random.randrange(4294967294))
  313. return seed
  314. def fix_seed(p):
  315. p.seed = get_fixed_seed(p.seed)
  316. p.subseed = get_fixed_seed(p.subseed)
  317. def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration=0, position_in_batch=0):
  318. index = position_in_batch + iteration * p.batch_size
  319. clip_skip = getattr(p, 'clip_skip', opts.CLIP_stop_at_last_layers)
  320. generation_params = {
  321. "Steps": p.steps,
  322. "Sampler": get_correct_sampler(p)[p.sampler_index].name,
  323. "CFG scale": p.cfg_scale,
  324. "Seed": all_seeds[index],
  325. "Face restoration": (opts.face_restoration_model if p.restore_faces else None),
  326. "Size": f"{p.width}x{p.height}",
  327. "Model hash": getattr(p, 'sd_model_hash', None if not opts.add_model_hash_to_info or not shared.sd_model.sd_model_hash else shared.sd_model.sd_model_hash),
  328. "Model": (None if not opts.add_model_name_to_info or not shared.sd_model.sd_checkpoint_info.model_name else shared.sd_model.sd_checkpoint_info.model_name.replace(',', '').replace(':', '')),
  329. "Hypernet": (None if shared.loaded_hypernetwork is None else shared.loaded_hypernetwork.name),
  330. "Hypernet strength": (None if shared.loaded_hypernetwork is None or shared.opts.sd_hypernetwork_strength >= 1 else shared.opts.sd_hypernetwork_strength),
  331. "Batch size": (None if p.batch_size < 2 else p.batch_size),
  332. "Batch pos": (None if p.batch_size < 2 else position_in_batch),
  333. "Variation seed": (None if p.subseed_strength == 0 else all_subseeds[index]),
  334. "Variation seed strength": (None if p.subseed_strength == 0 else p.subseed_strength),
  335. "Seed resize from": (None if p.seed_resize_from_w == 0 or p.seed_resize_from_h == 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"),
  336. "Denoising strength": getattr(p, 'denoising_strength', None),
  337. "Eta": (None if p.sampler is None or p.sampler.eta == p.sampler.default_eta else p.sampler.eta),
  338. "Clip skip": None if clip_skip <= 1 else clip_skip,
  339. "ENSD": None if opts.eta_noise_seed_delta == 0 else opts.eta_noise_seed_delta,
  340. }
  341. generation_params.update(p.extra_generation_params)
  342. generation_params_text = ", ".join([k if k == v else f'{k}: {generation_parameters_copypaste.quote(v)}' for k, v in generation_params.items() if v is not None])
  343. negative_prompt_text = "\nNegative prompt: " + p.negative_prompt if p.negative_prompt else ""
  344. return f"{all_prompts[index]}{negative_prompt_text}\n{generation_params_text}".strip()
  345. def process_images(p: StableDiffusionProcessing) -> Processed:
  346. stored_opts = {k: opts.data[k] for k in p.override_settings.keys()}
  347. try:
  348. for k, v in p.override_settings.items():
  349. opts.data[k] = v # we don't call onchange for simplicity which makes changing model, hypernet impossible
  350. res = process_images_inner(p)
  351. finally:
  352. for k, v in stored_opts.items():
  353. opts.data[k] = v
  354. return res
  355. def process_images_inner(p: StableDiffusionProcessing) -> Processed:
  356. """this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch"""
  357. if type(p.prompt) == list:
  358. assert(len(p.prompt) > 0)
  359. else:
  360. assert p.prompt is not None
  361. with open(os.path.join(shared.script_path, "params.txt"), "w", encoding="utf8") as file:
  362. processed = Processed(p, [], p.seed, "")
  363. file.write(processed.infotext(p, 0))
  364. devices.torch_gc()
  365. seed = get_fixed_seed(p.seed)
  366. subseed = get_fixed_seed(p.subseed)
  367. modules.sd_hijack.model_hijack.apply_circular(p.tiling)
  368. modules.sd_hijack.model_hijack.clear_comments()
  369. comments = {}
  370. shared.prompt_styles.apply_styles(p)
  371. if type(p.prompt) == list:
  372. p.all_prompts = p.prompt
  373. else:
  374. p.all_prompts = p.batch_size * p.n_iter * [p.prompt]
  375. if type(seed) == list:
  376. p.all_seeds = seed
  377. else:
  378. p.all_seeds = [int(seed) + (x if p.subseed_strength == 0 else 0) for x in range(len(p.all_prompts))]
  379. if type(subseed) == list:
  380. p.all_subseeds = subseed
  381. else:
  382. p.all_subseeds = [int(subseed) + x for x in range(len(p.all_prompts))]
  383. def infotext(iteration=0, position_in_batch=0):
  384. return create_infotext(p, p.all_prompts, p.all_seeds, p.all_subseeds, comments, iteration, position_in_batch)
  385. if os.path.exists(cmd_opts.embeddings_dir) and not p.do_not_reload_embeddings:
  386. model_hijack.embedding_db.load_textual_inversion_embeddings()
  387. if p.scripts is not None:
  388. p.scripts.process(p)
  389. infotexts = []
  390. output_images = []
  391. with torch.no_grad(), p.sd_model.ema_scope():
  392. with devices.autocast():
  393. p.init(p.all_prompts, p.all_seeds, p.all_subseeds)
  394. if state.job_count == -1:
  395. state.job_count = p.n_iter
  396. for n in range(p.n_iter):
  397. if state.skipped:
  398. state.skipped = False
  399. if state.interrupted:
  400. break
  401. prompts = p.all_prompts[n * p.batch_size:(n + 1) * p.batch_size]
  402. seeds = p.all_seeds[n * p.batch_size:(n + 1) * p.batch_size]
  403. subseeds = p.all_subseeds[n * p.batch_size:(n + 1) * p.batch_size]
  404. if len(prompts) == 0:
  405. break
  406. with devices.autocast():
  407. uc = prompt_parser.get_learned_conditioning(shared.sd_model, len(prompts) * [p.negative_prompt], p.steps)
  408. c = prompt_parser.get_multicond_learned_conditioning(shared.sd_model, prompts, p.steps)
  409. if len(model_hijack.comments) > 0:
  410. for comment in model_hijack.comments:
  411. comments[comment] = 1
  412. if p.n_iter > 1:
  413. shared.state.job = f"Batch {n+1} out of {p.n_iter}"
  414. with devices.autocast():
  415. samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength)
  416. samples_ddim = samples_ddim.to(devices.dtype_vae)
  417. x_samples_ddim = decode_first_stage(p.sd_model, samples_ddim)
  418. x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
  419. del samples_ddim
  420. if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
  421. lowvram.send_everything_to_cpu()
  422. devices.torch_gc()
  423. if opts.filter_nsfw:
  424. import modules.safety as safety
  425. x_samples_ddim = modules.safety.censor_batch(x_samples_ddim)
  426. for i, x_sample in enumerate(x_samples_ddim):
  427. x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
  428. x_sample = x_sample.astype(np.uint8)
  429. if p.restore_faces:
  430. if opts.save and not p.do_not_save_samples and opts.save_images_before_face_restoration:
  431. images.save_image(Image.fromarray(x_sample), p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p, suffix="-before-face-restoration")
  432. devices.torch_gc()
  433. x_sample = modules.face_restoration.restore_faces(x_sample)
  434. devices.torch_gc()
  435. image = Image.fromarray(x_sample)
  436. if p.color_corrections is not None and i < len(p.color_corrections):
  437. if opts.save and not p.do_not_save_samples and opts.save_images_before_color_correction:
  438. image_without_cc = apply_overlay(image, p.paste_to, i, p.overlay_images)
  439. images.save_image(image_without_cc, p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p, suffix="-before-color-correction")
  440. image = apply_color_correction(p.color_corrections[i], image)
  441. image = apply_overlay(image, p.paste_to, i, p.overlay_images)
  442. if opts.samples_save and not p.do_not_save_samples:
  443. images.save_image(image, p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p)
  444. text = infotext(n, i)
  445. infotexts.append(text)
  446. if opts.enable_pnginfo:
  447. image.info["parameters"] = text
  448. output_images.append(image)
  449. del x_samples_ddim
  450. devices.torch_gc()
  451. state.nextjob()
  452. p.color_corrections = None
  453. index_of_first_image = 0
  454. unwanted_grid_because_of_img_count = len(output_images) < 2 and opts.grid_only_if_multiple
  455. if (opts.return_grid or opts.grid_save) and not p.do_not_save_grid and not unwanted_grid_because_of_img_count:
  456. grid = images.image_grid(output_images, p.batch_size)
  457. if opts.return_grid:
  458. text = infotext()
  459. infotexts.insert(0, text)
  460. if opts.enable_pnginfo:
  461. grid.info["parameters"] = text
  462. output_images.insert(0, grid)
  463. index_of_first_image = 1
  464. if opts.grid_save:
  465. images.save_image(grid, p.outpath_grids, "grid", p.all_seeds[0], p.all_prompts[0], opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename, p=p, grid=True)
  466. devices.torch_gc()
  467. res = Processed(p, output_images, p.all_seeds[0], infotext() + "".join(["\n\n" + x for x in comments]), subseed=p.all_subseeds[0], all_prompts=p.all_prompts, all_seeds=p.all_seeds, all_subseeds=p.all_subseeds, index_of_first_image=index_of_first_image, infotexts=infotexts)
  468. if p.scripts is not None:
  469. p.scripts.postprocess(p, res)
  470. return res
  471. class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
  472. sampler = None
  473. def __init__(self, enable_hr: bool=False, denoising_strength: float=0.75, firstphase_width: int=0, firstphase_height: int=0, **kwargs):
  474. super().__init__(**kwargs)
  475. self.enable_hr = enable_hr
  476. self.denoising_strength = denoising_strength
  477. self.firstphase_width = firstphase_width
  478. self.firstphase_height = firstphase_height
  479. self.truncate_x = 0
  480. self.truncate_y = 0
  481. def init(self, all_prompts, all_seeds, all_subseeds):
  482. if self.enable_hr:
  483. if state.job_count == -1:
  484. state.job_count = self.n_iter * 2
  485. else:
  486. state.job_count = state.job_count * 2
  487. self.extra_generation_params["First pass size"] = f"{self.firstphase_width}x{self.firstphase_height}"
  488. if self.firstphase_width == 0 or self.firstphase_height == 0:
  489. desired_pixel_count = 512 * 512
  490. actual_pixel_count = self.width * self.height
  491. scale = math.sqrt(desired_pixel_count / actual_pixel_count)
  492. self.firstphase_width = math.ceil(scale * self.width / 64) * 64
  493. self.firstphase_height = math.ceil(scale * self.height / 64) * 64
  494. firstphase_width_truncated = int(scale * self.width)
  495. firstphase_height_truncated = int(scale * self.height)
  496. else:
  497. width_ratio = self.width / self.firstphase_width
  498. height_ratio = self.height / self.firstphase_height
  499. if width_ratio > height_ratio:
  500. firstphase_width_truncated = self.firstphase_width
  501. firstphase_height_truncated = self.firstphase_width * self.height / self.width
  502. else:
  503. firstphase_width_truncated = self.firstphase_height * self.width / self.height
  504. firstphase_height_truncated = self.firstphase_height
  505. self.truncate_x = int(self.firstphase_width - firstphase_width_truncated) // opt_f
  506. self.truncate_y = int(self.firstphase_height - firstphase_height_truncated) // opt_f
  507. def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength):
  508. self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers, self.sampler_index, self.sd_model)
  509. if not self.enable_hr:
  510. x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
  511. samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x))
  512. return samples
  513. x = create_random_tensors([opt_C, self.firstphase_height // opt_f, self.firstphase_width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
  514. samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x, self.firstphase_width, self.firstphase_height))
  515. samples = samples[:, :, self.truncate_y//2:samples.shape[2]-self.truncate_y//2, self.truncate_x//2:samples.shape[3]-self.truncate_x//2]
  516. if opts.use_scale_latent_for_hires_fix:
  517. samples = torch.nn.functional.interpolate(samples, size=(self.height // opt_f, self.width // opt_f), mode="bilinear")
  518. else:
  519. decoded_samples = decode_first_stage(self.sd_model, samples)
  520. lowres_samples = torch.clamp((decoded_samples + 1.0) / 2.0, min=0.0, max=1.0)
  521. batch_images = []
  522. for i, x_sample in enumerate(lowres_samples):
  523. x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
  524. x_sample = x_sample.astype(np.uint8)
  525. image = Image.fromarray(x_sample)
  526. image = images.resize_image(0, image, self.width, self.height)
  527. image = np.array(image).astype(np.float32) / 255.0
  528. image = np.moveaxis(image, 2, 0)
  529. batch_images.append(image)
  530. decoded_samples = torch.from_numpy(np.array(batch_images))
  531. decoded_samples = decoded_samples.to(shared.device)
  532. decoded_samples = 2. * decoded_samples - 1.
  533. samples = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(decoded_samples))
  534. shared.state.nextjob()
  535. self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers, self.sampler_index, self.sd_model)
  536. noise = create_random_tensors(samples.shape[1:], seeds=seeds, subseeds=subseeds, subseed_strength=subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
  537. image_conditioning = self.txt2img_image_conditioning(x)
  538. # GC now before running the next img2img to prevent running out of memory
  539. x = None
  540. devices.torch_gc()
  541. samples = self.sampler.sample_img2img(self, samples, noise, conditioning, unconditional_conditioning, steps=self.steps, image_conditioning=image_conditioning)
  542. return samples
  543. class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
  544. sampler = None
  545. def __init__(self, init_images: list=None, resize_mode: int=0, denoising_strength: float=0.75, mask: Any=None, mask_blur: int=4, inpainting_fill: int=0, inpaint_full_res: bool=True, inpaint_full_res_padding: int=0, inpainting_mask_invert: int=0, **kwargs):
  546. super().__init__(**kwargs)
  547. self.init_images = init_images
  548. self.resize_mode: int = resize_mode
  549. self.denoising_strength: float = denoising_strength
  550. self.init_latent = None
  551. self.image_mask = mask
  552. #self.image_unblurred_mask = None
  553. self.latent_mask = None
  554. self.mask_for_overlay = None
  555. self.mask_blur = mask_blur
  556. self.inpainting_fill = inpainting_fill
  557. self.inpaint_full_res = inpaint_full_res
  558. self.inpaint_full_res_padding = inpaint_full_res_padding
  559. self.inpainting_mask_invert = inpainting_mask_invert
  560. self.mask = None
  561. self.nmask = None
  562. self.image_conditioning = None
  563. def init(self, all_prompts, all_seeds, all_subseeds):
  564. self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers_for_img2img, self.sampler_index, self.sd_model)
  565. crop_region = None
  566. if self.image_mask is not None:
  567. self.image_mask = self.image_mask.convert('L')
  568. if self.inpainting_mask_invert:
  569. self.image_mask = ImageOps.invert(self.image_mask)
  570. #self.image_unblurred_mask = self.image_mask
  571. if self.mask_blur > 0:
  572. self.image_mask = self.image_mask.filter(ImageFilter.GaussianBlur(self.mask_blur))
  573. if self.inpaint_full_res:
  574. self.mask_for_overlay = self.image_mask
  575. mask = self.image_mask.convert('L')
  576. crop_region = masking.get_crop_region(np.array(mask), self.inpaint_full_res_padding)
  577. crop_region = masking.expand_crop_region(crop_region, self.width, self.height, mask.width, mask.height)
  578. x1, y1, x2, y2 = crop_region
  579. mask = mask.crop(crop_region)
  580. self.image_mask = images.resize_image(2, mask, self.width, self.height)
  581. self.paste_to = (x1, y1, x2-x1, y2-y1)
  582. else:
  583. self.image_mask = images.resize_image(self.resize_mode, self.image_mask, self.width, self.height)
  584. np_mask = np.array(self.image_mask)
  585. np_mask = np.clip((np_mask.astype(np.float32)) * 2, 0, 255).astype(np.uint8)
  586. self.mask_for_overlay = Image.fromarray(np_mask)
  587. self.overlay_images = []
  588. latent_mask = self.latent_mask if self.latent_mask is not None else self.image_mask
  589. add_color_corrections = opts.img2img_color_correction and self.color_corrections is None
  590. if add_color_corrections:
  591. self.color_corrections = []
  592. imgs = []
  593. for img in self.init_images:
  594. image = img.convert("RGB")
  595. if crop_region is None:
  596. image = images.resize_image(self.resize_mode, image, self.width, self.height)
  597. if self.image_mask is not None:
  598. image_masked = Image.new('RGBa', (image.width, image.height))
  599. image_masked.paste(image.convert("RGBA").convert("RGBa"), mask=ImageOps.invert(self.mask_for_overlay.convert('L')))
  600. self.overlay_images.append(image_masked.convert('RGBA'))
  601. if crop_region is not None:
  602. image = image.crop(crop_region)
  603. image = images.resize_image(2, image, self.width, self.height)
  604. if self.image_mask is not None:
  605. if self.inpainting_fill != 1:
  606. image = masking.fill(image, latent_mask)
  607. if add_color_corrections:
  608. self.color_corrections.append(setup_color_correction(image))
  609. image = np.array(image).astype(np.float32) / 255.0
  610. image = np.moveaxis(image, 2, 0)
  611. imgs.append(image)
  612. if len(imgs) == 1:
  613. batch_images = np.expand_dims(imgs[0], axis=0).repeat(self.batch_size, axis=0)
  614. if self.overlay_images is not None:
  615. self.overlay_images = self.overlay_images * self.batch_size
  616. if self.color_corrections is not None and len(self.color_corrections) == 1:
  617. self.color_corrections = self.color_corrections * self.batch_size
  618. elif len(imgs) <= self.batch_size:
  619. self.batch_size = len(imgs)
  620. batch_images = np.array(imgs)
  621. else:
  622. raise RuntimeError(f"bad number of images passed: {len(imgs)}; expecting {self.batch_size} or less")
  623. image = torch.from_numpy(batch_images)
  624. image = 2. * image - 1.
  625. image = image.to(shared.device)
  626. self.init_latent = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(image))
  627. if self.image_mask is not None:
  628. init_mask = latent_mask
  629. latmask = init_mask.convert('RGB').resize((self.init_latent.shape[3], self.init_latent.shape[2]))
  630. latmask = np.moveaxis(np.array(latmask, dtype=np.float32), 2, 0) / 255
  631. latmask = latmask[0]
  632. latmask = np.around(latmask)
  633. latmask = np.tile(latmask[None], (4, 1, 1))
  634. self.mask = torch.asarray(1.0 - latmask).to(shared.device).type(self.sd_model.dtype)
  635. self.nmask = torch.asarray(latmask).to(shared.device).type(self.sd_model.dtype)
  636. # this needs to be fixed to be done in sample() using actual seeds for batches
  637. if self.inpainting_fill == 2:
  638. self.init_latent = self.init_latent * self.mask + create_random_tensors(self.init_latent.shape[1:], all_seeds[0:self.init_latent.shape[0]]) * self.nmask
  639. elif self.inpainting_fill == 3:
  640. self.init_latent = self.init_latent * self.mask
  641. self.image_conditioning = self.img2img_image_conditioning(image, self.init_latent, self.image_mask)
  642. def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength):
  643. x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
  644. samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning, image_conditioning=self.image_conditioning)
  645. if self.mask is not None:
  646. samples = samples * self.nmask + self.init_latent * self.mask
  647. del x
  648. devices.torch_gc()
  649. return samples