sd_samplers_timesteps_impl.py 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137
  1. import torch
  2. import tqdm
  3. import k_diffusion.sampling
  4. import numpy as np
  5. from modules import shared
  6. from modules.models.diffusion.uni_pc import uni_pc
  7. @torch.no_grad()
  8. def ddim(model, x, timesteps, extra_args=None, callback=None, disable=None, eta=0.0):
  9. alphas_cumprod = model.inner_model.inner_model.alphas_cumprod
  10. alphas = alphas_cumprod[timesteps]
  11. alphas_prev = alphas_cumprod[torch.nn.functional.pad(timesteps[:-1], pad=(1, 0))].to(torch.float64 if x.device.type != 'mps' and x.device.type != 'xpu' else torch.float32)
  12. sqrt_one_minus_alphas = torch.sqrt(1 - alphas)
  13. sigmas = eta * np.sqrt((1 - alphas_prev.cpu().numpy()) / (1 - alphas.cpu()) * (1 - alphas.cpu() / alphas_prev.cpu().numpy()))
  14. extra_args = {} if extra_args is None else extra_args
  15. s_in = x.new_ones((x.shape[0]))
  16. s_x = x.new_ones((x.shape[0], 1, 1, 1))
  17. for i in tqdm.trange(len(timesteps) - 1, disable=disable):
  18. index = len(timesteps) - 1 - i
  19. e_t = model(x, timesteps[index].item() * s_in, **extra_args)
  20. a_t = alphas[index].item() * s_x
  21. a_prev = alphas_prev[index].item() * s_x
  22. sigma_t = sigmas[index].item() * s_x
  23. sqrt_one_minus_at = sqrt_one_minus_alphas[index].item() * s_x
  24. pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
  25. dir_xt = (1. - a_prev - sigma_t ** 2).sqrt() * e_t
  26. noise = sigma_t * k_diffusion.sampling.torch.randn_like(x)
  27. x = a_prev.sqrt() * pred_x0 + dir_xt + noise
  28. if callback is not None:
  29. callback({'x': x, 'i': i, 'sigma': 0, 'sigma_hat': 0, 'denoised': pred_x0})
  30. return x
  31. @torch.no_grad()
  32. def plms(model, x, timesteps, extra_args=None, callback=None, disable=None):
  33. alphas_cumprod = model.inner_model.inner_model.alphas_cumprod
  34. alphas = alphas_cumprod[timesteps]
  35. alphas_prev = alphas_cumprod[torch.nn.functional.pad(timesteps[:-1], pad=(1, 0))].to(torch.float64 if x.device.type != 'mps' and x.device.type != 'xpu' else torch.float32)
  36. sqrt_one_minus_alphas = torch.sqrt(1 - alphas)
  37. extra_args = {} if extra_args is None else extra_args
  38. s_in = x.new_ones([x.shape[0]])
  39. s_x = x.new_ones((x.shape[0], 1, 1, 1))
  40. old_eps = []
  41. def get_x_prev_and_pred_x0(e_t, index):
  42. # select parameters corresponding to the currently considered timestep
  43. a_t = alphas[index].item() * s_x
  44. a_prev = alphas_prev[index].item() * s_x
  45. sqrt_one_minus_at = sqrt_one_minus_alphas[index].item() * s_x
  46. # current prediction for x_0
  47. pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
  48. # direction pointing to x_t
  49. dir_xt = (1. - a_prev).sqrt() * e_t
  50. x_prev = a_prev.sqrt() * pred_x0 + dir_xt
  51. return x_prev, pred_x0
  52. for i in tqdm.trange(len(timesteps) - 1, disable=disable):
  53. index = len(timesteps) - 1 - i
  54. ts = timesteps[index].item() * s_in
  55. t_next = timesteps[max(index - 1, 0)].item() * s_in
  56. e_t = model(x, ts, **extra_args)
  57. if len(old_eps) == 0:
  58. # Pseudo Improved Euler (2nd order)
  59. x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index)
  60. e_t_next = model(x_prev, t_next, **extra_args)
  61. e_t_prime = (e_t + e_t_next) / 2
  62. elif len(old_eps) == 1:
  63. # 2nd order Pseudo Linear Multistep (Adams-Bashforth)
  64. e_t_prime = (3 * e_t - old_eps[-1]) / 2
  65. elif len(old_eps) == 2:
  66. # 3nd order Pseudo Linear Multistep (Adams-Bashforth)
  67. e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12
  68. else:
  69. # 4nd order Pseudo Linear Multistep (Adams-Bashforth)
  70. e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24
  71. x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index)
  72. old_eps.append(e_t)
  73. if len(old_eps) >= 4:
  74. old_eps.pop(0)
  75. x = x_prev
  76. if callback is not None:
  77. callback({'x': x, 'i': i, 'sigma': 0, 'sigma_hat': 0, 'denoised': pred_x0})
  78. return x
  79. class UniPCCFG(uni_pc.UniPC):
  80. def __init__(self, cfg_model, extra_args, callback, *args, **kwargs):
  81. super().__init__(None, *args, **kwargs)
  82. def after_update(x, model_x):
  83. callback({'x': x, 'i': self.index, 'sigma': 0, 'sigma_hat': 0, 'denoised': model_x})
  84. self.index += 1
  85. self.cfg_model = cfg_model
  86. self.extra_args = extra_args
  87. self.callback = callback
  88. self.index = 0
  89. self.after_update = after_update
  90. def get_model_input_time(self, t_continuous):
  91. return (t_continuous - 1. / self.noise_schedule.total_N) * 1000.
  92. def model(self, x, t):
  93. t_input = self.get_model_input_time(t)
  94. res = self.cfg_model(x, t_input, **self.extra_args)
  95. return res
  96. def unipc(model, x, timesteps, extra_args=None, callback=None, disable=None, is_img2img=False):
  97. alphas_cumprod = model.inner_model.inner_model.alphas_cumprod
  98. ns = uni_pc.NoiseScheduleVP('discrete', alphas_cumprod=alphas_cumprod)
  99. t_start = timesteps[-1] / 1000 + 1 / 1000 if is_img2img else None # this is likely off by a bit - if someone wants to fix it please by all means
  100. unipc_sampler = UniPCCFG(model, extra_args, callback, ns, predict_x0=True, thresholding=False, variant=shared.opts.uni_pc_variant)
  101. x = unipc_sampler.sample(x, steps=len(timesteps), t_start=t_start, skip_type=shared.opts.uni_pc_skip_type, method="multistep", order=shared.opts.uni_pc_order, lower_order_final=shared.opts.uni_pc_lower_order_final)
  102. return x