Преглед изворни кода

make it possible for StableDiffusionProcessing to accept multiple different negative prompts in a batch

AUTOMATIC пре 2 година
родитељ
комит
617c5b486f
2 измењених фајлова са 24 додато и 33 уклоњено
  1. 24 22
      modules/processing.py
  2. 0 11
      modules/styles.py

+ 24 - 22
modules/processing.py

@@ -124,6 +124,7 @@ class StableDiffusionProcessing():
         self.scripts = None
         self.script_args = None
         self.all_prompts = None
+        self.all_negative_prompts = None
         self.all_seeds = None
         self.all_subseeds = None
 
@@ -202,7 +203,7 @@ class StableDiffusionProcessing():
 
 
 class Processed:
-    def __init__(self, p: StableDiffusionProcessing, images_list, seed=-1, info="", subseed=None, all_prompts=None, all_seeds=None, all_subseeds=None, index_of_first_image=0, infotexts=None):
+    def __init__(self, p: StableDiffusionProcessing, images_list, seed=-1, info="", subseed=None, all_prompts=None, all_negative_prompts=None, all_seeds=None, all_subseeds=None, index_of_first_image=0, infotexts=None):
         self.images = images_list
         self.prompt = p.prompt
         self.negative_prompt = p.negative_prompt
@@ -241,16 +242,18 @@ class Processed:
         self.subseed = int(self.subseed if type(self.subseed) != list else self.subseed[0]) if self.subseed is not None else -1
         self.is_using_inpainting_conditioning = p.is_using_inpainting_conditioning
 
-        self.all_prompts = all_prompts or [self.prompt]
-        self.all_seeds = all_seeds or [self.seed]
-        self.all_subseeds = all_subseeds or [self.subseed]
+        self.all_prompts = all_prompts or p.all_prompts or [self.prompt]
+        self.all_negative_prompts = all_negative_prompts or p.all_negative_prompts or [self.negative_prompt]
+        self.all_seeds = all_seeds or p.all_seeds or [self.seed]
+        self.all_subseeds = all_subseeds or p.all_subseeds or [self.subseed]
         self.infotexts = infotexts or [info]
 
     def js(self):
         obj = {
-            "prompt": self.prompt,
+            "prompt": self.all_prompts[0],
             "all_prompts": self.all_prompts,
-            "negative_prompt": self.negative_prompt,
+            "negative_prompt": self.all_negative_prompts[0],
+            "all_negative_prompts": self.all_negative_prompts,
             "seed": self.seed,
             "all_seeds": self.all_seeds,
             "subseed": self.subseed,
@@ -411,7 +414,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration
 
     generation_params_text = ", ".join([k if k == v else f'{k}: {generation_parameters_copypaste.quote(v)}' for k, v in generation_params.items() if v is not None])
 
-    negative_prompt_text = "\nNegative prompt: " + p.negative_prompt if p.negative_prompt else ""
+    negative_prompt_text = "\nNegative prompt: " + p.all_negative_prompts[0] if  p.all_negative_prompts[0] else ""
 
     return f"{all_prompts[index]}{negative_prompt_text}\n{generation_params_text}".strip()
 
@@ -440,10 +443,6 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
     else:
         assert p.prompt is not None
 
-    with open(os.path.join(shared.script_path, "params.txt"), "w", encoding="utf8") as file:
-        processed = Processed(p, [], p.seed, "")
-        file.write(processed.infotext(p, 0))
-
     devices.torch_gc()
 
     seed = get_fixed_seed(p.seed)
@@ -453,15 +452,16 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
     modules.sd_hijack.model_hijack.clear_comments()
 
     comments = {}
-    prompt_tmp = p.prompt
-    negative_prompt_tmp = p.negative_prompt
-
-    shared.prompt_styles.apply_styles(p)
 
     if type(p.prompt) == list:
-        p.all_prompts = p.prompt
+        p.all_prompts = [shared.prompt_styles.apply_styles_to_prompt(x, p.styles) for x in p.prompt]
+    else:
+        p.all_prompts = p.batch_size * p.n_iter * [shared.prompt_styles.apply_styles_to_prompt(p.prompt, p.styles)]
+
+    if type(p.negative_prompt) == list:
+        p.all_negative_prompts = [shared.prompt_styles.apply_negative_styles_to_prompt(x, p.styles) for x in p.negative_prompt]
     else:
-        p.all_prompts = p.batch_size * p.n_iter * [p.prompt]
+        p.all_negative_prompts = p.batch_size * p.n_iter * [shared.prompt_styles.apply_negative_styles_to_prompt(p.negative_prompt, p.styles)]
 
     if type(seed) == list:
         p.all_seeds = seed
@@ -476,6 +476,10 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
     def infotext(iteration=0, position_in_batch=0):
         return create_infotext(p, p.all_prompts, p.all_seeds, p.all_subseeds, comments, iteration, position_in_batch)
 
+    with open(os.path.join(shared.script_path, "params.txt"), "w", encoding="utf8") as file:
+        processed = Processed(p, [], p.seed, "")
+        file.write(processed.infotext(p, 0))
+
     if os.path.exists(cmd_opts.embeddings_dir) and not p.do_not_reload_embeddings:
         model_hijack.embedding_db.load_textual_inversion_embeddings()
 
@@ -500,6 +504,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
                 break
 
             prompts = p.all_prompts[n * p.batch_size:(n + 1) * p.batch_size]
+            negative_prompts = p.all_negative_prompts[n * p.batch_size:(n + 1) * p.batch_size]
             seeds = p.all_seeds[n * p.batch_size:(n + 1) * p.batch_size]
             subseeds = p.all_subseeds[n * p.batch_size:(n + 1) * p.batch_size]
 
@@ -510,7 +515,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
                 p.scripts.process_batch(p, batch_number=n, prompts=prompts, seeds=seeds, subseeds=subseeds)
 
             with devices.autocast():
-                uc = prompt_parser.get_learned_conditioning(shared.sd_model, len(prompts) * [p.negative_prompt], p.steps)
+                uc = prompt_parser.get_learned_conditioning(shared.sd_model, negative_prompts, p.steps)
                 c = prompt_parser.get_multicond_learned_conditioning(shared.sd_model, prompts, p.steps)
 
             if len(model_hijack.comments) > 0:
@@ -596,14 +601,11 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
 
     devices.torch_gc()
 
-    res = Processed(p, output_images, p.all_seeds[0], infotext() + "".join(["\n\n" + x for x in comments]), subseed=p.all_subseeds[0], all_prompts=p.all_prompts, all_seeds=p.all_seeds, all_subseeds=p.all_subseeds, index_of_first_image=index_of_first_image, infotexts=infotexts)
+    res = Processed(p, output_images, p.all_seeds[0], infotext() + "".join(["\n\n" + x for x in comments]), subseed=p.all_subseeds[0], index_of_first_image=index_of_first_image, infotexts=infotexts)
 
     if p.scripts is not None:
         p.scripts.postprocess(p, res)
 
-    p.prompt = prompt_tmp
-    p.negative_prompt = negative_prompt_tmp
-
     return res
 
 

+ 0 - 11
modules/styles.py

@@ -65,17 +65,6 @@ class StyleDatabase:
     def apply_negative_styles_to_prompt(self, prompt, styles):
         return apply_styles_to_prompt(prompt, [self.styles.get(x, self.no_style).negative_prompt for x in styles])
 
-    def apply_styles(self, p: StableDiffusionProcessing) -> None:
-        if isinstance(p.prompt, list):
-            p.prompt = [self.apply_styles_to_prompt(prompt, p.styles) for prompt in p.prompt]
-        else:
-            p.prompt = self.apply_styles_to_prompt(p.prompt, p.styles)
-
-        if isinstance(p.negative_prompt, list):
-            p.negative_prompt = [self.apply_negative_styles_to_prompt(prompt, p.styles) for prompt in p.negative_prompt]
-        else:
-            p.negative_prompt = self.apply_negative_styles_to_prompt(p.negative_prompt, p.styles)
-
     def save_styles(self, path: str) -> None:
         # Write to temporary file first, so we don't nuke the file if something goes wrong
         fd, temp_path = tempfile.mkstemp(".csv")