Kaynağa Gözat

Implements "scheduling" for blending of the original latents and a latent blending formula that preserves details in blend transition areas.

CodeHatchling 1 yıl önce
ebeveyn
işleme
e715e46b6a
1 değiştirilmiş dosya ile 59 ekleme ve 2 silme
  1. 59 2
      modules/sd_samplers_cfg_denoiser.py

+ 59 - 2
modules/sd_samplers_cfg_denoiser.py

@@ -43,6 +43,9 @@ class CFGDenoiser(torch.nn.Module):
         self.model_wrap = None
         self.mask = None
         self.nmask = None
+        self.mask_blend_power = 1
+        self.mask_blend_scale = 1
+        self.mask_blend_offset = 0
         self.init_latent = None
         self.steps = None
         """number of steps as specified by user in UI"""
@@ -56,6 +59,9 @@ class CFGDenoiser(torch.nn.Module):
         self.sampler = sampler
         self.model_wrap = None
         self.p = None
+
+        # NOTE: masking before denoising can cause the original latents to be oversmoothed
+        # as the original latents do not have noise
         self.mask_before_denoising = False
 
     @property
@@ -89,6 +95,55 @@ class CFGDenoiser(torch.nn.Module):
         self.sampler.sampler_extra_args['uncond'] = uc
 
     def forward(self, x, sigma, uncond, cond, cond_scale, s_min_uncond, image_cond):
+        def latent_blend(a, b, t):
+            """
+            Interpolates two latent image representations according to the parameter t,
+            where the interpolated vectors' magnitudes are also interpolated separately.
+            The "detail_preservation" factor biases the magnitude interpolation towards
+            the larger of the two magnitudes.
+            """
+            # Record the original latent vector magnitudes.
+            # We bring them to a power so that larger magnitudes are favored over smaller ones.
+            # 64-bit operations are used here to allow large exponents.
+            detail_preservation = 32
+            a_magnitude = torch.norm(a, p=2, dim=1).to(torch.float64) ** detail_preservation
+            b_magnitude = torch.norm(b, p=2, dim=1).to(torch.float64) ** detail_preservation
+
+            one_minus_t = 1 - t
+
+            # Interpolate the powered magnitudes, then un-power them (bring them back to a power of 1).
+            interp_magnitude = (a_magnitude * one_minus_t + b_magnitude * t) ** (1 / detail_preservation)
+
+            # Linearly interpolate the image vectors.
+            image_interp = a * one_minus_t + b * t
+
+            # Calculate the magnitude of the interpolated vectors. (We will remove this magnitude.)
+            # 64-bit operations are used here to allow large exponents.
+            image_interp_magnitude = torch.norm(image_interp, p=2, dim=1).to(torch.float64) + 0.0001
+
+            # Change the linearly interpolated image vectors' magnitudes to the value we want.
+            # This is the last 64-bit operation.
+            image_interp *= (interp_magnitude / image_interp_magnitude).to(image_interp.dtype)
+
+            return image_interp
+
+        def get_modified_nmask(nmask, _sigma):
+            """
+            Converts a negative mask representing the transparency of the original latent vectors being overlayed
+            to a mask that is scaled according to the denoising strength for this step.
+
+            Where:
+                0 = fully opaque, infinite density, fully masked
+                1 = fully transparent, zero density, fully unmasked
+
+            We bring this transparency to a power, as this allows one to simulate N number of blending operations
+            where N can be any positive real value. Using this one can control the balance of influence between
+            the denoiser and the original latents according to the sigma value.
+
+            NOTE: "mask" is not used
+            """
+            return torch.pow(nmask, (_sigma ** self.mask_blend_power) * self.mask_blend_scale + self.mask_blend_offset)
+
         if state.interrupted or state.skipped:
             raise sd_samplers_common.InterruptedException
 
@@ -105,8 +160,9 @@ class CFGDenoiser(torch.nn.Module):
 
         assert not is_edit_model or all(len(conds) == 1 for conds in conds_list), "AND is not supported for InstructPix2Pix checkpoint (unless using Image CFG scale = 1.0)"
 
+        # Blend in the original latents (before)
         if self.mask_before_denoising and self.mask is not None:
-            x = self.init_latent * self.mask + self.nmask * x
+            x = latent_blend(self.init_latent, x, get_modified_nmask(self.nmask, sigma))
 
         batch_size = len(conds_list)
         repeats = [len(conds_list[i]) for i in range(batch_size)]
@@ -207,8 +263,9 @@ class CFGDenoiser(torch.nn.Module):
         else:
             denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale)
 
+        # Blend in the original latents (after)
         if not self.mask_before_denoising and self.mask is not None:
-            denoised = self.init_latent * self.mask + self.nmask * denoised
+            denoised = latent_blend(self.init_latent, denoised, get_modified_nmask(self.nmask, sigma))
 
         self.sampler.last_latent = self.get_pred_x0(torch.cat([x_in[i:i + 1] for i in denoised_image_indexes]), torch.cat([x_out[i:i + 1] for i in denoised_image_indexes]), sigma)