Răsfoiți Sursa

Fix sdxl inpaint

huchenlei 1 an în urmă
părinte
comite
d875cda565
2 a modificat fișierele cu 14 adăugiri și 9 ștergeri
  1. 2 2
      modules/processing.py
  2. 12 7
      modules/sd_models.py

+ 2 - 2
modules/processing.py

@@ -115,7 +115,7 @@ def txt2img_image_conditioning(sd_model, x, width, height):
         return x.new_zeros(x.shape[0], 2*sd_model.noise_augmentor.time_embed.dim, dtype=x.dtype, device=x.device)
 
     else:
-        if getattr(sd_model.model, "is_sdxl_inpaint", False):
+        if sd_model.is_sdxl_inpaint:
             # The "masked-image" in this case will just be all 0.5 since the entire image is masked.
             image_conditioning = torch.ones(x.shape[0], 3, height, width, device=x.device) * 0.5
             image_conditioning = images_tensor_to_samples(image_conditioning,
@@ -389,7 +389,7 @@ class StableDiffusionProcessing:
         if self.sampler.conditioning_key == "crossattn-adm":
             return self.unclip_image_conditioning(source_image)
 
-        if getattr(self.sampler.model_wrap.inner_model.model, "is_sdxl_inpaint", False):
+        if self.sampler.model_wrap.inner_model.is_sdxl_inpaint:
             return self.inpainting_image_conditioning(source_image, latent_image, image_mask=image_mask)
 
         # Dummy zero conditioning if we're not using inpainting or depth model.

+ 12 - 7
modules/sd_models.py

@@ -386,13 +386,6 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
     model.is_sd2 = not model.is_sdxl and hasattr(model.cond_stage_model, 'model')
     model.is_sd1 = not model.is_sdxl and not model.is_sd2
     model.is_ssd = model.is_sdxl and 'model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_q.weight' not in state_dict.keys()
-    # Set is_sdxl_inpaint flag.
-    diffusion_model_input = state_dict.get('diffusion_model.input_blocks.0.0.weight', None)
-    model.is_sdxl_inpaint = (
-        model.is_sdxl and
-        diffusion_model_input is not None and
-        diffusion_model_input.shape[1] == 9
-    )
     if model.is_sdxl:
         sd_models_xl.extend_sdxl(model)
 
@@ -408,6 +401,18 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
 
     del state_dict
 
+    # Set is_sdxl_inpaint flag.
+    # Perform this check after model initialization to make sure state_dict
+    # structure is already known.
+    diffusion_model_input = model.model.state_dict().get(
+        'diffusion_model.input_blocks.0.0.weight'
+    )
+    model.is_sdxl_inpaint = (
+        model.is_sdxl and
+        diffusion_model_input is not None and
+        diffusion_model_input.shape[1] == 9
+    )
+
     if shared.cmd_opts.opt_channelslast:
         model.to(memory_format=torch.channels_last)
         timer.record("apply channels_last")