Sfoglia il codice sorgente

support scheduler selection in hires fix

AUTOMATIC1111 1 anno fa
parent
commit
9aa9e980a9

+ 3 - 0
modules/infotext_utils.py

@@ -314,6 +314,9 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
     if "Hires sampler" not in res:
     if "Hires sampler" not in res:
         res["Hires sampler"] = "Use same sampler"
         res["Hires sampler"] = "Use same sampler"
 
 
+    if "Hires schedule type" not in res:
+        res["Hires schedule type"] = "Use same scheduler"
+
     if "Hires checkpoint" not in res:
     if "Hires checkpoint" not in res:
         res["Hires checkpoint"] = "Use same checkpoint"
         res["Hires checkpoint"] = "Use same checkpoint"
 
 

+ 6 - 0
modules/processing.py

@@ -1115,6 +1115,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
     hr_resize_y: int = 0
     hr_resize_y: int = 0
     hr_checkpoint_name: str = None
     hr_checkpoint_name: str = None
     hr_sampler_name: str = None
     hr_sampler_name: str = None
+    hr_scheduler: str = None
     hr_prompt: str = ''
     hr_prompt: str = ''
     hr_negative_prompt: str = ''
     hr_negative_prompt: str = ''
     force_task_id: str = None
     force_task_id: str = None
@@ -1203,6 +1204,11 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
             if self.hr_sampler_name is not None and self.hr_sampler_name != self.sampler_name:
             if self.hr_sampler_name is not None and self.hr_sampler_name != self.sampler_name:
                 self.extra_generation_params["Hires sampler"] = self.hr_sampler_name
                 self.extra_generation_params["Hires sampler"] = self.hr_sampler_name
 
 
+            self.extra_generation_params["Hires schedule type"] = None  # to be set in sd_samplers_kdiffusion.py
+
+            if self.hr_scheduler is None:
+                self.hr_scheduler = self.scheduler
+
             self.latent_scale_mode = shared.latent_upscale_modes.get(self.hr_upscaler, None) if self.hr_upscaler is not None else shared.latent_upscale_modes.get(shared.latent_upscale_default_mode, "nearest")
             self.latent_scale_mode = shared.latent_upscale_modes.get(self.hr_upscaler, None) if self.hr_upscaler is not None else shared.latent_upscale_modes.get(shared.latent_upscale_default_mode, "nearest")
             if self.enable_hr and self.latent_scale_mode is None:
             if self.enable_hr and self.latent_scale_mode is None:
                 if not any(x.name == self.hr_upscaler for x in shared.sd_upscalers):
                 if not any(x.name == self.hr_upscaler for x in shared.sd_upscalers):

+ 2 - 36
modules/processing_scripts/sampler.py

@@ -1,44 +1,10 @@
 import gradio as gr
 import gradio as gr
-import functools
 
 
 from modules import scripts, sd_samplers, sd_schedulers, shared
 from modules import scripts, sd_samplers, sd_schedulers, shared
 from modules.infotext_utils import PasteField
 from modules.infotext_utils import PasteField
 from modules.ui_components import FormRow, FormGroup
 from modules.ui_components import FormRow, FormGroup
 
 
 
 
-def get_sampler_from_infotext(d: dict):
-    return get_sampler_and_scheduler(d.get("Sampler"), d.get("Schedule type"))[0]
-
-
-def get_scheduler_from_infotext(d: dict):
-    return get_sampler_and_scheduler(d.get("Sampler"), d.get("Schedule type"))[1]
-
-
-@functools.cache
-def get_sampler_and_scheduler(sampler_name, scheduler_name):
-    default_sampler = sd_samplers.samplers[0]
-    found_scheduler = sd_schedulers.schedulers_map.get(scheduler_name, sd_schedulers.schedulers[0])
-
-    name = sampler_name or default_sampler.name
-
-    for scheduler in sd_schedulers.schedulers:
-        name_options = [scheduler.label, scheduler.name, *(scheduler.aliases or [])]
-
-        for name_option in name_options:
-            if name.endswith(" " + name_option):
-                found_scheduler = scheduler
-                name = name[0:-(len(name_option) + 1)]
-                break
-
-    sampler = sd_samplers.all_samplers_map.get(name, default_sampler)
-
-    # revert back to Automatic if it's the default scheduler for the selected sampler
-    if sampler.options.get('scheduler', None) == found_scheduler.name:
-        found_scheduler = sd_schedulers.schedulers[0]
-
-    return sampler.name, found_scheduler.label
-
-
 class ScriptSampler(scripts.ScriptBuiltinUI):
 class ScriptSampler(scripts.ScriptBuiltinUI):
     section = "sampler"
     section = "sampler"
 
 
@@ -67,8 +33,8 @@ class ScriptSampler(scripts.ScriptBuiltinUI):
 
 
         self.infotext_fields = [
         self.infotext_fields = [
             PasteField(self.steps, "Steps", api="steps"),
             PasteField(self.steps, "Steps", api="steps"),
-            PasteField(self.sampler_name, get_sampler_from_infotext, api="sampler_name"),
-            PasteField(self.scheduler, get_scheduler_from_infotext, api="scheduler"),
+            PasteField(self.sampler_name, sd_samplers.get_sampler_from_infotext, api="sampler_name"),
+            PasteField(self.scheduler, sd_samplers.get_scheduler_from_infotext, api="scheduler"),
         ]
         ]
 
 
         return self.steps, self.sampler_name, self.scheduler
         return self.steps, self.sampler_name, self.scheduler

+ 59 - 1
modules/sd_samplers.py

@@ -1,6 +1,8 @@
 from __future__ import annotations
 from __future__ import annotations
 
 
-from modules import sd_samplers_kdiffusion, sd_samplers_timesteps, sd_samplers_lcm, shared, sd_samplers_common
+import functools
+
+from modules import sd_samplers_kdiffusion, sd_samplers_timesteps, sd_samplers_lcm, shared, sd_samplers_common, sd_schedulers
 
 
 # imports for functions that previously were here and are used by other modules
 # imports for functions that previously were here and are used by other modules
 samples_to_image_grid = sd_samplers_common.samples_to_image_grid
 samples_to_image_grid = sd_samplers_common.samples_to_image_grid
@@ -64,4 +66,60 @@ def visible_samplers():
     return [x for x in samplers if x.name not in samplers_hidden]
     return [x for x in samplers if x.name not in samplers_hidden]
 
 
 
 
+def get_sampler_from_infotext(d: dict):
+    return get_sampler_and_scheduler(d.get("Sampler"), d.get("Schedule type"))[0]
+
+
+def get_scheduler_from_infotext(d: dict):
+    return get_sampler_and_scheduler(d.get("Sampler"), d.get("Schedule type"))[1]
+
+
+def get_hr_sampler_and_scheduler(d: dict):
+    hr_sampler = d.get("Hires sampler", "Use same sampler")
+    sampler = d.get("Sampler") if hr_sampler == "Use same sampler" else hr_sampler
+
+    hr_scheduler = d.get("Hires schedule type", "Use same scheduler")
+    scheduler = d.get("Schedule type") if hr_scheduler == "Use same scheduler" else hr_scheduler
+
+    sampler, scheduler = get_sampler_and_scheduler(sampler, scheduler)
+
+    sampler = sampler if sampler != d.get("Sampler") else "Use same sampler"
+    scheduler = scheduler if scheduler != d.get("Schedule type") else "Use same scheduler"
+
+    return sampler, scheduler
+
+
+def get_hr_sampler_from_infotext(d: dict):
+    return get_hr_sampler_and_scheduler(d)[0]
+
+
+def get_hr_scheduler_from_infotext(d: dict):
+    return get_hr_sampler_and_scheduler(d)[1]
+
+
+@functools.cache
+def get_sampler_and_scheduler(sampler_name, scheduler_name):
+    default_sampler = samplers[0]
+    found_scheduler = sd_schedulers.schedulers_map.get(scheduler_name, sd_schedulers.schedulers[0])
+
+    name = sampler_name or default_sampler.name
+
+    for scheduler in sd_schedulers.schedulers:
+        name_options = [scheduler.label, scheduler.name, *(scheduler.aliases or [])]
+
+        for name_option in name_options:
+            if name.endswith(" " + name_option):
+                found_scheduler = scheduler
+                name = name[0:-(len(name_option) + 1)]
+                break
+
+    sampler = all_samplers_map.get(name, default_sampler)
+
+    # revert back to Automatic if it's the default scheduler for the selected sampler
+    if sampler.options.get('scheduler', None) == found_scheduler.name:
+        found_scheduler = sd_schedulers.schedulers[0]
+
+    return sampler.name, found_scheduler.label
+
+
 set_samplers()
 set_samplers()

+ 4 - 2
modules/sd_samplers_kdiffusion.py

@@ -79,7 +79,7 @@ class KDiffusionSampler(sd_samplers_common.Sampler):
 
 
         steps += 1 if discard_next_to_last_sigma else 0
         steps += 1 if discard_next_to_last_sigma else 0
 
 
-        scheduler_name = p.scheduler or 'Automatic'
+        scheduler_name = (p.hr_scheduler if p.is_hr_pass else p.scheduler) or 'Automatic'
         if scheduler_name == 'Automatic':
         if scheduler_name == 'Automatic':
             scheduler_name = self.config.options.get('scheduler', None)
             scheduler_name = self.config.options.get('scheduler', None)
 
 
@@ -95,8 +95,10 @@ class KDiffusionSampler(sd_samplers_common.Sampler):
         else:
         else:
             sigmas_kwargs = {'sigma_min': sigma_min, 'sigma_max': sigma_max}
             sigmas_kwargs = {'sigma_min': sigma_min, 'sigma_max': sigma_max}
 
 
-            if scheduler.label != 'Automatic':
+            if scheduler.label != 'Automatic' and not p.is_hr_pass:
                 p.extra_generation_params["Schedule type"] = scheduler.label
                 p.extra_generation_params["Schedule type"] = scheduler.label
+            elif scheduler.label != p.extra_generation_params.get("Schedule type"):
+                p.extra_generation_params["Hires schedule type"] = scheduler.label
 
 
             if opts.sigma_min != 0 and opts.sigma_min != m_sigma_min:
             if opts.sigma_min != 0 and opts.sigma_min != m_sigma_min:
                 sigmas_kwargs['sigma_min'] = opts.sigma_min
                 sigmas_kwargs['sigma_min'] = opts.sigma_min

+ 2 - 1
modules/txt2img.py

@@ -11,7 +11,7 @@ from PIL import Image
 import gradio as gr
 import gradio as gr
 
 
 
 
-def txt2img_create_processing(id_task: str, request: gr.Request, prompt: str, negative_prompt: str, prompt_styles, n_iter: int, batch_size: int, cfg_scale: float, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, hr_checkpoint_name: str, hr_sampler_name: str, hr_prompt: str, hr_negative_prompt, override_settings_texts, *args, force_enable_hr=False):
+def txt2img_create_processing(id_task: str, request: gr.Request, prompt: str, negative_prompt: str, prompt_styles, n_iter: int, batch_size: int, cfg_scale: float, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, hr_checkpoint_name: str, hr_sampler_name: str, hr_scheduler: str, hr_prompt: str, hr_negative_prompt, override_settings_texts, *args, force_enable_hr=False):
     override_settings = create_override_settings_dict(override_settings_texts)
     override_settings = create_override_settings_dict(override_settings_texts)
 
 
     if force_enable_hr:
     if force_enable_hr:
@@ -38,6 +38,7 @@ def txt2img_create_processing(id_task: str, request: gr.Request, prompt: str, ne
         hr_resize_y=hr_resize_y,
         hr_resize_y=hr_resize_y,
         hr_checkpoint_name=None if hr_checkpoint_name == 'Use same checkpoint' else hr_checkpoint_name,
         hr_checkpoint_name=None if hr_checkpoint_name == 'Use same checkpoint' else hr_checkpoint_name,
         hr_sampler_name=None if hr_sampler_name == 'Use same sampler' else hr_sampler_name,
         hr_sampler_name=None if hr_sampler_name == 'Use same sampler' else hr_sampler_name,
+        hr_scheduler=None if hr_scheduler == 'Use same scheduler' else hr_scheduler,
         hr_prompt=hr_prompt,
         hr_prompt=hr_prompt,
         hr_negative_prompt=hr_negative_prompt,
         hr_negative_prompt=hr_negative_prompt,
         override_settings=override_settings,
         override_settings=override_settings,

+ 7 - 4
modules/ui.py

@@ -322,10 +322,11 @@ def create_ui():
 
 
                                 with FormRow(elem_id="txt2img_hires_fix_row3", variant="compact", visible=opts.hires_fix_show_sampler) as hr_sampler_container:
                                 with FormRow(elem_id="txt2img_hires_fix_row3", variant="compact", visible=opts.hires_fix_show_sampler) as hr_sampler_container:
 
 
-                                    hr_checkpoint_name = gr.Dropdown(label='Hires checkpoint', elem_id="hr_checkpoint", choices=["Use same checkpoint"] + modules.sd_models.checkpoint_tiles(use_short=True), value="Use same checkpoint")
+                                    hr_checkpoint_name = gr.Dropdown(label='Checkpoint', elem_id="hr_checkpoint", choices=["Use same checkpoint"] + modules.sd_models.checkpoint_tiles(use_short=True), value="Use same checkpoint")
                                     create_refresh_button(hr_checkpoint_name, modules.sd_models.list_models, lambda: {"choices": ["Use same checkpoint"] + modules.sd_models.checkpoint_tiles(use_short=True)}, "hr_checkpoint_refresh")
                                     create_refresh_button(hr_checkpoint_name, modules.sd_models.list_models, lambda: {"choices": ["Use same checkpoint"] + modules.sd_models.checkpoint_tiles(use_short=True)}, "hr_checkpoint_refresh")
 
 
-                                    hr_sampler_name = gr.Dropdown(label='Hires sampling method', elem_id="hr_sampler", choices=["Use same sampler"] + sd_samplers.visible_sampler_names(), value="Use same sampler")
+                                    hr_sampler_name = gr.Dropdown(label='Sampling method', elem_id="hr_sampler", choices=["Use same sampler"] + sd_samplers.visible_sampler_names(), value="Use same sampler")
+                                    hr_scheduler = gr.Dropdown(label='Schedule type', elem_id="hr_scheduler", choices=["Use same scheduler"] + [x.label for x in sd_schedulers.schedulers], value="Use same scheduler")
 
 
                                 with FormRow(elem_id="txt2img_hires_fix_row4", variant="compact", visible=opts.hires_fix_show_prompts) as hr_prompts_container:
                                 with FormRow(elem_id="txt2img_hires_fix_row4", variant="compact", visible=opts.hires_fix_show_prompts) as hr_prompts_container:
                                     with gr.Column(scale=80):
                                     with gr.Column(scale=80):
@@ -394,6 +395,7 @@ def create_ui():
                 hr_resize_y,
                 hr_resize_y,
                 hr_checkpoint_name,
                 hr_checkpoint_name,
                 hr_sampler_name,
                 hr_sampler_name,
+                hr_scheduler,
                 hr_prompt,
                 hr_prompt,
                 hr_negative_prompt,
                 hr_negative_prompt,
                 override_settings,
                 override_settings,
@@ -456,8 +458,9 @@ def create_ui():
                 PasteField(hr_resize_x, "Hires resize-1", api="hr_resize_x"),
                 PasteField(hr_resize_x, "Hires resize-1", api="hr_resize_x"),
                 PasteField(hr_resize_y, "Hires resize-2", api="hr_resize_y"),
                 PasteField(hr_resize_y, "Hires resize-2", api="hr_resize_y"),
                 PasteField(hr_checkpoint_name, "Hires checkpoint", api="hr_checkpoint_name"),
                 PasteField(hr_checkpoint_name, "Hires checkpoint", api="hr_checkpoint_name"),
-                PasteField(hr_sampler_name, "Hires sampler", api="hr_sampler_name"),
-                PasteField(hr_sampler_container, lambda d: gr.update(visible=True) if d.get("Hires sampler", "Use same sampler") != "Use same sampler" or d.get("Hires checkpoint", "Use same checkpoint") != "Use same checkpoint" else gr.update()),
+                PasteField(hr_sampler_name, sd_samplers.get_hr_sampler_from_infotext, api="hr_sampler_name"),
+                PasteField(hr_scheduler, sd_samplers.get_hr_scheduler_from_infotext, api="hr_scheduler"),
+                PasteField(hr_sampler_container, lambda d: gr.update(visible=True) if d.get("Hires sampler", "Use same sampler") != "Use same sampler" or d.get("Hires checkpoint", "Use same checkpoint") != "Use same checkpoint" or d.get("Hires schedule type", "Use same scheduler") != "Use same scheduler" else gr.update()),
                 PasteField(hr_prompt, "Hires prompt", api="hr_prompt"),
                 PasteField(hr_prompt, "Hires prompt", api="hr_prompt"),
                 PasteField(hr_negative_prompt, "Hires negative prompt", api="hr_negative_prompt"),
                 PasteField(hr_negative_prompt, "Hires negative prompt", api="hr_negative_prompt"),
                 PasteField(hr_prompts_container, lambda d: gr.update(visible=True) if d.get("Hires prompt", "") != "" or d.get("Hires negative prompt", "") != "" else gr.update()),
                 PasteField(hr_prompts_container, lambda d: gr.update(visible=True) if d.get("Hires prompt", "") != "" or d.get("Hires negative prompt", "") != "" else gr.update()),