Переглянути джерело

Removed soft inpainting, added hooks for softpainting to work instead.

CodeHatchling 1 рік тому
батько
коміт
ac45789123
3 змінених файлів з 118 додано та 69 видалено
  1. 38 56
      modules/processing.py
  2. 70 0
      modules/scripts.py
  3. 10 13
      modules/sd_samplers_cfg_denoiser.py

+ 38 - 56
modules/processing.py

@@ -30,7 +30,6 @@ import modules.sd_models as sd_models
 import modules.sd_vae as sd_vae
 import modules.sd_vae as sd_vae
 from ldm.data.util import AddMiDaS
 from ldm.data.util import AddMiDaS
 from ldm.models.diffusion.ddpm import LatentDepth2ImageDiffusion
 from ldm.models.diffusion.ddpm import LatentDepth2ImageDiffusion
-import modules.soft_inpainting as si
 
 
 from einops import repeat, rearrange
 from einops import repeat, rearrange
 from blendmodes.blend import blendLayers, BlendType
 from blendmodes.blend import blendLayers, BlendType
@@ -73,12 +72,10 @@ def uncrop(image, dest_size, paste_loc):
     return image
     return image
 
 
 
 
-def apply_overlay(image, paste_loc, index, overlays):
-    if overlays is None or index >= len(overlays):
+def apply_overlay(image, paste_loc, overlay):
+    if overlay is None:
         return image
         return image
 
 
-    overlay = overlays[index]
-
     if paste_loc is not None:
     if paste_loc is not None:
         image = uncrop(image, (overlay.width, overlay.height), paste_loc)
         image = uncrop(image, (overlay.width, overlay.height), paste_loc)
 
 
@@ -150,7 +147,6 @@ class StableDiffusionProcessing:
     do_not_save_grid: bool = False
     do_not_save_grid: bool = False
     extra_generation_params: dict[str, Any] = None
     extra_generation_params: dict[str, Any] = None
     overlay_images: list = None
     overlay_images: list = None
-    masks_for_overlay: list = None
     eta: float = None
     eta: float = None
     do_not_reload_embeddings: bool = False
     do_not_reload_embeddings: bool = False
     denoising_strength: float = None
     denoising_strength: float = None
@@ -880,31 +876,17 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
             with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast():
             with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast():
                 samples_ddim = p.sample(conditioning=p.c, unconditional_conditioning=p.uc, seeds=p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, prompts=p.prompts)
                 samples_ddim = p.sample(conditioning=p.c, unconditional_conditioning=p.uc, seeds=p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, prompts=p.prompts)
 
 
+            if p.scripts is not None:
+                ps = scripts.PostSampleArgs(samples_ddim)
+                p.scripts.post_sample(p, ps)
+                samples_ddim = pp.samples
+
             if getattr(samples_ddim, 'already_decoded', False):
             if getattr(samples_ddim, 'already_decoded', False):
                 x_samples_ddim = samples_ddim
                 x_samples_ddim = samples_ddim
-                # todo: generate adaptive masks based on pixel differences.
-                if getattr(p, "image_mask", None) is not None and getattr(p, "soft_inpainting", None) is not None:
-                    si.apply_masks(soft_inpainting=p.soft_inpainting,
-                                   nmask=p.nmask,
-                                   overlay_images=p.overlay_images,
-                                   masks_for_overlay=p.masks_for_overlay,
-                                   width=p.width,
-                                   height=p.height,
-                                   paste_to=p.paste_to)
             else:
             else:
                 if opts.sd_vae_decode_method != 'Full':
                 if opts.sd_vae_decode_method != 'Full':
                     p.extra_generation_params['VAE Decoder'] = opts.sd_vae_decode_method
                     p.extra_generation_params['VAE Decoder'] = opts.sd_vae_decode_method
 
 
-                # Generate the mask(s) based on similarity between the original and denoised latent vectors
-                if getattr(p, "image_mask", None) is not None and getattr(p, "soft_inpainting", None) is not None:
-                    si.apply_adaptive_masks(latent_orig=p.init_latent,
-                                            latent_processed=samples_ddim,
-                                            overlay_images=p.overlay_images,
-                                            masks_for_overlay=p.masks_for_overlay,
-                                            width=p.width,
-                                            height=p.height,
-                                            paste_to=p.paste_to)
-
                 x_samples_ddim = decode_latent_batch(p.sd_model, samples_ddim, target_device=devices.cpu, check_for_nans=True)
                 x_samples_ddim = decode_latent_batch(p.sd_model, samples_ddim, target_device=devices.cpu, check_for_nans=True)
 
 
             x_samples_ddim = torch.stack(x_samples_ddim).float()
             x_samples_ddim = torch.stack(x_samples_ddim).float()
@@ -955,9 +937,18 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
                     pp = scripts.PostprocessImageArgs(image)
                     pp = scripts.PostprocessImageArgs(image)
                     p.scripts.postprocess_image(p, pp)
                     p.scripts.postprocess_image(p, pp)
                     image = pp.image
                     image = pp.image
+
+                mask_for_overlay = p.mask_for_overlay
+                overlay_image = p.overlay_images[i] if p.overlay_images is not None and i < len(p.overlay_images) else None
+
+                if p.scripts is not None:
+                    ppmo = scripts.PostProcessMaskOverlayArgs(i, mask_for_overlay, overlay_image)
+                    p.scripts.postprocess_maskoverlay(p, ppmo)
+                    mask_for_overlay, overlay_image = pp.mask_for_overlay, pp.overlay_image
+
                 if p.color_corrections is not None and i < len(p.color_corrections):
                 if p.color_corrections is not None and i < len(p.color_corrections):
                     if save_samples and opts.save_images_before_color_correction:
                     if save_samples and opts.save_images_before_color_correction:
-                        image_without_cc = apply_overlay(image, p.paste_to, i, p.overlay_images)
+                        image_without_cc = apply_overlay(image, p.paste_to, overlay_image)
                         images.save_image(image_without_cc, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-before-color-correction")
                         images.save_image(image_without_cc, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-before-color-correction")
                     image = apply_color_correction(p.color_corrections[i], image)
                     image = apply_color_correction(p.color_corrections[i], image)
 
 
@@ -968,9 +959,9 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
                 original_denoised_image = image.copy()
                 original_denoised_image = image.copy()
 
 
                 if p.paste_to is not None:
                 if p.paste_to is not None:
-                    original_denoised_image = uncrop(original_denoised_image, (p.overlay_images[i].width, p.overlay_images[i].height), p.paste_to)
+                    original_denoised_image = uncrop(original_denoised_image, (p.overlay_image.width, p.overlay_image.height), p.paste_to)
 
 
-                image = apply_overlay(image, p.paste_to, i, p.overlay_images)
+                image = apply_overlay(image, p.paste_to, overlay_image)
 
 
                 if save_samples:
                 if save_samples:
                     images.save_image(image, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p)
                     images.save_image(image, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p)
@@ -981,13 +972,6 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
                     image.info["parameters"] = text
                     image.info["parameters"] = text
                 output_images.append(image)
                 output_images.append(image)
 
 
-                if hasattr(p, 'mask_for_overlay') and p.mask_for_overlay:
-                    mask_for_overlay = p.mask_for_overlay
-                elif hasattr(p, 'masks_for_overlay') and p.masks_for_overlay and p.masks_for_overlay[i]:
-                    mask_for_overlay = p.masks_for_overlay[i]
-                else:
-                    mask_for_overlay = None
-
                 if mask_for_overlay is not None:
                 if mask_for_overlay is not None:
                     if opts.return_mask or opts.save_mask:
                     if opts.return_mask or opts.save_mask:
                         image_mask = mask_for_overlay.convert('RGB')
                         image_mask = mask_for_overlay.convert('RGB')
@@ -1401,7 +1385,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
     mask_blur_x: int = 4
     mask_blur_x: int = 4
     mask_blur_y: int = 4
     mask_blur_y: int = 4
     mask_blur: int = None
     mask_blur: int = None
-    soft_inpainting: si.SoftInpaintingParameters = si.default
+    mask_round: bool = True
     inpainting_fill: int = 0
     inpainting_fill: int = 0
     inpaint_full_res: bool = True
     inpaint_full_res: bool = True
     inpaint_full_res_padding: int = 0
     inpaint_full_res_padding: int = 0
@@ -1447,7 +1431,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
         if image_mask is not None:
         if image_mask is not None:
             # image_mask is passed in as RGBA by Gradio to support alpha masks,
             # image_mask is passed in as RGBA by Gradio to support alpha masks,
             # but we still want to support binary masks.
             # but we still want to support binary masks.
-            image_mask = create_binary_mask(image_mask, round=(self.soft_inpainting is None))
+            image_mask = create_binary_mask(image_mask, round=self.mask_round)
 
 
             if self.inpainting_mask_invert:
             if self.inpainting_mask_invert:
                 image_mask = ImageOps.invert(image_mask)
                 image_mask = ImageOps.invert(image_mask)
@@ -1465,7 +1449,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
                 image_mask = Image.fromarray(np_mask)
                 image_mask = Image.fromarray(np_mask)
 
 
             if self.inpaint_full_res:
             if self.inpaint_full_res:
-                self.mask_for_overlay = image_mask if self.soft_inpainting is None else None
+                self.mask_for_overlay = image_mask
                 mask = image_mask.convert('L')
                 mask = image_mask.convert('L')
                 crop_region = masking.get_crop_region(np.array(mask), self.inpaint_full_res_padding)
                 crop_region = masking.get_crop_region(np.array(mask), self.inpaint_full_res_padding)
                 crop_region = masking.expand_crop_region(crop_region, self.width, self.height, mask.width, mask.height)
                 crop_region = masking.expand_crop_region(crop_region, self.width, self.height, mask.width, mask.height)
@@ -1476,13 +1460,10 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
                 self.paste_to = (x1, y1, x2-x1, y2-y1)
                 self.paste_to = (x1, y1, x2-x1, y2-y1)
             else:
             else:
                 image_mask = images.resize_image(self.resize_mode, image_mask, self.width, self.height)
                 image_mask = images.resize_image(self.resize_mode, image_mask, self.width, self.height)
+                np_mask = np.array(image_mask)
+                np_mask = np.clip((np_mask.astype(np.float32)) * 2, 0, 255).astype(np.uint8)
+                self.mask_for_overlay = Image.fromarray(np_mask)
 
 
-                if self.soft_inpainting is None:
-                    np_mask = np.array(image_mask)
-                    np_mask = np.clip((np_mask.astype(np.float32)) * 2, 0, 255).astype(np.uint8)
-                    self.mask_for_overlay = Image.fromarray(np_mask)
-
-            self.masks_for_overlay = [] if self.soft_inpainting is not None else None
             self.overlay_images = []
             self.overlay_images = []
 
 
         latent_mask = self.latent_mask if self.latent_mask is not None else image_mask
         latent_mask = self.latent_mask if self.latent_mask is not None else image_mask
@@ -1504,15 +1485,10 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
                 image = images.resize_image(self.resize_mode, image, self.width, self.height)
                 image = images.resize_image(self.resize_mode, image, self.width, self.height)
 
 
             if image_mask is not None:
             if image_mask is not None:
-                if self.soft_inpainting is not None:
-                    # We apply the masks AFTER to adjust mask based on changed content.
-                    self.overlay_images.append(image.convert('RGBA'))
-                    self.masks_for_overlay.append(image_mask)
-                else:
-                    image_masked = Image.new('RGBa', (image.width, image.height))
-                    image_masked.paste(image.convert("RGBA").convert("RGBa"), mask=ImageOps.invert(self.mask_for_overlay.convert('L')))
+                image_masked = Image.new('RGBa', (image.width, image.height))
+                image_masked.paste(image.convert("RGBA").convert("RGBa"), mask=ImageOps.invert(self.mask_for_overlay.convert('L')))
 
 
-                    self.overlay_images.append(image_masked.convert('RGBA'))
+                self.overlay_images.append(image_masked.convert('RGBA'))
 
 
             # crop_region is not None if we are doing inpaint full res
             # crop_region is not None if we are doing inpaint full res
             if crop_region is not None:
             if crop_region is not None:
@@ -1565,7 +1541,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
             latmask = init_mask.convert('RGB').resize((self.init_latent.shape[3], self.init_latent.shape[2]))
             latmask = init_mask.convert('RGB').resize((self.init_latent.shape[3], self.init_latent.shape[2]))
             latmask = np.moveaxis(np.array(latmask, dtype=np.float32), 2, 0) / 255
             latmask = np.moveaxis(np.array(latmask, dtype=np.float32), 2, 0) / 255
             latmask = latmask[0]
             latmask = latmask[0]
-            if self.soft_inpainting is None:
+            if self.mask_round:
                 latmask = np.around(latmask)
                 latmask = np.around(latmask)
             latmask = np.tile(latmask[None], (4, 1, 1))
             latmask = np.tile(latmask[None], (4, 1, 1))
 
 
@@ -1578,7 +1554,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
             elif self.inpainting_fill == 3:
             elif self.inpainting_fill == 3:
                 self.init_latent = self.init_latent * self.mask
                 self.init_latent = self.init_latent * self.mask
 
 
-        self.image_conditioning = self.img2img_image_conditioning(image * 2 - 1, self.init_latent, image_mask, self.soft_inpainting is None)
+        self.image_conditioning = self.img2img_image_conditioning(image * 2 - 1, self.init_latent, image_mask, self.mask_round)
 
 
     def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
     def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
         x = self.rng.next()
         x = self.rng.next()
@@ -1589,8 +1565,14 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
 
 
         samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning, image_conditioning=self.image_conditioning)
         samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning, image_conditioning=self.image_conditioning)
 
 
-        if self.mask is not None and self.soft_inpainting is None:
-            samples = samples * self.nmask + self.init_latent * self.mask
+        blended_samples = samples * self.nmask + self.init_latent * self.mask
+
+        if self.scripts is not None:
+            mba = scripts.MaskBlendArgs(self, samples, self.nmask, self.init_latent, self.mask, blended_samples, sigma=None, is_final_blend=True)
+            self.scripts.on_mask_blend(self, mba)
+            blended_samples = mba.blended_latent
+
+        samples = blended_samples
 
 
         del x
         del x
         devices.torch_gc()
         devices.torch_gc()

+ 70 - 0
modules/scripts.py

@@ -11,11 +11,31 @@ from modules import shared, paths, script_callbacks, extensions, script_loading,
 
 
 AlwaysVisible = object()
 AlwaysVisible = object()
 
 
+class MaskBlendArgs:
+    def __init__(self, current_latent, nmask, init_latent, mask, blended_samples, denoiser=None, sigma=None):
+        self.current_latent = current_latent
+        self.nmask = nmask
+        self.init_latent = init_latent
+        self.mask = mask
+        self.blended_samples = blended_samples
+
+        self.denoiser = denoiser
+        self.is_final_blend = denoiser is None
+        self.sigma = sigma
+
+class PostSampleArgs:
+    def __init__(self, samples):
+        self.samples = samples
 
 
 class PostprocessImageArgs:
 class PostprocessImageArgs:
     def __init__(self, image):
     def __init__(self, image):
         self.image = image
         self.image = image
 
 
+class PostProcessMaskOverlayArgs:
+    def __init__(self, index, mask_for_overlay, overlay_image):
+        self.index = index
+        self.mask_for_overlay = mask_for_overlay
+        self.overlay_image = overlay_image
 
 
 class PostprocessBatchListArgs:
 class PostprocessBatchListArgs:
     def __init__(self, images):
     def __init__(self, images):
@@ -206,6 +226,25 @@ class Script:
 
 
         pass
         pass
 
 
+    def on_mask_blend(self, p, mba: MaskBlendArgs, *args):
+        """
+        Called in inpainting mode when the original content is blended with the inpainted content.
+        This is called at every step in the denoising process and once at the end.
+        If is_final_blend is true, this is called for the final blending stage.
+        Otherwise, denoiser and sigma are defined and may be used to inform the procedure.
+        """
+
+        pass
+
+    def post_sample(self, p, ps: PostSampleArgs, *args):
+        """
+        Called after the samples have been generated,
+        but before they have been decoded by the VAE, if applicable.
+        Check getattr(samples, 'already_decoded', False) to test if the images are decoded.
+        """
+
+        pass
+
     def postprocess_image(self, p, pp: PostprocessImageArgs, *args):
     def postprocess_image(self, p, pp: PostprocessImageArgs, *args):
         """
         """
         Called for every image after it has been generated.
         Called for every image after it has been generated.
@@ -213,6 +252,13 @@ class Script:
 
 
         pass
         pass
 
 
+    def postprocess_maskoverlay(self, p, ppmo: PostProcessMaskOverlayArgs, *args):
+        """
+        Called for every image after it has been generated.
+        """
+
+        pass
+
     def postprocess(self, p, processed, *args):
     def postprocess(self, p, processed, *args):
         """
         """
         This function is called after processing ends for AlwaysVisible scripts.
         This function is called after processing ends for AlwaysVisible scripts.
@@ -767,6 +813,22 @@ class ScriptRunner:
             except Exception:
             except Exception:
                 errors.report(f"Error running postprocess_batch_list: {script.filename}", exc_info=True)
                 errors.report(f"Error running postprocess_batch_list: {script.filename}", exc_info=True)
 
 
+    def post_sample(self, p, ps: PostSampleArgs):
+        for script in self.alwayson_scripts:
+            try:
+                script_args = p.script_args[script.args_from:script.args_to]
+                script.post_sample(p, ps, *script_args)
+            except Exception:
+                errors.report(f"Error running post_sample: {script.filename}", exc_info=True)
+
+    def on_mask_blend(self, p, mba: MaskBlendArgs):
+        for script in self.alwayson_scripts:
+            try:
+                script_args = p.script_args[script.args_from:script.args_to]
+                script.on_mask_blend(p, mba, *script_args)
+            except Exception:
+                errors.report(f"Error running post_sample: {script.filename}", exc_info=True)
+
     def postprocess_image(self, p, pp: PostprocessImageArgs):
     def postprocess_image(self, p, pp: PostprocessImageArgs):
         for script in self.alwayson_scripts:
         for script in self.alwayson_scripts:
             try:
             try:
@@ -775,6 +837,14 @@ class ScriptRunner:
             except Exception:
             except Exception:
                 errors.report(f"Error running postprocess_image: {script.filename}", exc_info=True)
                 errors.report(f"Error running postprocess_image: {script.filename}", exc_info=True)
 
 
+    def postprocess_maskoverlay(self, p, ppmo: PostProcessMaskOverlayArgs):
+        for script in self.alwayson_scripts:
+            try:
+                script_args = p.script_args[script.args_from:script.args_to]
+                script.postprocess_maskoverlay(p, ppmo, *script_args)
+            except Exception:
+                errors.report(f"Error running postprocess_image: {script.filename}", exc_info=True)
+
     def before_component(self, component, **kwargs):
     def before_component(self, component, **kwargs):
         for callback, script in self.on_before_component_elem_id.get(kwargs.get("elem_id"), []):
         for callback, script in self.on_before_component_elem_id.get(kwargs.get("elem_id"), []):
             try:
             try:

+ 10 - 13
modules/sd_samplers_cfg_denoiser.py

@@ -109,19 +109,16 @@ class CFGDenoiser(torch.nn.Module):
         assert not is_edit_model or all(len(conds) == 1 for conds in conds_list), "AND is not supported for InstructPix2Pix checkpoint (unless using Image CFG scale = 1.0)"
         assert not is_edit_model or all(len(conds) == 1 for conds in conds_list), "AND is not supported for InstructPix2Pix checkpoint (unless using Image CFG scale = 1.0)"
 
 
         # If we use masks, blending between the denoised and original latent images occurs here.
         # If we use masks, blending between the denoised and original latent images occurs here.
-        def apply_blend(latent):
-            if hasattr(self.p, "denoiser_masked_blend_function") and callable(self.p.denoiser_masked_blend_function):
-                return self.p.denoiser_masked_blend_function(
-                    self,
-                    # Using an argument dictionary so that arguments can be added without breaking extensions.
-                    args=
-                    {
-                        "denoiser": self,
-                        "current_latent": latent,
-                        "sigma": sigma
-                    })
-            else:
-                return self.init_latent * self.mask + self.nmask * latent
+        def apply_blend(current_latent):
+            blended_latent = current_latent * self.nmask + self.init_latent * self.mask
+
+            if self.p.scripts is not None:
+                from modules import scripts
+                mba = scripts.MaskBlendArgs(current_latent, self.nmask, self.init_latent, self.mask, blended_latent, denoiser=self, sigma=sigma)
+                self.p.scripts.on_mask_blend(self.p, mba)
+                blended_latent = mba.blended_latent
+
+            return blended_latent
 
 
         # Blend in the original latents (before)
         # Blend in the original latents (before)
         if self.mask_before_denoising and self.mask is not None:
         if self.mask_before_denoising and self.mask is not None: