sd_samplers_timesteps_impl.py 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178
  1. import torch
  2. import tqdm
  3. import k_diffusion.sampling
  4. import numpy as np
  5. from modules import shared
  6. from modules.models.diffusion.uni_pc import uni_pc
  7. from modules.torch_utils import float64
  8. @torch.no_grad()
  9. def ddim(model, x, timesteps, extra_args=None, callback=None, disable=None, eta=0.0):
  10. alphas_cumprod = model.inner_model.inner_model.alphas_cumprod
  11. alphas = alphas_cumprod[timesteps]
  12. alphas_prev = alphas_cumprod[torch.nn.functional.pad(timesteps[:-1], pad=(1, 0))].to(float64(x))
  13. sqrt_one_minus_alphas = torch.sqrt(1 - alphas)
  14. sigmas = eta * np.sqrt((1 - alphas_prev.cpu().numpy()) / (1 - alphas.cpu()) * (1 - alphas.cpu() / alphas_prev.cpu().numpy()))
  15. extra_args = {} if extra_args is None else extra_args
  16. s_in = x.new_ones((x.shape[0]))
  17. s_x = x.new_ones((x.shape[0], 1, 1, 1))
  18. for i in tqdm.trange(len(timesteps) - 1, disable=disable):
  19. index = len(timesteps) - 1 - i
  20. e_t = model(x, timesteps[index].item() * s_in, **extra_args)
  21. a_t = alphas[index].item() * s_x
  22. a_prev = alphas_prev[index].item() * s_x
  23. sigma_t = sigmas[index].item() * s_x
  24. sqrt_one_minus_at = sqrt_one_minus_alphas[index].item() * s_x
  25. pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
  26. dir_xt = (1. - a_prev - sigma_t ** 2).sqrt() * e_t
  27. noise = sigma_t * k_diffusion.sampling.torch.randn_like(x)
  28. x = a_prev.sqrt() * pred_x0 + dir_xt + noise
  29. if callback is not None:
  30. callback({'x': x, 'i': i, 'sigma': 0, 'sigma_hat': 0, 'denoised': pred_x0})
  31. return x
  32. @torch.no_grad()
  33. def ddim_cfgpp(model, x, timesteps, extra_args=None, callback=None, disable=None, eta=0.0):
  34. """ Implements CFG++: Manifold-constrained Classifier Free Guidance For Diffusion Models (2024).
  35. Uses the unconditional noise prediction instead of the conditional noise to guide the denoising direction.
  36. The CFG scale is divided by 12.5 to map CFG from [0.0, 12.5] to [0, 1.0].
  37. """
  38. alphas_cumprod = model.inner_model.inner_model.alphas_cumprod
  39. alphas = alphas_cumprod[timesteps]
  40. alphas_prev = alphas_cumprod[torch.nn.functional.pad(timesteps[:-1], pad=(1, 0))].to(float64(x))
  41. sqrt_one_minus_alphas = torch.sqrt(1 - alphas)
  42. sigmas = eta * np.sqrt((1 - alphas_prev.cpu().numpy()) / (1 - alphas.cpu()) * (1 - alphas.cpu() / alphas_prev.cpu().numpy()))
  43. model.cond_scale_miltiplier = 1 / 12.5
  44. model.need_last_noise_uncond = True
  45. extra_args = {} if extra_args is None else extra_args
  46. s_in = x.new_ones((x.shape[0]))
  47. s_x = x.new_ones((x.shape[0], 1, 1, 1))
  48. for i in tqdm.trange(len(timesteps) - 1, disable=disable):
  49. index = len(timesteps) - 1 - i
  50. e_t = model(x, timesteps[index].item() * s_in, **extra_args)
  51. last_noise_uncond = model.last_noise_uncond
  52. a_t = alphas[index].item() * s_x
  53. a_prev = alphas_prev[index].item() * s_x
  54. sigma_t = sigmas[index].item() * s_x
  55. sqrt_one_minus_at = sqrt_one_minus_alphas[index].item() * s_x
  56. pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
  57. dir_xt = (1. - a_prev - sigma_t ** 2).sqrt() * last_noise_uncond
  58. noise = sigma_t * k_diffusion.sampling.torch.randn_like(x)
  59. x = a_prev.sqrt() * pred_x0 + dir_xt + noise
  60. if callback is not None:
  61. callback({'x': x, 'i': i, 'sigma': 0, 'sigma_hat': 0, 'denoised': pred_x0})
  62. return x
  63. @torch.no_grad()
  64. def plms(model, x, timesteps, extra_args=None, callback=None, disable=None):
  65. alphas_cumprod = model.inner_model.inner_model.alphas_cumprod
  66. alphas = alphas_cumprod[timesteps]
  67. alphas_prev = alphas_cumprod[torch.nn.functional.pad(timesteps[:-1], pad=(1, 0))].to(float64(x))
  68. sqrt_one_minus_alphas = torch.sqrt(1 - alphas)
  69. extra_args = {} if extra_args is None else extra_args
  70. s_in = x.new_ones([x.shape[0]])
  71. s_x = x.new_ones((x.shape[0], 1, 1, 1))
  72. old_eps = []
  73. def get_x_prev_and_pred_x0(e_t, index):
  74. # select parameters corresponding to the currently considered timestep
  75. a_t = alphas[index].item() * s_x
  76. a_prev = alphas_prev[index].item() * s_x
  77. sqrt_one_minus_at = sqrt_one_minus_alphas[index].item() * s_x
  78. # current prediction for x_0
  79. pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
  80. # direction pointing to x_t
  81. dir_xt = (1. - a_prev).sqrt() * e_t
  82. x_prev = a_prev.sqrt() * pred_x0 + dir_xt
  83. return x_prev, pred_x0
  84. for i in tqdm.trange(len(timesteps) - 1, disable=disable):
  85. index = len(timesteps) - 1 - i
  86. ts = timesteps[index].item() * s_in
  87. t_next = timesteps[max(index - 1, 0)].item() * s_in
  88. e_t = model(x, ts, **extra_args)
  89. if len(old_eps) == 0:
  90. # Pseudo Improved Euler (2nd order)
  91. x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index)
  92. e_t_next = model(x_prev, t_next, **extra_args)
  93. e_t_prime = (e_t + e_t_next) / 2
  94. elif len(old_eps) == 1:
  95. # 2nd order Pseudo Linear Multistep (Adams-Bashforth)
  96. e_t_prime = (3 * e_t - old_eps[-1]) / 2
  97. elif len(old_eps) == 2:
  98. # 3nd order Pseudo Linear Multistep (Adams-Bashforth)
  99. e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12
  100. else:
  101. # 4nd order Pseudo Linear Multistep (Adams-Bashforth)
  102. e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24
  103. x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index)
  104. old_eps.append(e_t)
  105. if len(old_eps) >= 4:
  106. old_eps.pop(0)
  107. x = x_prev
  108. if callback is not None:
  109. callback({'x': x, 'i': i, 'sigma': 0, 'sigma_hat': 0, 'denoised': pred_x0})
  110. return x
  111. class UniPCCFG(uni_pc.UniPC):
  112. def __init__(self, cfg_model, extra_args, callback, *args, **kwargs):
  113. super().__init__(None, *args, **kwargs)
  114. def after_update(x, model_x):
  115. callback({'x': x, 'i': self.index, 'sigma': 0, 'sigma_hat': 0, 'denoised': model_x})
  116. self.index += 1
  117. self.cfg_model = cfg_model
  118. self.extra_args = extra_args
  119. self.callback = callback
  120. self.index = 0
  121. self.after_update = after_update
  122. def get_model_input_time(self, t_continuous):
  123. return (t_continuous - 1. / self.noise_schedule.total_N) * 1000.
  124. def model(self, x, t):
  125. t_input = self.get_model_input_time(t)
  126. res = self.cfg_model(x, t_input, **self.extra_args)
  127. return res
  128. def unipc(model, x, timesteps, extra_args=None, callback=None, disable=None, is_img2img=False):
  129. alphas_cumprod = model.inner_model.inner_model.alphas_cumprod
  130. ns = uni_pc.NoiseScheduleVP('discrete', alphas_cumprod=alphas_cumprod)
  131. t_start = timesteps[-1] / 1000 + 1 / 1000 if is_img2img else None # this is likely off by a bit - if someone wants to fix it please by all means
  132. unipc_sampler = UniPCCFG(model, extra_args, callback, ns, predict_x0=True, thresholding=False, variant=shared.opts.uni_pc_variant)
  133. x = unipc_sampler.sample(x, steps=len(timesteps), t_start=t_start, skip_type=shared.opts.uni_pc_skip_type, method="multistep", order=shared.opts.uni_pc_order, lower_order_final=shared.opts.uni_pc_lower_order_final)
  134. return x