浏览代码

Add before_process_batch script callback

space-nuko 2 年之前
父节点
当前提交
a2d635ad13
共有 2 个文件被更改,包括 26 次插入0 次删除
  1. 3 0
      modules/processing.py
  2. 23 0
      modules/scripts.py

+ 3 - 0
modules/processing.py

@@ -597,6 +597,9 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
             seeds = p.all_seeds[n * p.batch_size:(n + 1) * p.batch_size]
             seeds = p.all_seeds[n * p.batch_size:(n + 1) * p.batch_size]
             subseeds = p.all_subseeds[n * p.batch_size:(n + 1) * p.batch_size]
             subseeds = p.all_subseeds[n * p.batch_size:(n + 1) * p.batch_size]
 
 
+            if p.scripts is not None:
+                p.scripts.before_process_batch(p, batch_number=n, prompts=prompts, seeds=seeds, subseeds=subseeds)
+
             if len(prompts) == 0:
             if len(prompts) == 0:
                 break
                 break
 
 

+ 23 - 0
modules/scripts.py

@@ -80,6 +80,20 @@ class Script:
 
 
         pass
         pass
 
 
+    def before_process_batch(self, p, *args, **kwargs):
+        """
+        Called before extra networks are parsed from the prompt, so you can add
+        new extra network keywords to the prompt with this callback.
+
+        **kwargs will have those items:
+          - batch_number - index of current batch, from 0 to number of batches-1
+          - prompts - list of prompts for current batch; you can change contents of this list but changing the number of entries will likely break things
+          - seeds - list of seeds for current batch
+          - subseeds - list of subseeds for current batch
+        """
+
+        pass
+
     def process_batch(self, p, *args, **kwargs):
     def process_batch(self, p, *args, **kwargs):
         """
         """
         Same as process(), but called for every batch.
         Same as process(), but called for every batch.
@@ -388,6 +402,15 @@ class ScriptRunner:
                 print(f"Error running process: {script.filename}", file=sys.stderr)
                 print(f"Error running process: {script.filename}", file=sys.stderr)
                 print(traceback.format_exc(), file=sys.stderr)
                 print(traceback.format_exc(), file=sys.stderr)
 
 
+    def before_process_batch(self, p, **kwargs):
+        for script in self.alwayson_scripts:
+            try:
+                script_args = p.script_args[script.args_from:script.args_to]
+                script.before_process_batch(p, *script_args, **kwargs)
+            except Exception:
+                print(f"Error running before_process_batch: {script.filename}", file=sys.stderr)
+                print(traceback.format_exc(), file=sys.stderr)
+
     def process_batch(self, p, **kwargs):
     def process_batch(self, p, **kwargs):
         for script in self.alwayson_scripts:
         for script in self.alwayson_scripts:
             try:
             try: