Browse Source

Merge remote-tracking branch 'InvincibleDude/improved-hr-conflict-test' into hires-fix-ext

AUTOMATIC 2 years ago
parent
commit
5ec2c294ee
4 changed files with 97 additions and 8 deletions
  1. 5 1
      modules/generation_parameters_copypaste.py
  2. 68 5
      modules/processing.py
  3. 6 2
      modules/txt2img.py
  4. 18 0
      modules/ui.py

+ 5 - 1
modules/generation_parameters_copypaste.py

@@ -252,7 +252,6 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
         if line.startswith("Negative prompt:"):
             done_with_prompt = True
             line = line[16:].strip()
-
         if done_with_prompt:
             negative_prompt += ("" if negative_prompt == "" else "\n") + line
         else:
@@ -270,6 +269,11 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
         else:
             res[k] = v
 
+        if k.startswith("Hires prompt"):
+            res["Hires prompt"] = v[1:][:-1].replace(';', ',')
+        elif k.startswith("Hires negative prompt"):
+            res["Hires negative prompt"] = v[1:][:-1].replace(';', ',')
+
     # Missing CLIP skip means it was set to 1 (the default)
     if "Clip skip" not in res:
         res["Clip skip"] = "1"

+ 68 - 5
modules/processing.py

@@ -271,7 +271,7 @@ class StableDiffusionProcessing:
     def init(self, all_prompts, all_seeds, all_subseeds):
         pass
 
-    def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
+    def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts, hr_conditioning=None, hr_unconditional_conditioning=None):
         raise NotImplementedError()
 
     def close(self):
@@ -592,6 +592,20 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
     else:
         p.all_negative_prompts = p.batch_size * p.n_iter * [shared.prompt_styles.apply_negative_styles_to_prompt(p.negative_prompt, p.styles)]
 
+    if type(p) == StableDiffusionProcessingTxt2Img:
+        if p.enable_hr and p.hr_prompt == '':
+            p.all_hr_prompts, p.all_hr_negative_prompts = p.all_prompts, p.all_negative_prompts
+        elif p.enable_hr and p.hr_prompt != '':
+            if type(p.prompt) == list:
+                p.all_hr_prompts = [shared.prompt_styles.apply_styles_to_prompt(x, p.styles) for x in p.hr_prompt]
+            else:
+                p.all_hr_prompts = p.batch_size * p.n_iter * [shared.prompt_styles.apply_styles_to_prompt(p.hr_prompt, p.styles)]
+
+            if type(p.negative_prompt) == list:
+                p.all_hr_negative_prompts = [shared.prompt_styles.apply_negative_styles_to_prompt(x, p.styles) for x in p.hr_negative_prompt]
+            else:
+                p.all_hr_negative_prompts = p.batch_size * p.n_iter * [shared.prompt_styles.apply_negative_styles_to_prompt(p.hr_negative_prompt, p.styles)]
+
     if type(seed) == list:
         p.all_seeds = seed
     else:
@@ -660,6 +674,15 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
 
             prompts = p.all_prompts[n * p.batch_size:(n + 1) * p.batch_size]
             negative_prompts = p.all_negative_prompts[n * p.batch_size:(n + 1) * p.batch_size]
+
+            if type(p) == StableDiffusionProcessingTxt2Img:
+                if p.enable_hr:
+                    if p.hr_prompt == '':
+                        hr_prompts, hr_negative_prompts = prompts, negative_prompts
+                    else:
+                        hr_prompts = p.all_hr_prompts[n * p.batch_size:(n + 1) * p.batch_size]
+                        hr_negative_prompts = p.all_hr_negative_prompts[n * p.batch_size:(n + 1) * p.batch_size]
+
             seeds = p.all_seeds[n * p.batch_size:(n + 1) * p.batch_size]
             subseeds = p.all_subseeds[n * p.batch_size:(n + 1) * p.batch_size]
 
@@ -671,6 +694,12 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
 
             prompts, extra_network_data = extra_networks.parse_prompts(prompts)
 
+            if type(p) == StableDiffusionProcessingTxt2Img:
+                if p.enable_hr and hr_prompts != prompts:
+                    _, hr_extra_network_data = extra_networks.parse_prompts(hr_prompts)
+                    extra_network_data.update(hr_extra_network_data)
+
+
             if not p.disable_extra_networks:
                 with devices.autocast():
                     extra_networks.activate(p, extra_network_data)
@@ -692,6 +721,14 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
             uc = get_conds_with_caching(prompt_parser.get_learned_conditioning, negative_prompts, p.steps * step_multiplier, cached_uc)
             c = get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, prompts, p.steps * step_multiplier, cached_c)
 
+            if type(p) == StableDiffusionProcessingTxt2Img:
+                if p.enable_hr:
+                    if prompts != hr_prompts:
+                        hr_uc = get_conds_with_caching(prompt_parser.get_learned_conditioning, hr_negative_prompts, p.steps, cached_uc)
+                        hr_c = get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, hr_prompts, p.steps, cached_c)
+                    else:
+                        hr_uc, hr_c = uc, c
+
             if len(model_hijack.comments) > 0:
                 for comment in model_hijack.comments:
                     comments[comment] = 1
@@ -699,8 +736,15 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
             if p.n_iter > 1:
                 shared.state.job = f"Batch {n+1} out of {p.n_iter}"
 
+
             with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast():
-                samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength, prompts=prompts)
+                if type(p) == StableDiffusionProcessingTxt2Img:
+                    if p.enable_hr:
+                        samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, hr_conditioning=hr_c, hr_unconditional_conditioning=hr_uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength, prompts=prompts)
+                    else:
+                        samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength, prompts=prompts)
+                else:
+                    samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength, prompts=prompts)
 
             x_samples_ddim = [decode_first_stage(p.sd_model, samples_ddim[i:i+1].to(dtype=devices.dtype_vae))[0].cpu() for i in range(samples_ddim.size(0))]
             for x in x_samples_ddim:
@@ -835,7 +879,7 @@ def old_hires_fix_first_pass_dimensions(width, height):
 class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
     sampler = None
 
-    def __init__(self, enable_hr: bool = False, denoising_strength: float = 0.75, firstphase_width: int = 0, firstphase_height: int = 0, hr_scale: float = 2.0, hr_upscaler: str = None, hr_second_pass_steps: int = 0, hr_resize_x: int = 0, hr_resize_y: int = 0, **kwargs):
+    def __init__(self, enable_hr: bool = False, denoising_strength: float = 0.75, firstphase_width: int = 0, firstphase_height: int = 0, hr_scale: float = 2.0, hr_upscaler: str = None, hr_second_pass_steps: int = 0, hr_resize_x: int = 0, hr_resize_y: int = 0, hr_sampler: str = '---', hr_prompt: str = '', hr_negative_prompt: str = '', **kwargs):
         super().__init__(**kwargs)
         self.enable_hr = enable_hr
         self.denoising_strength = denoising_strength
@@ -846,6 +890,11 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
         self.hr_resize_y = hr_resize_y
         self.hr_upscale_to_x = hr_resize_x
         self.hr_upscale_to_y = hr_resize_y
+        self.hr_sampler = hr_sampler
+        self.hr_prompt = hr_prompt if hr_prompt != '' else ''
+        self.hr_negative_prompt = hr_negative_prompt if hr_negative_prompt != '' else ''
+        self.all_hr_prompts = None
+        self.all_hr_negative_prompts = None
 
         if firstphase_width != 0 or firstphase_height != 0:
             self.hr_upscale_to_x = self.width
@@ -859,6 +908,13 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
 
     def init(self, all_prompts, all_seeds, all_subseeds):
         if self.enable_hr:
+            if self.hr_sampler != '---':
+                self.extra_generation_params["Hires sampler"] = self.hr_sampler
+
+            if self.hr_prompt != '':
+                self.extra_generation_params["Hires prompt"] = f'({self.hr_prompt.replace(",", ";")})'
+                self.extra_generation_params["Hires negative prompt"] = f'({self.hr_negative_prompt.replace(",", ";")})'
+
             if opts.use_old_hires_fix_width_height and self.applied_old_hires_behavior_to != (self.width, self.height):
                 self.hr_resize_x = self.width
                 self.hr_resize_y = self.height
@@ -919,7 +975,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
             if self.hr_upscaler is not None:
                 self.extra_generation_params["Hires upscaler"] = self.hr_upscaler
 
-    def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
+    def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts, hr_conditioning=None, hr_unconditional_conditioning=None):
         self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model)
 
         latent_scale_mode = shared.latent_upscale_modes.get(self.hr_upscaler, None) if self.hr_upscaler is not None else shared.latent_upscale_modes.get(shared.latent_upscale_default_mode, "nearest")
@@ -989,8 +1045,15 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
         shared.state.nextjob()
 
         img2img_sampler_name = self.sampler_name
+
         if self.sampler_name in ['PLMS', 'UniPC']:  # PLMS/UniPC do not support img2img so we just silently switch to DDIM
             img2img_sampler_name = 'DDIM'
+
+        if self.hr_sampler == '---':
+            pass
+        else:
+            img2img_sampler_name = self.hr_sampler
+
         self.sampler = sd_samplers.create_sampler(img2img_sampler_name, self.sd_model)
 
         samples = samples[:, :, self.truncate_y//2:samples.shape[2]-(self.truncate_y+1)//2, self.truncate_x//2:samples.shape[3]-(self.truncate_x+1)//2]
@@ -1003,7 +1066,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
 
         sd_models.apply_token_merging(self.sd_model, self.get_token_merging_ratio(for_hr=True))
 
-        samples = self.sampler.sample_img2img(self, samples, noise, conditioning, unconditional_conditioning, steps=self.hr_second_pass_steps or self.steps, image_conditioning=image_conditioning)
+        samples = self.sampler.sample_img2img(self, samples, noise, hr_conditioning, hr_unconditional_conditioning, steps=self.hr_second_pass_steps or self.steps, image_conditioning=image_conditioning)
 
         sd_models.apply_token_merging(self.sd_model, self.get_token_merging_ratio())
 

+ 6 - 2
modules/txt2img.py

@@ -6,9 +6,10 @@ import modules.shared as shared
 from modules.ui import plaintext_to_html
 
 
-def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, override_settings_texts, *args):
-    override_settings = create_override_settings_dict(override_settings_texts)
 
+def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, hr_sampler_index: int, hr_prompt: str, hr_negative_prompt, override_settings_texts, *args):
+    override_settings = create_override_settings_dict(override_settings_texts)
+    
     p = processing.StableDiffusionProcessingTxt2Img(
         sd_model=shared.sd_model,
         outpath_samples=opts.outdir_samples or opts.outdir_txt2img_samples,
@@ -38,6 +39,9 @@ def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, step
         hr_second_pass_steps=hr_second_pass_steps,
         hr_resize_x=hr_resize_x,
         hr_resize_y=hr_resize_y,
+        hr_sampler=sd_samplers.samplers_for_img2img[hr_sampler_index - 1].name if hr_sampler_index != 0 else '---',
+        hr_prompt=hr_prompt,
+        hr_negative_prompt=hr_negative_prompt,
         override_settings=override_settings,
     )
 

+ 18 - 0
modules/ui.py

@@ -499,6 +499,17 @@ def create_ui():
                                 hr_resize_x = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize width to", value=0, elem_id="txt2img_hr_resize_x")
                                 hr_resize_y = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize height to", value=0, elem_id="txt2img_hr_resize_y")
 
+                            with FormRow(elem_id="txt2img_hires_fix_row3", variant="compact"):
+                                hr_sampler_index = gr.Dropdown(label='Hires sampling method', elem_id=f"hr_sampler", choices=["---"] + [x.name for x in samplers_for_img2img], value="---", type="index")
+
+                            with FormRow(elem_id="txt2img_hires_fix_row4", variant="compact"):
+                                with gr.Column(scale=80):
+                                    with gr.Row():
+                                        hr_prompt = gr.Textbox(label="Prompt", elem_id=f"hires_prompt", show_label=False, lines=3, placeholder="Prompt that will be used for hires fix pass (leave it blank to use the same prompt as in initial txt2img gen)")
+                                with gr.Column(scale=80):
+                                    with gr.Row():
+                                        hr_negative_prompt = gr.Textbox(label="Negative prompt", elem_id=f"hires_neg_prompt", show_label=False, lines=3, placeholder="Negative prompt that will be used for hires fix pass (leave it blank to use the same prompt as in initial txt2img gen)")
+
                     elif category == "batch":
                         if not opts.dimensions_and_batch_together:
                             with FormRow(elem_id="txt2img_column_batch"):
@@ -560,7 +571,11 @@ def create_ui():
                     hr_second_pass_steps,
                     hr_resize_x,
                     hr_resize_y,
+                    hr_sampler_index,
+                    hr_prompt,
+                    hr_negative_prompt,
                     override_settings,
+
                 ] + custom_inputs,
 
                 outputs=[
@@ -631,6 +646,9 @@ def create_ui():
                 (hr_second_pass_steps, "Hires steps"),
                 (hr_resize_x, "Hires resize-1"),
                 (hr_resize_y, "Hires resize-2"),
+                (hr_sampler_index, "Hires sampling method"),
+                (hr_prompt, "Hires prompt"),
+                (hr_negative_prompt, "Hires negative prompt"),
                 *modules.scripts.scripts_txt2img.infotext_fields
             ]
             parameters_copypaste.add_paste_fields("txt2img", None, txt2img_paste_fields, override_settings)