فهرست منبع

Add process_before_every_sampling hook

huchenlei 1 سال پیش
والد
کامیت
17e846150c
2فایلهای تغییر یافته به همراه39 افزوده شده و 0 حذف شده
  1. 24 0
      modules/processing.py
  2. 15 0
      modules/scripts.py

+ 24 - 0
modules/processing.py

@@ -1330,6 +1330,15 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
             # here we generate an image normally
             # here we generate an image normally
 
 
             x = self.rng.next()
             x = self.rng.next()
+            if self.scripts is not None:
+                self.scripts.process_before_every_sampling(
+                    p=self,
+                    x=x,
+                    noise=x,
+                    c=conditioning,
+                    uc=unconditional_conditioning
+                )
+
             samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x))
             samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x))
             del x
             del x
 
 
@@ -1430,6 +1439,13 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
 
 
         if self.scripts is not None:
         if self.scripts is not None:
             self.scripts.before_hr(self)
             self.scripts.before_hr(self)
+            self.scripts.process_before_every_sampling(
+                p=self,
+                x=samples,
+                noise=noise,
+                c=self.hr_c,
+                uc=self.hr_uc,
+            )
 
 
         samples = self.sampler.sample_img2img(self, samples, noise, self.hr_c, self.hr_uc, steps=self.hr_second_pass_steps or self.steps, image_conditioning=image_conditioning)
         samples = self.sampler.sample_img2img(self, samples, noise, self.hr_c, self.hr_uc, steps=self.hr_second_pass_steps or self.steps, image_conditioning=image_conditioning)
 
 
@@ -1743,6 +1759,14 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
             self.extra_generation_params["Noise multiplier"] = self.initial_noise_multiplier
             self.extra_generation_params["Noise multiplier"] = self.initial_noise_multiplier
             x *= self.initial_noise_multiplier
             x *= self.initial_noise_multiplier
 
 
+        if self.scripts is not None:
+            self.scripts.process_before_every_sampling(
+                p=self,
+                x=self.init_latent,
+                noise=x,
+                c=conditioning,
+                uc=unconditional_conditioning
+            )
         samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning, image_conditioning=self.image_conditioning)
         samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning, image_conditioning=self.image_conditioning)
 
 
         if self.mask is not None:
         if self.mask is not None:

+ 15 - 0
modules/scripts.py

@@ -187,6 +187,13 @@ class Script:
         """
         """
         pass
         pass
 
 
+    def process_before_every_sampling(self, p, *args, **kwargs):
+        """
+        Similar to process(), called before every sampling.
+        If you use high-res fix, this will be called two times.
+        """
+        pass
+
     def process_batch(self, p, *args, **kwargs):
     def process_batch(self, p, *args, **kwargs):
         """
         """
         Same as process(), but called for every batch.
         Same as process(), but called for every batch.
@@ -826,6 +833,14 @@ class ScriptRunner:
             except Exception:
             except Exception:
                 errors.report(f"Error running process: {script.filename}", exc_info=True)
                 errors.report(f"Error running process: {script.filename}", exc_info=True)
 
 
+    def process_before_every_sampling(self, p, **kwargs):
+        for script in self.ordered_scripts('process_before_every_sampling'):
+            try:
+                script_args = p.script_args[script.args_from:script.args_to]
+                script.process_before_every_sampling(p, *script_args, **kwargs)
+            except Exception:
+                errors.report(f"Error running process_before_every_sampling: {script.filename}", exc_info=True)
+
     def before_process_batch(self, p, **kwargs):
     def before_process_batch(self, p, **kwargs):
         for script in self.ordered_scripts('before_process_batch'):
         for script in self.ordered_scripts('before_process_batch'):
             try:
             try: