浏览代码

set loopback color corrections on each iteration

Brian Drupieski 2 年之前
父节点
当前提交
7bc5739fe2
共有 3 个文件被更改,包括 5 次插入6 次删除
  1. 2 0
      modules/processing.py
  2. 1 4
      scripts/batch.py
  3. 2 2
      scripts/loopback.py

+ 2 - 0
modules/processing.py

@@ -339,6 +339,8 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
 
             state.nextjob()
 
+        p.color_corrections = None
+
         unwanted_grid_because_of_img_count = len(output_images) < 2 and opts.grid_only_if_multiple
         if (opts.return_grid or opts.grid_save) and not p.do_not_save_grid and not unwanted_grid_because_of_img_count:
             grid = images.image_grid(output_images, p.batch_size)

+ 1 - 4
scripts/batch.py

@@ -6,7 +6,7 @@ import traceback
 import modules.scripts as scripts
 import gradio as gr
 
-from modules.processing import Processed, process_images, setup_color_correction
+from modules.processing import Processed, process_images
 from PIL import Image
 from modules.shared import opts, cmd_opts, state
 
@@ -52,9 +52,6 @@ class Script(scripts.Script):
             state.job = f"{batch_no} out of {batch_count}: {batch_images[0][1]}"
             p.init_images = [x[0] for x in batch_images]
 
-            if opts.img2img_color_correction:
-                p.color_corrections = [setup_color_correction(i) for i in p.init_images]
-
             proc = process_images(p)
             for image, (_, path) in zip(proc.images, batch_images):
                 filename = os.path.basename(path)

+ 2 - 2
scripts/loopback.py

@@ -40,8 +40,7 @@ class Script(scripts.Script):
         all_images = []
         state.job_count = loops * batch_count
 
-        if opts.img2img_color_correction:
-            p.color_corrections = [processing.setup_color_correction(p.init_images[0])]
+        initial_color_corrections = [processing.setup_color_correction(p.init_images[0])]
 
         for n in range(batch_count):
             history = []
@@ -50,6 +49,7 @@ class Script(scripts.Script):
                 p.n_iter = 1
                 p.batch_size = 1
                 p.do_not_save_grid = True
+                p.color_corrections = initial_color_corrections
 
                 state.job = f"Iteration {i + 1}/{loops}, batch {n + 1}/{batch_count}"