Explorar el Código

Add process_before_every_sampling hook

huchenlei hace 1 año
padre
commit
17e846150c
Se han modificado 2 ficheros con 39 adiciones y 0 borrados
  1. 24 0
      modules/processing.py
  2. 15 0
      modules/scripts.py

+ 24 - 0
modules/processing.py

@@ -1330,6 +1330,15 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
             # here we generate an image normally
 
             x = self.rng.next()
+            if self.scripts is not None:
+                self.scripts.process_before_every_sampling(
+                    p=self,
+                    x=x,
+                    noise=x,
+                    c=conditioning,
+                    uc=unconditional_conditioning
+                )
+
             samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x))
             del x
 
@@ -1430,6 +1439,13 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
 
         if self.scripts is not None:
             self.scripts.before_hr(self)
+            self.scripts.process_before_every_sampling(
+                p=self,
+                x=samples,
+                noise=noise,
+                c=self.hr_c,
+                uc=self.hr_uc,
+            )
 
         samples = self.sampler.sample_img2img(self, samples, noise, self.hr_c, self.hr_uc, steps=self.hr_second_pass_steps or self.steps, image_conditioning=image_conditioning)
 
@@ -1743,6 +1759,14 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
             self.extra_generation_params["Noise multiplier"] = self.initial_noise_multiplier
             x *= self.initial_noise_multiplier
 
+        if self.scripts is not None:
+            self.scripts.process_before_every_sampling(
+                p=self,
+                x=self.init_latent,
+                noise=x,
+                c=conditioning,
+                uc=unconditional_conditioning
+            )
         samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning, image_conditioning=self.image_conditioning)
 
         if self.mask is not None:

+ 15 - 0
modules/scripts.py

@@ -187,6 +187,13 @@ class Script:
         """
         pass
 
+    def process_before_every_sampling(self, p, *args, **kwargs):
+        """
+        Similar to process(), called before every sampling.
+        If you use high-res fix, this will be called two times.
+        """
+        pass
+
     def process_batch(self, p, *args, **kwargs):
         """
         Same as process(), but called for every batch.
@@ -826,6 +833,14 @@ class ScriptRunner:
             except Exception:
                 errors.report(f"Error running process: {script.filename}", exc_info=True)
 
+    def process_before_every_sampling(self, p, **kwargs):
+        for script in self.ordered_scripts('process_before_every_sampling'):
+            try:
+                script_args = p.script_args[script.args_from:script.args_to]
+                script.process_before_every_sampling(p, *script_args, **kwargs)
+            except Exception:
+                errors.report(f"Error running process_before_every_sampling: {script.filename}", exc_info=True)
+
     def before_process_batch(self, p, **kwargs):
         for script in self.ordered_scripts('before_process_batch'):
             try: