sd_samplers.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535
  1. from collections import namedtuple, deque
  2. import numpy as np
  3. from math import floor
  4. import torch
  5. import tqdm
  6. from PIL import Image
  7. import inspect
  8. import k_diffusion.sampling
  9. import torchsde._brownian.brownian_interval
  10. import ldm.models.diffusion.ddim
  11. import ldm.models.diffusion.plms
  12. from modules import prompt_parser, devices, processing, images, sd_vae_approx
  13. from modules.shared import opts, cmd_opts, state
  14. import modules.shared as shared
  15. from modules.script_callbacks import CFGDenoiserParams, cfg_denoiser_callback
  16. SamplerData = namedtuple('SamplerData', ['name', 'constructor', 'aliases', 'options'])
  17. samplers_k_diffusion = [
  18. ('Euler a', 'sample_euler_ancestral', ['k_euler_a', 'k_euler_ancestral'], {}),
  19. ('Euler', 'sample_euler', ['k_euler'], {}),
  20. ('LMS', 'sample_lms', ['k_lms'], {}),
  21. ('Heun', 'sample_heun', ['k_heun'], {}),
  22. ('DPM2', 'sample_dpm_2', ['k_dpm_2'], {'discard_next_to_last_sigma': True}),
  23. ('DPM2 a', 'sample_dpm_2_ancestral', ['k_dpm_2_a'], {'discard_next_to_last_sigma': True}),
  24. ('DPM++ 2S a', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a'], {}),
  25. ('DPM++ 2M', 'sample_dpmpp_2m', ['k_dpmpp_2m'], {}),
  26. ('DPM++ SDE', 'sample_dpmpp_sde', ['k_dpmpp_sde'], {}),
  27. ('DPM fast', 'sample_dpm_fast', ['k_dpm_fast'], {}),
  28. ('DPM adaptive', 'sample_dpm_adaptive', ['k_dpm_ad'], {}),
  29. ('LMS Karras', 'sample_lms', ['k_lms_ka'], {'scheduler': 'karras'}),
  30. ('DPM2 Karras', 'sample_dpm_2', ['k_dpm_2_ka'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True}),
  31. ('DPM2 a Karras', 'sample_dpm_2_ancestral', ['k_dpm_2_a_ka'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True}),
  32. ('DPM++ 2S a Karras', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a_ka'], {'scheduler': 'karras'}),
  33. ('DPM++ 2M Karras', 'sample_dpmpp_2m', ['k_dpmpp_2m_ka'], {'scheduler': 'karras'}),
  34. ('DPM++ SDE Karras', 'sample_dpmpp_sde', ['k_dpmpp_sde_ka'], {'scheduler': 'karras'}),
  35. ]
  36. samplers_data_k_diffusion = [
  37. SamplerData(label, lambda model, funcname=funcname: KDiffusionSampler(funcname, model), aliases, options)
  38. for label, funcname, aliases, options in samplers_k_diffusion
  39. if hasattr(k_diffusion.sampling, funcname)
  40. ]
  41. all_samplers = [
  42. *samplers_data_k_diffusion,
  43. SamplerData('DDIM', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.ddim.DDIMSampler, model), [], {}),
  44. SamplerData('PLMS', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.plms.PLMSSampler, model), [], {}),
  45. ]
  46. all_samplers_map = {x.name: x for x in all_samplers}
  47. samplers = []
  48. samplers_for_img2img = []
  49. samplers_map = {}
  50. def create_sampler(name, model):
  51. if name is not None:
  52. config = all_samplers_map.get(name, None)
  53. else:
  54. config = all_samplers[0]
  55. assert config is not None, f'bad sampler name: {name}'
  56. sampler = config.constructor(model)
  57. sampler.config = config
  58. return sampler
  59. def set_samplers():
  60. global samplers, samplers_for_img2img
  61. hidden = set(opts.hide_samplers)
  62. hidden_img2img = set(opts.hide_samplers + ['PLMS'])
  63. samplers = [x for x in all_samplers if x.name not in hidden]
  64. samplers_for_img2img = [x for x in all_samplers if x.name not in hidden_img2img]
  65. samplers_map.clear()
  66. for sampler in all_samplers:
  67. samplers_map[sampler.name.lower()] = sampler.name
  68. for alias in sampler.aliases:
  69. samplers_map[alias.lower()] = sampler.name
  70. set_samplers()
  71. sampler_extra_params = {
  72. 'sample_euler': ['s_churn', 's_tmin', 's_tmax', 's_noise'],
  73. 'sample_heun': ['s_churn', 's_tmin', 's_tmax', 's_noise'],
  74. 'sample_dpm_2': ['s_churn', 's_tmin', 's_tmax', 's_noise'],
  75. }
  76. def setup_img2img_steps(p, steps=None):
  77. if opts.img2img_fix_steps or steps is not None:
  78. steps = int((steps or p.steps) / min(p.denoising_strength, 0.999)) if p.denoising_strength > 0 else 0
  79. t_enc = p.steps - 1
  80. else:
  81. steps = p.steps
  82. t_enc = int(min(p.denoising_strength, 0.999) * steps)
  83. return steps, t_enc
  84. approximation_indexes = {"Full": 0, "Approx NN": 1, "Approx cheap": 2}
  85. def single_sample_to_image(sample, approximation=None):
  86. if approximation is None:
  87. approximation = approximation_indexes.get(opts.show_progress_type, 0)
  88. if approximation == 2:
  89. x_sample = sd_vae_approx.cheap_approximation(sample)
  90. elif approximation == 1:
  91. x_sample = sd_vae_approx.model()(sample.to(devices.device, devices.dtype).unsqueeze(0))[0].detach()
  92. else:
  93. x_sample = processing.decode_first_stage(shared.sd_model, sample.unsqueeze(0))[0]
  94. x_sample = torch.clamp((x_sample + 1.0) / 2.0, min=0.0, max=1.0)
  95. x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
  96. x_sample = x_sample.astype(np.uint8)
  97. return Image.fromarray(x_sample)
  98. def sample_to_image(samples, index=0, approximation=None):
  99. return single_sample_to_image(samples[index], approximation)
  100. def samples_to_image_grid(samples, approximation=None):
  101. return images.image_grid([single_sample_to_image(sample, approximation) for sample in samples])
  102. def store_latent(decoded):
  103. state.current_latent = decoded
  104. if opts.show_progress_every_n_steps > 0 and shared.state.sampling_step % opts.show_progress_every_n_steps == 0:
  105. if not shared.parallel_processing_allowed:
  106. shared.state.current_image = sample_to_image(decoded)
  107. class InterruptedException(BaseException):
  108. pass
  109. class VanillaStableDiffusionSampler:
  110. def __init__(self, constructor, sd_model):
  111. self.sampler = constructor(sd_model)
  112. self.is_plms = hasattr(self.sampler, 'p_sample_plms')
  113. self.orig_p_sample_ddim = self.sampler.p_sample_plms if self.is_plms else self.sampler.p_sample_ddim
  114. self.mask = None
  115. self.nmask = None
  116. self.init_latent = None
  117. self.sampler_noises = None
  118. self.step = 0
  119. self.stop_at = None
  120. self.eta = None
  121. self.default_eta = 0.0
  122. self.config = None
  123. self.last_latent = None
  124. self.conditioning_key = sd_model.model.conditioning_key
  125. def number_of_needed_noises(self, p):
  126. return 0
  127. def launch_sampling(self, steps, func):
  128. state.sampling_steps = steps
  129. state.sampling_step = 0
  130. try:
  131. return func()
  132. except InterruptedException:
  133. return self.last_latent
  134. def p_sample_ddim_hook(self, x_dec, cond, ts, unconditional_conditioning, *args, **kwargs):
  135. if state.interrupted or state.skipped:
  136. raise InterruptedException
  137. if self.stop_at is not None and self.step > self.stop_at:
  138. raise InterruptedException
  139. # Have to unwrap the inpainting conditioning here to perform pre-processing
  140. image_conditioning = None
  141. if isinstance(cond, dict):
  142. image_conditioning = cond["c_concat"][0]
  143. cond = cond["c_crossattn"][0]
  144. unconditional_conditioning = unconditional_conditioning["c_crossattn"][0]
  145. conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step)
  146. unconditional_conditioning = prompt_parser.reconstruct_cond_batch(unconditional_conditioning, self.step)
  147. assert all([len(conds) == 1 for conds in conds_list]), 'composition via AND is not supported for DDIM/PLMS samplers'
  148. cond = tensor
  149. # for DDIM, shapes must match, we can't just process cond and uncond independently;
  150. # filling unconditional_conditioning with repeats of the last vector to match length is
  151. # not 100% correct but should work well enough
  152. if unconditional_conditioning.shape[1] < cond.shape[1]:
  153. last_vector = unconditional_conditioning[:, -1:]
  154. last_vector_repeated = last_vector.repeat([1, cond.shape[1] - unconditional_conditioning.shape[1], 1])
  155. unconditional_conditioning = torch.hstack([unconditional_conditioning, last_vector_repeated])
  156. elif unconditional_conditioning.shape[1] > cond.shape[1]:
  157. unconditional_conditioning = unconditional_conditioning[:, :cond.shape[1]]
  158. if self.mask is not None:
  159. img_orig = self.sampler.model.q_sample(self.init_latent, ts)
  160. x_dec = img_orig * self.mask + self.nmask * x_dec
  161. # Wrap the image conditioning back up since the DDIM code can accept the dict directly.
  162. # Note that they need to be lists because it just concatenates them later.
  163. if image_conditioning is not None:
  164. cond = {"c_concat": [image_conditioning], "c_crossattn": [cond]}
  165. unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]}
  166. res = self.orig_p_sample_ddim(x_dec, cond, ts, unconditional_conditioning=unconditional_conditioning, *args, **kwargs)
  167. if self.mask is not None:
  168. self.last_latent = self.init_latent * self.mask + self.nmask * res[1]
  169. else:
  170. self.last_latent = res[1]
  171. store_latent(self.last_latent)
  172. self.step += 1
  173. state.sampling_step = self.step
  174. shared.total_tqdm.update()
  175. return res
  176. def initialize(self, p):
  177. self.eta = p.eta if p.eta is not None else opts.eta_ddim
  178. for fieldname in ['p_sample_ddim', 'p_sample_plms']:
  179. if hasattr(self.sampler, fieldname):
  180. setattr(self.sampler, fieldname, self.p_sample_ddim_hook)
  181. self.mask = p.mask if hasattr(p, 'mask') else None
  182. self.nmask = p.nmask if hasattr(p, 'nmask') else None
  183. def adjust_steps_if_invalid(self, p, num_steps):
  184. if (self.config.name == 'DDIM' and p.ddim_discretize == 'uniform') or (self.config.name == 'PLMS'):
  185. valid_step = 999 / (1000 // num_steps)
  186. if valid_step == floor(valid_step):
  187. return int(valid_step) + 1
  188. return num_steps
  189. def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
  190. steps, t_enc = setup_img2img_steps(p, steps)
  191. steps = self.adjust_steps_if_invalid(p, steps)
  192. self.initialize(p)
  193. self.sampler.make_schedule(ddim_num_steps=steps, ddim_eta=self.eta, ddim_discretize=p.ddim_discretize, verbose=False)
  194. x1 = self.sampler.stochastic_encode(x, torch.tensor([t_enc] * int(x.shape[0])).to(shared.device), noise=noise)
  195. self.init_latent = x
  196. self.last_latent = x
  197. self.step = 0
  198. # Wrap the conditioning models with additional image conditioning for inpainting model
  199. if image_conditioning is not None:
  200. conditioning = {"c_concat": [image_conditioning], "c_crossattn": [conditioning]}
  201. unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]}
  202. samples = self.launch_sampling(t_enc + 1, lambda: self.sampler.decode(x1, conditioning, t_enc, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning))
  203. return samples
  204. def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
  205. self.initialize(p)
  206. self.init_latent = None
  207. self.last_latent = x
  208. self.step = 0
  209. steps = self.adjust_steps_if_invalid(p, steps or p.steps)
  210. # Wrap the conditioning models with additional image conditioning for inpainting model
  211. # dummy_for_plms is needed because PLMS code checks the first item in the dict to have the right shape
  212. if image_conditioning is not None:
  213. conditioning = {"dummy_for_plms": np.zeros((conditioning.shape[0],)), "c_crossattn": [conditioning], "c_concat": [image_conditioning]}
  214. unconditional_conditioning = {"c_crossattn": [unconditional_conditioning], "c_concat": [image_conditioning]}
  215. samples_ddim = self.launch_sampling(steps, lambda: self.sampler.sample(S=steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x, eta=self.eta)[0])
  216. return samples_ddim
  217. class CFGDenoiser(torch.nn.Module):
  218. def __init__(self, model):
  219. super().__init__()
  220. self.inner_model = model
  221. self.mask = None
  222. self.nmask = None
  223. self.init_latent = None
  224. self.step = 0
  225. def combine_denoised(self, x_out, conds_list, uncond, cond_scale):
  226. denoised_uncond = x_out[-uncond.shape[0]:]
  227. denoised = torch.clone(denoised_uncond)
  228. for i, conds in enumerate(conds_list):
  229. for cond_index, weight in conds:
  230. denoised[i] += (x_out[cond_index] - denoised_uncond[i]) * (weight * cond_scale)
  231. return denoised
  232. def forward(self, x, sigma, uncond, cond, cond_scale, image_cond):
  233. if state.interrupted or state.skipped:
  234. raise InterruptedException
  235. conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step)
  236. uncond = prompt_parser.reconstruct_cond_batch(uncond, self.step)
  237. batch_size = len(conds_list)
  238. repeats = [len(conds_list[i]) for i in range(batch_size)]
  239. x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x])
  240. image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_cond])
  241. sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma])
  242. denoiser_params = CFGDenoiserParams(x_in, image_cond_in, sigma_in, state.sampling_step, state.sampling_steps)
  243. cfg_denoiser_callback(denoiser_params)
  244. x_in = denoiser_params.x
  245. image_cond_in = denoiser_params.image_cond
  246. sigma_in = denoiser_params.sigma
  247. if tensor.shape[1] == uncond.shape[1]:
  248. cond_in = torch.cat([tensor, uncond])
  249. if shared.batch_cond_uncond:
  250. x_out = self.inner_model(x_in, sigma_in, cond={"c_crossattn": [cond_in], "c_concat": [image_cond_in]})
  251. else:
  252. x_out = torch.zeros_like(x_in)
  253. for batch_offset in range(0, x_out.shape[0], batch_size):
  254. a = batch_offset
  255. b = a + batch_size
  256. x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond={"c_crossattn": [cond_in[a:b]], "c_concat": [image_cond_in[a:b]]})
  257. else:
  258. x_out = torch.zeros_like(x_in)
  259. batch_size = batch_size*2 if shared.batch_cond_uncond else batch_size
  260. for batch_offset in range(0, tensor.shape[0], batch_size):
  261. a = batch_offset
  262. b = min(a + batch_size, tensor.shape[0])
  263. x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond={"c_crossattn": [tensor[a:b]], "c_concat": [image_cond_in[a:b]]})
  264. x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond={"c_crossattn": [uncond], "c_concat": [image_cond_in[-uncond.shape[0]:]]})
  265. denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale)
  266. if self.mask is not None:
  267. denoised = self.init_latent * self.mask + self.nmask * denoised
  268. self.step += 1
  269. return denoised
  270. class TorchHijack:
  271. def __init__(self, sampler_noises):
  272. # Using a deque to efficiently receive the sampler_noises in the same order as the previous index-based
  273. # implementation.
  274. self.sampler_noises = deque(sampler_noises)
  275. def __getattr__(self, item):
  276. if item == 'randn_like':
  277. return self.randn_like
  278. if hasattr(torch, item):
  279. return getattr(torch, item)
  280. raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, item))
  281. def randn_like(self, x):
  282. if self.sampler_noises:
  283. noise = self.sampler_noises.popleft()
  284. if noise.shape == x.shape:
  285. return noise
  286. if x.device.type == 'mps':
  287. return torch.randn_like(x, device=devices.cpu).to(x.device)
  288. else:
  289. return torch.randn_like(x)
  290. # MPS fix for randn in torchsde
  291. def torchsde_randn(size, dtype, device, seed):
  292. if device.type == 'mps':
  293. generator = torch.Generator(devices.cpu).manual_seed(int(seed))
  294. return torch.randn(size, dtype=dtype, device=devices.cpu, generator=generator).to(device)
  295. else:
  296. generator = torch.Generator(device).manual_seed(int(seed))
  297. return torch.randn(size, dtype=dtype, device=device, generator=generator)
  298. torchsde._brownian.brownian_interval._randn = torchsde_randn
  299. class KDiffusionSampler:
  300. def __init__(self, funcname, sd_model):
  301. denoiser = k_diffusion.external.CompVisVDenoiser if sd_model.parameterization == "v" else k_diffusion.external.CompVisDenoiser
  302. self.model_wrap = denoiser(sd_model, quantize=shared.opts.enable_quantization)
  303. self.funcname = funcname
  304. self.func = getattr(k_diffusion.sampling, self.funcname)
  305. self.extra_params = sampler_extra_params.get(funcname, [])
  306. self.model_wrap_cfg = CFGDenoiser(self.model_wrap)
  307. self.sampler_noises = None
  308. self.stop_at = None
  309. self.eta = None
  310. self.default_eta = 1.0
  311. self.config = None
  312. self.last_latent = None
  313. self.conditioning_key = sd_model.model.conditioning_key
  314. def callback_state(self, d):
  315. step = d['i']
  316. latent = d["denoised"]
  317. store_latent(latent)
  318. self.last_latent = latent
  319. if self.stop_at is not None and step > self.stop_at:
  320. raise InterruptedException
  321. state.sampling_step = step
  322. shared.total_tqdm.update()
  323. def launch_sampling(self, steps, func):
  324. state.sampling_steps = steps
  325. state.sampling_step = 0
  326. try:
  327. return func()
  328. except InterruptedException:
  329. return self.last_latent
  330. def number_of_needed_noises(self, p):
  331. return p.steps
  332. def initialize(self, p):
  333. self.model_wrap_cfg.mask = p.mask if hasattr(p, 'mask') else None
  334. self.model_wrap_cfg.nmask = p.nmask if hasattr(p, 'nmask') else None
  335. self.model_wrap.step = 0
  336. self.eta = p.eta or opts.eta_ancestral
  337. k_diffusion.sampling.torch = TorchHijack(self.sampler_noises if self.sampler_noises is not None else [])
  338. extra_params_kwargs = {}
  339. for param_name in self.extra_params:
  340. if hasattr(p, param_name) and param_name in inspect.signature(self.func).parameters:
  341. extra_params_kwargs[param_name] = getattr(p, param_name)
  342. if 'eta' in inspect.signature(self.func).parameters:
  343. extra_params_kwargs['eta'] = self.eta
  344. return extra_params_kwargs
  345. def get_sigmas(self, p, steps):
  346. if p.sampler_noise_scheduler_override:
  347. sigmas = p.sampler_noise_scheduler_override(steps)
  348. elif self.config is not None and self.config.options.get('scheduler', None) == 'karras':
  349. sigmas = k_diffusion.sampling.get_sigmas_karras(n=steps, sigma_min=0.1, sigma_max=10, device=shared.device)
  350. else:
  351. sigmas = self.model_wrap.get_sigmas(steps)
  352. if self.config is not None and self.config.options.get('discard_next_to_last_sigma', False):
  353. sigmas = torch.cat([sigmas[:-2], sigmas[-1:]])
  354. return sigmas
  355. def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
  356. steps, t_enc = setup_img2img_steps(p, steps)
  357. sigmas = self.get_sigmas(p, steps)
  358. sigma_sched = sigmas[steps - t_enc - 1:]
  359. xi = x + noise * sigma_sched[0]
  360. extra_params_kwargs = self.initialize(p)
  361. if 'sigma_min' in inspect.signature(self.func).parameters:
  362. ## last sigma is zero which isn't allowed by DPM Fast & Adaptive so taking value before last
  363. extra_params_kwargs['sigma_min'] = sigma_sched[-2]
  364. if 'sigma_max' in inspect.signature(self.func).parameters:
  365. extra_params_kwargs['sigma_max'] = sigma_sched[0]
  366. if 'n' in inspect.signature(self.func).parameters:
  367. extra_params_kwargs['n'] = len(sigma_sched) - 1
  368. if 'sigma_sched' in inspect.signature(self.func).parameters:
  369. extra_params_kwargs['sigma_sched'] = sigma_sched
  370. if 'sigmas' in inspect.signature(self.func).parameters:
  371. extra_params_kwargs['sigmas'] = sigma_sched
  372. self.model_wrap_cfg.init_latent = x
  373. self.last_latent = x
  374. samples = self.launch_sampling(t_enc + 1, lambda: self.func(self.model_wrap_cfg, xi, extra_args={
  375. 'cond': conditioning,
  376. 'image_cond': image_conditioning,
  377. 'uncond': unconditional_conditioning,
  378. 'cond_scale': p.cfg_scale
  379. }, disable=False, callback=self.callback_state, **extra_params_kwargs))
  380. return samples
  381. def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, image_conditioning = None):
  382. steps = steps or p.steps
  383. sigmas = self.get_sigmas(p, steps)
  384. x = x * sigmas[0]
  385. extra_params_kwargs = self.initialize(p)
  386. if 'sigma_min' in inspect.signature(self.func).parameters:
  387. extra_params_kwargs['sigma_min'] = self.model_wrap.sigmas[0].item()
  388. extra_params_kwargs['sigma_max'] = self.model_wrap.sigmas[-1].item()
  389. if 'n' in inspect.signature(self.func).parameters:
  390. extra_params_kwargs['n'] = steps
  391. else:
  392. extra_params_kwargs['sigmas'] = sigmas
  393. self.last_latent = x
  394. samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, x, extra_args={
  395. 'cond': conditioning,
  396. 'image_cond': image_conditioning,
  397. 'uncond': unconditional_conditioning,
  398. 'cond_scale': p.cfg_scale
  399. }, disable=False, callback=self.callback_state, **extra_params_kwargs))
  400. return samples