|
@@ -40,6 +40,43 @@ def ddim(model, x, timesteps, extra_args=None, callback=None, disable=None, eta=
|
|
|
return x
|
|
|
|
|
|
|
|
|
+@torch.no_grad()
|
|
|
+def ddim_cfgpp(model, x, timesteps, extra_args=None, callback=None, disable=None, eta=0.0):
|
|
|
+ """ Implements CFG++: Manifold-constrained Classifier Free Guidance For Diffusion Models (2024).
|
|
|
+ Uses the unconditional noise prediction instead of the conditional noise to guide the denoising direction.
|
|
|
+ The CFG scale is divided by 12.5 to map CFG from [0.0, 12.5] to [0, 1.0].
|
|
|
+ """
|
|
|
+ alphas_cumprod = model.inner_model.inner_model.alphas_cumprod
|
|
|
+ alphas = alphas_cumprod[timesteps]
|
|
|
+ alphas_prev = alphas_cumprod[torch.nn.functional.pad(timesteps[:-1], pad=(1, 0))].to(float64(x))
|
|
|
+ sqrt_one_minus_alphas = torch.sqrt(1 - alphas)
|
|
|
+ sigmas = eta * np.sqrt((1 - alphas_prev.cpu().numpy()) / (1 - alphas.cpu()) * (1 - alphas.cpu() / alphas_prev.cpu().numpy()))
|
|
|
+
|
|
|
+ extra_args = {} if extra_args is None else extra_args
|
|
|
+ s_in = x.new_ones((x.shape[0]))
|
|
|
+ s_x = x.new_ones((x.shape[0], 1, 1, 1))
|
|
|
+ for i in tqdm.trange(len(timesteps) - 1, disable=disable):
|
|
|
+ index = len(timesteps) - 1 - i
|
|
|
+
|
|
|
+ e_t = model(x, timesteps[index].item() * s_in, **extra_args)
|
|
|
+ last_noise_uncond = model.last_noise_uncond
|
|
|
+
|
|
|
+ a_t = alphas[index].item() * s_x
|
|
|
+ a_prev = alphas_prev[index].item() * s_x
|
|
|
+ sigma_t = sigmas[index].item() * s_x
|
|
|
+ sqrt_one_minus_at = sqrt_one_minus_alphas[index].item() * s_x
|
|
|
+
|
|
|
+ pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
|
|
|
+ dir_xt = (1. - a_prev - sigma_t ** 2).sqrt() * last_noise_uncond
|
|
|
+ noise = sigma_t * k_diffusion.sampling.torch.randn_like(x)
|
|
|
+ x = a_prev.sqrt() * pred_x0 + dir_xt + noise
|
|
|
+
|
|
|
+ if callback is not None:
|
|
|
+ callback({'x': x, 'i': i, 'sigma': 0, 'sigma_hat': 0, 'denoised': pred_x0})
|
|
|
+
|
|
|
+ return x
|
|
|
+
|
|
|
+
|
|
|
@torch.no_grad()
|
|
|
def plms(model, x, timesteps, extra_args=None, callback=None, disable=None):
|
|
|
alphas_cumprod = model.inner_model.inner_model.alphas_cumprod
|