Bläddra i källkod

add new sampler DDIM CFG++

v0xie 1 år sedan
förälder
incheckning
663a4d80df

+ 10 - 0
modules/sd_samplers_cfg_denoiser.py

@@ -58,6 +58,8 @@ class CFGDenoiser(torch.nn.Module):
         self.model_wrap = None
         self.p = None
 
+        self.last_noise_uncond = None
+
         # NOTE: masking before denoising can cause the original latents to be oversmoothed
         # as the original latents do not have noise
         self.mask_before_denoising = False
@@ -160,6 +162,8 @@ class CFGDenoiser(torch.nn.Module):
         # so is_edit_model is set to False to support AND composition.
         is_edit_model = shared.sd_model.cond_stage_key == "edit" and self.image_cfg_scale is not None and self.image_cfg_scale != 1.0
 
+        is_cfg_pp = 'CFG++' in self.sampler.config.name
+
         conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step)
         uncond = prompt_parser.reconstruct_cond_batch(uncond, self.step)
 
@@ -273,10 +277,16 @@ class CFGDenoiser(torch.nn.Module):
         denoised_params = CFGDenoisedParams(x_out, state.sampling_step, state.sampling_steps, self.inner_model)
         cfg_denoised_callback(denoised_params)
 
+        if is_cfg_pp:
+            self.last_noise_uncond = x_out[-uncond.shape[0]:]
+            self.last_noise_uncond = torch.clone(self.last_noise_uncond)
+
         if is_edit_model:
             denoised = self.combine_denoised_for_edit_model(x_out, cond_scale)
         elif skip_uncond:
             denoised = self.combine_denoised(x_out, conds_list, uncond, 1.0)
+        elif is_cfg_pp:
+            denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale/12.5) # CFG++ scale of (0, 1) maps to (1.0, 12.5)
         else:
             denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale)
 

+ 1 - 0
modules/sd_samplers_timesteps.py

@@ -10,6 +10,7 @@ import modules.shared as shared
 
 samplers_timesteps = [
     ('DDIM', sd_samplers_timesteps_impl.ddim, ['ddim'], {}),
+    ('DDIM CFG++', sd_samplers_timesteps_impl.ddim_cfgpp, ['ddim_cfgpp'], {}),
     ('PLMS', sd_samplers_timesteps_impl.plms, ['plms'], {}),
     ('UniPC', sd_samplers_timesteps_impl.unipc, ['unipc'], {}),
 ]

+ 37 - 0
modules/sd_samplers_timesteps_impl.py

@@ -40,6 +40,43 @@ def ddim(model, x, timesteps, extra_args=None, callback=None, disable=None, eta=
     return x
 
 
+@torch.no_grad()
+def ddim_cfgpp(model, x, timesteps, extra_args=None, callback=None, disable=None, eta=0.0):
+    """ Implements CFG++: Manifold-constrained Classifier Free Guidance For Diffusion Models (2024).
+    Uses the unconditional noise prediction instead of the conditional noise to guide the denoising direction.
+    The CFG scale is divided by 12.5 to map CFG from [0.0, 12.5] to [0, 1.0].
+    """
+    alphas_cumprod = model.inner_model.inner_model.alphas_cumprod
+    alphas = alphas_cumprod[timesteps]
+    alphas_prev = alphas_cumprod[torch.nn.functional.pad(timesteps[:-1], pad=(1, 0))].to(float64(x))
+    sqrt_one_minus_alphas = torch.sqrt(1 - alphas)
+    sigmas = eta * np.sqrt((1 - alphas_prev.cpu().numpy()) / (1 - alphas.cpu()) * (1 - alphas.cpu() / alphas_prev.cpu().numpy()))
+
+    extra_args = {} if extra_args is None else extra_args
+    s_in = x.new_ones((x.shape[0]))
+    s_x = x.new_ones((x.shape[0], 1, 1, 1))
+    for i in tqdm.trange(len(timesteps) - 1, disable=disable):
+        index = len(timesteps) - 1 - i
+
+        e_t = model(x, timesteps[index].item() * s_in, **extra_args)
+        last_noise_uncond = model.last_noise_uncond
+
+        a_t = alphas[index].item() * s_x
+        a_prev = alphas_prev[index].item() * s_x
+        sigma_t = sigmas[index].item() * s_x
+        sqrt_one_minus_at = sqrt_one_minus_alphas[index].item() * s_x
+
+        pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
+        dir_xt = (1. - a_prev - sigma_t ** 2).sqrt() * last_noise_uncond
+        noise = sigma_t * k_diffusion.sampling.torch.randn_like(x)
+        x = a_prev.sqrt() * pred_x0 + dir_xt + noise
+
+        if callback is not None:
+            callback({'x': x, 'i': i, 'sigma': 0, 'sigma_hat': 0, 'denoised': pred_x0})
+
+    return x
+
+
 @torch.no_grad()
 def plms(model, x, timesteps, extra_args=None, callback=None, disable=None):
     alphas_cumprod = model.inner_model.inner_model.alphas_cumprod