sd_samplers_common.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355
  1. import inspect
  2. from collections import namedtuple
  3. import numpy as np
  4. import torch
  5. from PIL import Image
  6. from modules import devices, images, sd_vae_approx, sd_samplers, sd_vae_taesd, shared, sd_models
  7. from modules.shared import opts, state
  8. import k_diffusion.sampling
  9. SamplerDataTuple = namedtuple('SamplerData', ['name', 'constructor', 'aliases', 'options'])
  10. class SamplerData(SamplerDataTuple):
  11. def total_steps(self, steps):
  12. if self.options.get("second_order", False):
  13. steps = steps * 2
  14. return steps
  15. def setup_img2img_steps(p, steps=None):
  16. if opts.img2img_fix_steps or steps is not None:
  17. requested_steps = (steps or p.steps)
  18. steps = int(requested_steps / min(p.denoising_strength, 0.999)) if p.denoising_strength > 0 else 0
  19. t_enc = requested_steps - 1
  20. else:
  21. steps = p.steps
  22. t_enc = int(min(p.denoising_strength, 0.999) * steps)
  23. return steps, t_enc
  24. approximation_indexes = {"Full": 0, "Approx NN": 1, "Approx cheap": 2, "TAESD": 3}
  25. def samples_to_images_tensor(sample, approximation=None, model=None):
  26. """Transforms 4-channel latent space images into 3-channel RGB image tensors, with values in range [-1, 1]."""
  27. if approximation is None or (shared.state.interrupted and opts.live_preview_fast_interrupt):
  28. approximation = approximation_indexes.get(opts.show_progress_type, 0)
  29. from modules import lowvram
  30. if approximation == 0 and lowvram.is_enabled(shared.sd_model) and not shared.opts.live_preview_allow_lowvram_full:
  31. approximation = 1
  32. if approximation == 2:
  33. x_sample = sd_vae_approx.cheap_approximation(sample)
  34. elif approximation == 1:
  35. x_sample = sd_vae_approx.model()(sample.to(devices.device, devices.dtype)).detach()
  36. elif approximation == 3:
  37. x_sample = sd_vae_taesd.decoder_model()(sample.to(devices.device, devices.dtype)).detach()
  38. x_sample = x_sample * 2 - 1
  39. else:
  40. if model is None:
  41. model = shared.sd_model
  42. with devices.without_autocast(): # fixes an issue with unstable VAEs that are flaky even in fp32
  43. x_sample = model.decode_first_stage(sample.to(model.first_stage_model.dtype))
  44. return x_sample
  45. def single_sample_to_image(sample, approximation=None):
  46. x_sample = samples_to_images_tensor(sample.unsqueeze(0), approximation)[0] * 0.5 + 0.5
  47. x_sample = torch.clamp(x_sample, min=0.0, max=1.0)
  48. x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
  49. x_sample = x_sample.astype(np.uint8)
  50. return Image.fromarray(x_sample)
  51. def decode_first_stage(model, x):
  52. x = x.to(devices.dtype_vae)
  53. approx_index = approximation_indexes.get(opts.sd_vae_decode_method, 0)
  54. return samples_to_images_tensor(x, approx_index, model)
  55. def sample_to_image(samples, index=0, approximation=None):
  56. return single_sample_to_image(samples[index], approximation)
  57. def samples_to_image_grid(samples, approximation=None):
  58. return images.image_grid([single_sample_to_image(sample, approximation) for sample in samples])
  59. def images_tensor_to_samples(image, approximation=None, model=None):
  60. '''image[0, 1] -> latent'''
  61. if approximation is None:
  62. approximation = approximation_indexes.get(opts.sd_vae_encode_method, 0)
  63. if approximation == 3:
  64. image = image.to(devices.device, devices.dtype)
  65. x_latent = sd_vae_taesd.encoder_model()(image)
  66. else:
  67. if model is None:
  68. model = shared.sd_model
  69. model.first_stage_model.to(devices.dtype_vae)
  70. image = image.to(shared.device, dtype=devices.dtype_vae)
  71. image = image * 2 - 1
  72. if len(image) > 1:
  73. x_latent = torch.stack([
  74. model.get_first_stage_encoding(
  75. model.encode_first_stage(torch.unsqueeze(img, 0))
  76. )[0]
  77. for img in image
  78. ])
  79. else:
  80. x_latent = model.get_first_stage_encoding(model.encode_first_stage(image))
  81. return x_latent
  82. def store_latent(decoded):
  83. state.current_latent = decoded
  84. if opts.live_previews_enable and opts.show_progress_every_n_steps > 0 and shared.state.sampling_step % opts.show_progress_every_n_steps == 0:
  85. if not shared.parallel_processing_allowed:
  86. shared.state.assign_current_image(sample_to_image(decoded))
  87. def is_sampler_using_eta_noise_seed_delta(p):
  88. """returns whether sampler from config will use eta noise seed delta for image creation"""
  89. sampler_config = sd_samplers.find_sampler_config(p.sampler_name)
  90. eta = p.eta
  91. if eta is None and p.sampler is not None:
  92. eta = p.sampler.eta
  93. if eta is None and sampler_config is not None:
  94. eta = 0 if sampler_config.options.get("default_eta_is_0", False) else 1.0
  95. if eta == 0:
  96. return False
  97. return sampler_config.options.get("uses_ensd", False)
  98. class InterruptedException(BaseException):
  99. pass
  100. def replace_torchsde_browinan():
  101. import torchsde._brownian.brownian_interval
  102. def torchsde_randn(size, dtype, device, seed):
  103. return devices.randn_local(seed, size).to(device=device, dtype=dtype)
  104. torchsde._brownian.brownian_interval._randn = torchsde_randn
  105. replace_torchsde_browinan()
  106. def apply_refiner(cfg_denoiser, sigma=None):
  107. if opts.refiner_switch_by_sample_steps or sigma is None:
  108. completed_ratio = cfg_denoiser.step / cfg_denoiser.total_steps
  109. cfg_denoiser.p.extra_generation_params["Refiner switch by sampling steps"] = True
  110. else:
  111. # torch.max(sigma) only to handle rare case where we might have different sigmas in the same batch
  112. try:
  113. timestep = torch.argmin(torch.abs(cfg_denoiser.inner_model.sigmas - torch.max(sigma)))
  114. except AttributeError: # for samplers that don't use sigmas (DDIM) sigma is actually the timestep
  115. timestep = torch.max(sigma).to(dtype=int)
  116. completed_ratio = (999 - timestep) / 1000
  117. refiner_switch_at = cfg_denoiser.p.refiner_switch_at
  118. refiner_checkpoint_info = cfg_denoiser.p.refiner_checkpoint_info
  119. if refiner_switch_at is not None and completed_ratio < refiner_switch_at:
  120. return False
  121. if refiner_checkpoint_info is None or shared.sd_model.sd_checkpoint_info == refiner_checkpoint_info:
  122. return False
  123. if getattr(cfg_denoiser.p, "enable_hr", False):
  124. is_second_pass = cfg_denoiser.p.is_hr_pass
  125. if opts.hires_fix_refiner_pass == "first pass" and is_second_pass:
  126. return False
  127. if opts.hires_fix_refiner_pass == "second pass" and not is_second_pass:
  128. return False
  129. if opts.hires_fix_refiner_pass != "second pass":
  130. cfg_denoiser.p.extra_generation_params['Hires refiner'] = opts.hires_fix_refiner_pass
  131. cfg_denoiser.p.extra_generation_params['Refiner'] = refiner_checkpoint_info.short_title
  132. cfg_denoiser.p.extra_generation_params['Refiner switch at'] = refiner_switch_at
  133. with sd_models.SkipWritingToConfig():
  134. sd_models.reload_model_weights(info=refiner_checkpoint_info)
  135. devices.torch_gc()
  136. cfg_denoiser.p.setup_conds()
  137. cfg_denoiser.update_inner_model()
  138. return True
  139. class TorchHijack:
  140. """This is here to replace torch.randn_like of k-diffusion.
  141. k-diffusion has random_sampler argument for most samplers, but not for all, so
  142. this is needed to properly replace every use of torch.randn_like.
  143. We need to replace to make images generated in batches to be same as images generated individually."""
  144. def __init__(self, p):
  145. self.rng = p.rng
  146. def __getattr__(self, item):
  147. if item == 'randn_like':
  148. return self.randn_like
  149. if hasattr(torch, item):
  150. return getattr(torch, item)
  151. raise AttributeError(f"'{type(self).__name__}' object has no attribute '{item}'")
  152. def randn_like(self, x):
  153. return self.rng.next()
  154. class Sampler:
  155. def __init__(self, funcname):
  156. self.funcname = funcname
  157. self.func = funcname
  158. self.extra_params = []
  159. self.sampler_noises = None
  160. self.stop_at = None
  161. self.eta = None
  162. self.config: SamplerData = None # set by the function calling the constructor
  163. self.last_latent = None
  164. self.s_min_uncond = None
  165. self.s_churn = 0.0
  166. self.s_tmin = 0.0
  167. self.s_tmax = float('inf')
  168. self.s_noise = 1.0
  169. self.eta_option_field = 'eta_ancestral'
  170. self.eta_infotext_field = 'Eta'
  171. self.eta_default = 1.0
  172. self.conditioning_key = shared.sd_model.model.conditioning_key
  173. self.p = None
  174. self.model_wrap_cfg = None
  175. self.sampler_extra_args = None
  176. self.options = {}
  177. def callback_state(self, d):
  178. step = d['i']
  179. if self.stop_at is not None and step > self.stop_at:
  180. raise InterruptedException
  181. state.sampling_step = step
  182. shared.total_tqdm.update()
  183. def launch_sampling(self, steps, func):
  184. self.model_wrap_cfg.steps = steps
  185. self.model_wrap_cfg.total_steps = self.config.total_steps(steps)
  186. state.sampling_steps = steps
  187. state.sampling_step = 0
  188. try:
  189. return func()
  190. except RecursionError:
  191. print(
  192. 'Encountered RecursionError during sampling, returning last latent. '
  193. 'rho >5 with a polyexponential scheduler may cause this error. '
  194. 'You should try to use a smaller rho value instead.'
  195. )
  196. return self.last_latent
  197. except InterruptedException:
  198. return self.last_latent
  199. def number_of_needed_noises(self, p):
  200. return p.steps
  201. def initialize(self, p) -> dict:
  202. self.p = p
  203. self.model_wrap_cfg.p = p
  204. self.model_wrap_cfg.mask = p.mask if hasattr(p, 'mask') else None
  205. self.model_wrap_cfg.nmask = p.nmask if hasattr(p, 'nmask') else None
  206. self.model_wrap_cfg.step = 0
  207. self.model_wrap_cfg.image_cfg_scale = getattr(p, 'image_cfg_scale', None)
  208. self.eta = p.eta if p.eta is not None else getattr(opts, self.eta_option_field, 0.0)
  209. self.s_min_uncond = getattr(p, 's_min_uncond', 0.0)
  210. k_diffusion.sampling.torch = TorchHijack(p)
  211. extra_params_kwargs = {}
  212. for param_name in self.extra_params:
  213. if hasattr(p, param_name) and param_name in inspect.signature(self.func).parameters:
  214. extra_params_kwargs[param_name] = getattr(p, param_name)
  215. if 'eta' in inspect.signature(self.func).parameters:
  216. if self.eta != self.eta_default:
  217. p.extra_generation_params[self.eta_infotext_field] = self.eta
  218. extra_params_kwargs['eta'] = self.eta
  219. if len(self.extra_params) > 0:
  220. s_churn = getattr(opts, 's_churn', p.s_churn)
  221. s_tmin = getattr(opts, 's_tmin', p.s_tmin)
  222. s_tmax = getattr(opts, 's_tmax', p.s_tmax) or self.s_tmax # 0 = inf
  223. s_noise = getattr(opts, 's_noise', p.s_noise)
  224. if 's_churn' in extra_params_kwargs and s_churn != self.s_churn:
  225. extra_params_kwargs['s_churn'] = s_churn
  226. p.s_churn = s_churn
  227. p.extra_generation_params['Sigma churn'] = s_churn
  228. if 's_tmin' in extra_params_kwargs and s_tmin != self.s_tmin:
  229. extra_params_kwargs['s_tmin'] = s_tmin
  230. p.s_tmin = s_tmin
  231. p.extra_generation_params['Sigma tmin'] = s_tmin
  232. if 's_tmax' in extra_params_kwargs and s_tmax != self.s_tmax:
  233. extra_params_kwargs['s_tmax'] = s_tmax
  234. p.s_tmax = s_tmax
  235. p.extra_generation_params['Sigma tmax'] = s_tmax
  236. if 's_noise' in extra_params_kwargs and s_noise != self.s_noise:
  237. extra_params_kwargs['s_noise'] = s_noise
  238. p.s_noise = s_noise
  239. p.extra_generation_params['Sigma noise'] = s_noise
  240. return extra_params_kwargs
  241. def create_noise_sampler(self, x, sigmas, p):
  242. """For DPM++ SDE: manually create noise sampler to enable deterministic results across different batch sizes"""
  243. if shared.opts.no_dpmpp_sde_batch_determinism:
  244. return None
  245. from k_diffusion.sampling import BrownianTreeNoiseSampler
  246. sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
  247. current_iter_seeds = p.all_seeds[p.iteration * p.batch_size:(p.iteration + 1) * p.batch_size]
  248. return BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=current_iter_seeds)
  249. def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
  250. raise NotImplementedError()
  251. def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
  252. raise NotImplementedError()
  253. def add_infotext(self, p):
  254. if self.model_wrap_cfg.padded_cond_uncond:
  255. p.extra_generation_params["Pad conds"] = True
  256. if self.model_wrap_cfg.padded_cond_uncond_v0:
  257. p.extra_generation_params["Pad conds v0"] = True