Bläddra i källkod

Merge pull request #12181 from AUTOMATIC1111/hires_checkpoint

Hires fix change checkpoint
AUTOMATIC1111 2 år sedan
förälder
incheckning
0ae2767ae6
6 ändrade filer med 108 tillägg och 49 borttagningar
  1. 3 0
      modules/generation_parameters_copypaste.py
  2. 66 31
      modules/processing.py
  3. 17 10
      modules/sd_models.py
  4. 13 6
      modules/shared.py
  5. 2 1
      modules/txt2img.py
  6. 7 1
      modules/ui.py

+ 3 - 0
modules/generation_parameters_copypaste.py

@@ -280,6 +280,9 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
     if "Hires sampler" not in res:
         res["Hires sampler"] = "Use same sampler"
 
+    if "Hires checkpoint" not in res:
+        res["Hires checkpoint"] = "Use same checkpoint"
+
     if "Hires prompt" not in res:
         res["Hires prompt"] = ""
 

+ 66 - 31
modules/processing.py

@@ -539,8 +539,12 @@ def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, see
     return x
 
 
+class DecodedSamples(list):
+    already_decoded = True
+
+
 def decode_latent_batch(model, batch, target_device=None, check_for_nans=False):
-    samples = []
+    samples = DecodedSamples()
 
     for i in range(batch.shape[0]):
         sample = decode_first_stage(model, batch[i:i + 1])[0]
@@ -788,7 +792,11 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
             with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast():
                 samples_ddim = p.sample(conditioning=p.c, unconditional_conditioning=p.uc, seeds=p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, prompts=p.prompts)
 
-            x_samples_ddim = decode_latent_batch(p.sd_model, samples_ddim, target_device=devices.cpu, check_for_nans=True)
+            if getattr(samples_ddim, 'already_decoded', False):
+                x_samples_ddim = samples_ddim
+            else:
+                x_samples_ddim = decode_latent_batch(p.sd_model, samples_ddim, target_device=devices.cpu, check_for_nans=True)
+
             x_samples_ddim = torch.stack(x_samples_ddim).float()
             x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
 
@@ -930,7 +938,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
     cached_hr_uc = [None, None]
     cached_hr_c = [None, None]
 
-    def __init__(self, enable_hr: bool = False, denoising_strength: float = 0.75, firstphase_width: int = 0, firstphase_height: int = 0, hr_scale: float = 2.0, hr_upscaler: str = None, hr_second_pass_steps: int = 0, hr_resize_x: int = 0, hr_resize_y: int = 0, hr_sampler_name: str = None, hr_prompt: str = '', hr_negative_prompt: str = '', **kwargs):
+    def __init__(self, enable_hr: bool = False, denoising_strength: float = 0.75, firstphase_width: int = 0, firstphase_height: int = 0, hr_scale: float = 2.0, hr_upscaler: str = None, hr_second_pass_steps: int = 0, hr_resize_x: int = 0, hr_resize_y: int = 0, hr_checkpoint_name: str = None, hr_sampler_name: str = None, hr_prompt: str = '', hr_negative_prompt: str = '', **kwargs):
         super().__init__(**kwargs)
         self.enable_hr = enable_hr
         self.denoising_strength = denoising_strength
@@ -941,11 +949,14 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
         self.hr_resize_y = hr_resize_y
         self.hr_upscale_to_x = hr_resize_x
         self.hr_upscale_to_y = hr_resize_y
+        self.hr_checkpoint_name = hr_checkpoint_name
+        self.hr_checkpoint_info = None
         self.hr_sampler_name = hr_sampler_name
         self.hr_prompt = hr_prompt
         self.hr_negative_prompt = hr_negative_prompt
         self.all_hr_prompts = None
         self.all_hr_negative_prompts = None
+        self.latent_scale_mode = None
 
         if firstphase_width != 0 or firstphase_height != 0:
             self.hr_upscale_to_x = self.width
@@ -968,6 +979,14 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
 
     def init(self, all_prompts, all_seeds, all_subseeds):
         if self.enable_hr:
+            if self.hr_checkpoint_name:
+                self.hr_checkpoint_info = sd_models.get_closet_checkpoint_match(self.hr_checkpoint_name)
+
+                if self.hr_checkpoint_info is None:
+                    raise Exception(f'Could not find checkpoint with name {self.hr_checkpoint_name}')
+
+                self.extra_generation_params["Hires checkpoint"] = self.hr_checkpoint_info.short_title
+
             if self.hr_sampler_name is not None and self.hr_sampler_name != self.sampler_name:
                 self.extra_generation_params["Hires sampler"] = self.hr_sampler_name
 
@@ -977,6 +996,11 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
             if tuple(self.hr_negative_prompt) != tuple(self.negative_prompt):
                 self.extra_generation_params["Hires negative prompt"] = self.hr_negative_prompt
 
+            self.latent_scale_mode = shared.latent_upscale_modes.get(self.hr_upscaler, None) if self.hr_upscaler is not None else shared.latent_upscale_modes.get(shared.latent_upscale_default_mode, "nearest")
+            if self.enable_hr and self.latent_scale_mode is None:
+                if not any(x.name == self.hr_upscaler for x in shared.sd_upscalers):
+                    raise Exception(f"could not find upscaler named {self.hr_upscaler}")
+
             if opts.use_old_hires_fix_width_height and self.applied_old_hires_behavior_to != (self.width, self.height):
                 self.hr_resize_x = self.width
                 self.hr_resize_y = self.height
@@ -1015,14 +1039,6 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
                     self.truncate_x = (self.hr_upscale_to_x - target_w) // opt_f
                     self.truncate_y = (self.hr_upscale_to_y - target_h) // opt_f
 
-            # special case: the user has chosen to do nothing
-            if self.hr_upscale_to_x == self.width and self.hr_upscale_to_y == self.height:
-                self.enable_hr = False
-                self.denoising_strength = None
-                self.extra_generation_params.pop("Hires upscale", None)
-                self.extra_generation_params.pop("Hires resize", None)
-                return
-
             if not state.processing_has_refined_job_count:
                 if state.job_count == -1:
                     state.job_count = self.n_iter
@@ -1040,17 +1056,32 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
     def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
         self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model)
 
-        latent_scale_mode = shared.latent_upscale_modes.get(self.hr_upscaler, None) if self.hr_upscaler is not None else shared.latent_upscale_modes.get(shared.latent_upscale_default_mode, "nearest")
-        if self.enable_hr and latent_scale_mode is None:
-            if not any(x.name == self.hr_upscaler for x in shared.sd_upscalers):
-                raise Exception(f"could not find upscaler named {self.hr_upscaler}")
-
         x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
         samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x))
+        del x
 
         if not self.enable_hr:
             return samples
 
+        if self.latent_scale_mode is None:
+            decoded_samples = torch.stack(decode_latent_batch(self.sd_model, samples, target_device=devices.cpu, check_for_nans=True)).to(dtype=torch.float32)
+        else:
+            decoded_samples = None
+
+        current = shared.sd_model.sd_checkpoint_info
+        try:
+            if self.hr_checkpoint_info is not None:
+                self.sampler = None
+                sd_models.reload_model_weights(info=self.hr_checkpoint_info)
+                devices.torch_gc()
+
+            return self.sample_hr_pass(samples, decoded_samples, seeds, subseeds, subseed_strength, prompts)
+        finally:
+            self.sampler = None
+            sd_models.reload_model_weights(info=current)
+            devices.torch_gc()
+
+    def sample_hr_pass(self, samples, decoded_samples, seeds, subseeds, subseed_strength, prompts):
         self.is_hr_pass = True
 
         target_width = self.hr_upscale_to_x
@@ -1068,11 +1099,18 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
             info = create_infotext(self, self.all_prompts, self.all_seeds, self.all_subseeds, [], iteration=self.iteration, position_in_batch=index)
             images.save_image(image, self.outpath_samples, "", seeds[index], prompts[index], opts.samples_format, info=info, p=self, suffix="-before-highres-fix")
 
-        if latent_scale_mode is not None:
+        img2img_sampler_name = self.hr_sampler_name or self.sampler_name
+
+        if self.sampler_name in ['PLMS', 'UniPC']:  # PLMS/UniPC do not support img2img so we just silently switch to DDIM
+            img2img_sampler_name = 'DDIM'
+
+        self.sampler = sd_samplers.create_sampler(img2img_sampler_name, self.sd_model)
+
+        if self.latent_scale_mode is not None:
             for i in range(samples.shape[0]):
                 save_intermediate(samples, i)
 
-            samples = torch.nn.functional.interpolate(samples, size=(target_height // opt_f, target_width // opt_f), mode=latent_scale_mode["mode"], antialias=latent_scale_mode["antialias"])
+            samples = torch.nn.functional.interpolate(samples, size=(target_height // opt_f, target_width // opt_f), mode=self.latent_scale_mode["mode"], antialias=self.latent_scale_mode["antialias"])
 
             # Avoid making the inpainting conditioning unless necessary as
             # this does need some extra compute to decode / encode the image again.
@@ -1081,7 +1119,6 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
             else:
                 image_conditioning = self.txt2img_image_conditioning(samples)
         else:
-            decoded_samples = decode_first_stage(self.sd_model, samples)
             lowres_samples = torch.clamp((decoded_samples + 1.0) / 2.0, min=0.0, max=1.0)
 
             batch_images = []
@@ -1100,6 +1137,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
             decoded_samples = torch.from_numpy(np.array(batch_images))
             decoded_samples = decoded_samples.to(shared.device)
             decoded_samples = 2. * decoded_samples - 1.
+            decoded_samples = decoded_samples.to(shared.device, dtype=devices.dtype_vae)
 
             samples = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(decoded_samples))
 
@@ -1107,19 +1145,11 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
 
         shared.state.nextjob()
 
-        img2img_sampler_name = self.hr_sampler_name or self.sampler_name
-
-        if self.sampler_name in ['PLMS', 'UniPC']:  # PLMS/UniPC do not support img2img so we just silently switch to DDIM
-            img2img_sampler_name = 'DDIM'
-
-        self.sampler = sd_samplers.create_sampler(img2img_sampler_name, self.sd_model)
-
         samples = samples[:, :, self.truncate_y//2:samples.shape[2]-(self.truncate_y+1)//2, self.truncate_x//2:samples.shape[3]-(self.truncate_x+1)//2]
 
         noise = create_random_tensors(samples.shape[1:], seeds=seeds, subseeds=subseeds, subseed_strength=subseed_strength, p=self)
 
         # GC now before running the next img2img to prevent running out of memory
-        x = None
         devices.torch_gc()
 
         if not self.disable_extra_networks:
@@ -1138,9 +1168,11 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
 
         sd_models.apply_token_merging(self.sd_model, self.get_token_merging_ratio())
 
+        decoded_samples = decode_latent_batch(self.sd_model, samples, target_device=devices.cpu, check_for_nans=True)
+
         self.is_hr_pass = False
 
-        return samples
+        return decoded_samples
 
     def close(self):
         super().close()
@@ -1179,8 +1211,11 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
         if self.hr_c is not None:
             return
 
-        self.hr_uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, self.hr_negative_prompts, self.steps * self.step_multiplier, [self.cached_hr_uc, self.cached_uc], self.hr_extra_network_data)
-        self.hr_c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, self.hr_prompts, self.steps * self.step_multiplier, [self.cached_hr_c, self.cached_c], self.hr_extra_network_data)
+        hr_prompts = prompt_parser.SdConditioning(self.hr_prompts, width=self.hr_upscale_to_x, height=self.hr_upscale_to_y)
+        hr_negative_prompts = prompt_parser.SdConditioning(self.hr_negative_prompts, width=self.hr_upscale_to_x, height=self.hr_upscale_to_y, is_negative_prompt=True)
+
+        self.hr_uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, hr_negative_prompts, self.steps * self.step_multiplier, [self.cached_hr_uc, self.cached_uc], self.hr_extra_network_data)
+        self.hr_c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, hr_prompts, self.steps * self.step_multiplier, [self.cached_hr_c, self.cached_c], self.hr_extra_network_data)
 
     def setup_conds(self):
         super().setup_conds()
@@ -1188,7 +1223,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
         self.hr_uc = None
         self.hr_c = None
 
-        if self.enable_hr:
+        if self.enable_hr and self.hr_checkpoint_info is None:
             if shared.opts.hires_fix_use_firstpass_conds:
                 self.calculate_hr_conds()
 

+ 17 - 10
modules/sd_models.py

@@ -67,6 +67,7 @@ class CheckpointInfo:
         self.shorthash = self.sha256[0:10] if self.sha256 else None
 
         self.title = name if self.shorthash is None else f'{name} [{self.shorthash}]'
+        self.short_title = self.name_for_extra if self.shorthash is None else f'{self.name_for_extra} [{self.shorthash}]'
 
         self.ids = [self.hash, self.model_name, self.title, name, self.name_for_extra, f'{name} [{self.hash}]'] + ([self.shorthash, self.sha256, f'{self.name} [{self.shorthash}]'] if self.shorthash else [])
 
@@ -87,6 +88,7 @@ class CheckpointInfo:
 
         checkpoints_list.pop(self.title, None)
         self.title = f'{self.name} [{self.shorthash}]'
+        self.short_title = f'{self.name_for_extra} [{self.shorthash}]'
         self.register()
 
         return self.shorthash
@@ -107,14 +109,8 @@ def setup_model():
     enable_midas_autodownload()
 
 
-def checkpoint_tiles():
-    def convert(name):
-        return int(name) if name.isdigit() else name.lower()
-
-    def alphanumeric_key(key):
-        return [convert(c) for c in re.split('([0-9]+)', key)]
-
-    return sorted([x.title for x in checkpoints_list.values()], key=alphanumeric_key)
+def checkpoint_tiles(use_short=False):
+    return [x.short_title if use_short else x.title for x in checkpoints_list.values()]
 
 
 def list_models():
@@ -137,11 +133,14 @@ def list_models():
     elif cmd_ckpt is not None and cmd_ckpt != shared.default_sd_model_file:
         print(f"Checkpoint in --ckpt argument not found (Possible it was moved to {model_path}: {cmd_ckpt}", file=sys.stderr)
 
-    for filename in sorted(model_list, key=str.lower):
+    for filename in model_list:
         checkpoint_info = CheckpointInfo(filename)
         checkpoint_info.register()
 
 
+re_strip_checksum = re.compile(r"\s*\[[^]]+]\s*$")
+
+
 def get_closet_checkpoint_match(search_string):
     checkpoint_info = checkpoint_aliases.get(search_string, None)
     if checkpoint_info is not None:
@@ -151,6 +150,11 @@ def get_closet_checkpoint_match(search_string):
     if found:
         return found[0]
 
+    search_string_without_checksum = re.sub(re_strip_checksum, '', search_string)
+    found = sorted([info for info in checkpoints_list.values() if search_string_without_checksum in info.title], key=lambda x: len(x.title))
+    if found:
+        return found[0]
+
     return None
 
 
@@ -585,7 +589,10 @@ def reload_model_weights(sd_model=None, info=None):
     timer.record("find config")
 
     if sd_model is None or checkpoint_config != sd_model.used_config:
-        del sd_model
+        if sd_model is not None:
+            sd_model.to(device="meta")
+
+        devices.torch_gc()
         load_model(checkpoint_info, already_loaded_state_dict=state_dict)
         return model_data.sd_model
 

+ 13 - 6
modules/shared.py

@@ -220,12 +220,19 @@ class State:
             return
 
         import modules.sd_samplers
-        if opts.show_progress_grid:
-            self.assign_current_image(modules.sd_samplers.samples_to_image_grid(self.current_latent))
-        else:
-            self.assign_current_image(modules.sd_samplers.sample_to_image(self.current_latent))
 
-        self.current_image_sampling_step = self.sampling_step
+        try:
+            if opts.show_progress_grid:
+                self.assign_current_image(modules.sd_samplers.samples_to_image_grid(self.current_latent))
+            else:
+                self.assign_current_image(modules.sd_samplers.sample_to_image(self.current_latent))
+
+            self.current_image_sampling_step = self.sampling_step
+
+        except Exception:
+            # when switching models during genration, VAE would be on CPU, so creating an image will fail.
+            # we silently ignore this error
+            errors.record_exception()
 
     def assign_current_image(self, image):
         self.current_image = image
@@ -521,7 +528,7 @@ options_templates.update(options_section(('ui', "User interface"), {
     "ui_tab_order": OptionInfo([], "UI tab order", ui_components.DropdownMulti, lambda: {"choices": list(tab_names)}).needs_restart(),
     "hidden_tabs": OptionInfo([], "Hidden UI tabs", ui_components.DropdownMulti, lambda: {"choices": list(tab_names)}).needs_restart(),
     "ui_reorder_list": OptionInfo([], "txt2img/img2img UI item order", ui_components.DropdownMulti, lambda: {"choices": list(shared_items.ui_reorder_categories())}).info("selected items appear first").needs_restart(),
-    "hires_fix_show_sampler": OptionInfo(False, "Hires fix: show hires sampler selection").needs_restart(),
+    "hires_fix_show_sampler": OptionInfo(False, "Hires fix: show hires checkpoint and sampler selection").needs_restart(),
     "hires_fix_show_prompts": OptionInfo(False, "Hires fix: show hires prompt and negative prompt").needs_restart(),
     "disable_token_counters": OptionInfo(False, "Disable prompt token counters").needs_restart(),
 }))

+ 2 - 1
modules/txt2img.py

@@ -9,7 +9,7 @@ from modules.ui import plaintext_to_html
 import gradio as gr
 
 
-def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, hr_sampler_index: int, hr_prompt: str, hr_negative_prompt, override_settings_texts, request: gr.Request, *args):
+def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, hr_checkpoint_name: str, hr_sampler_index: int, hr_prompt: str, hr_negative_prompt, override_settings_texts, request: gr.Request, *args):
     override_settings = create_override_settings_dict(override_settings_texts)
 
     p = processing.StableDiffusionProcessingTxt2Img(
@@ -41,6 +41,7 @@ def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, step
         hr_second_pass_steps=hr_second_pass_steps,
         hr_resize_x=hr_resize_x,
         hr_resize_y=hr_resize_y,
+        hr_checkpoint_name=None if hr_checkpoint_name == 'Use same checkpoint' else hr_checkpoint_name,
         hr_sampler_name=sd_samplers.samplers_for_img2img[hr_sampler_index - 1].name if hr_sampler_index != 0 else None,
         hr_prompt=hr_prompt,
         hr_negative_prompt=hr_negative_prompt,

+ 7 - 1
modules/ui.py

@@ -456,6 +456,10 @@ def create_ui():
                                 hr_resize_y = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize height to", value=0, elem_id="txt2img_hr_resize_y")
 
                             with FormRow(elem_id="txt2img_hires_fix_row3", variant="compact", visible=opts.hires_fix_show_sampler) as hr_sampler_container:
+
+                                hr_checkpoint_name = gr.Dropdown(label='Hires checkpoint', elem_id="hr_checkpoint", choices=["Use same checkpoint"] + modules.sd_models.checkpoint_tiles(use_short=True), value="Use same checkpoint")
+                                create_refresh_button(hr_checkpoint_name, modules.sd_models.list_models, lambda: {"choices": ["Use same checkpoint"] + modules.sd_models.checkpoint_tiles(use_short=True)}, "hr_checkpoint_refresh")
+
                                 hr_sampler_index = gr.Dropdown(label='Hires sampling method', elem_id="hr_sampler", choices=["Use same sampler"] + [x.name for x in samplers_for_img2img], value="Use same sampler", type="index")
 
                             with FormRow(elem_id="txt2img_hires_fix_row4", variant="compact", visible=opts.hires_fix_show_prompts) as hr_prompts_container:
@@ -533,6 +537,7 @@ def create_ui():
                     hr_second_pass_steps,
                     hr_resize_x,
                     hr_resize_y,
+                    hr_checkpoint_name,
                     hr_sampler_index,
                     hr_prompt,
                     hr_negative_prompt,
@@ -599,8 +604,9 @@ def create_ui():
                 (hr_second_pass_steps, "Hires steps"),
                 (hr_resize_x, "Hires resize-1"),
                 (hr_resize_y, "Hires resize-2"),
+                (hr_checkpoint_name, "Hires checkpoint"),
                 (hr_sampler_index, "Hires sampler"),
-                (hr_sampler_container, lambda d: gr.update(visible=True) if d.get("Hires sampler", "Use same sampler") != "Use same sampler" else gr.update()),
+                (hr_sampler_container, lambda d: gr.update(visible=True) if d.get("Hires sampler", "Use same sampler") != "Use same sampler" or d.get("Hires checkpoint", "Use same checkpoint") != "Use same checkpoint" else gr.update()),
                 (hr_prompt, "Hires prompt"),
                 (hr_negative_prompt, "Hires negative prompt"),
                 (hr_prompts_container, lambda d: gr.update(visible=True) if d.get("Hires prompt", "") != "" or d.get("Hires negative prompt", "") != "" else gr.update()),