Эх сурвалжийг харах

support loading clip/t5 from the main model checkpoint

AUTOMATIC1111 1 жил өмнө
parent
commit
7e4b06fcd0

+ 11 - 19
modules/models/sd3/sd3_cond.py

@@ -174,15 +174,10 @@ class SD3Cond(torch.nn.Module):
             self.model_lg = Sd3ClipLG(self.clip_l, self.clip_g)
             self.model_t5 = Sd3T5(self.t5xxl)
 
-        self.weights_loaded = False
-
     def forward(self, prompts: list[str]):
         with devices.without_autocast():
             lg_out, vector_out = self.model_lg(prompts)
-
-            token_count = lg_out.shape[1]
-
-            t5_out = self.model_t5(prompts, token_count=token_count)
+            t5_out = self.model_t5(prompts, token_count=lg_out.shape[1])
             lgt_out = torch.cat([lg_out, t5_out], dim=-2)
 
         return {
@@ -190,27 +185,24 @@ class SD3Cond(torch.nn.Module):
             'vector': vector_out,
         }
 
-    def load_weights(self):
-        if self.weights_loaded:
-            return
-
+    def before_load_weights(self, state_dict):
         clip_path = os.path.join(shared.models_path, "CLIP")
 
-        clip_g_file = modelloader.load_file_from_url(CLIPG_URL, model_dir=clip_path, file_name="clip_g.safetensors")
-        with safetensors.safe_open(clip_g_file, framework="pt") as file:
-            self.clip_g.transformer.load_state_dict(SafetensorsMapping(file))
+        if 'text_encoders.clip_g.transformer.text_model.embeddings.position_embedding.weight' not in state_dict:
+            clip_g_file = modelloader.load_file_from_url(CLIPG_URL, model_dir=clip_path, file_name="clip_g.safetensors")
+            with safetensors.safe_open(clip_g_file, framework="pt") as file:
+                self.clip_g.transformer.load_state_dict(SafetensorsMapping(file))
 
-        clip_l_file = modelloader.load_file_from_url(CLIPL_URL, model_dir=clip_path, file_name="clip_l.safetensors")
-        with safetensors.safe_open(clip_l_file, framework="pt") as file:
-            self.clip_l.transformer.load_state_dict(SafetensorsMapping(file), strict=False)
+        if 'text_encoders.clip_l.transformer.text_model.embeddings.position_embedding.weight' not in state_dict:
+            clip_l_file = modelloader.load_file_from_url(CLIPL_URL, model_dir=clip_path, file_name="clip_l.safetensors")
+            with safetensors.safe_open(clip_l_file, framework="pt") as file:
+                self.clip_l.transformer.load_state_dict(SafetensorsMapping(file), strict=False)
 
-        if self.t5xxl:
+        if self.t5xxl and 'text_encoders.t5xxl.transformer.encoder.embed_tokens.weight' not in state_dict:
             t5_file = modelloader.load_file_from_url(T5_URL, model_dir=clip_path, file_name="t5xxl_fp16.safetensors")
             with safetensors.safe_open(t5_file, framework="pt") as file:
                 self.t5xxl.transformer.load_state_dict(SafetensorsMapping(file), strict=False)
 
-        self.weights_loaded = True
-
     def encode_embedding_init_text(self, init_text, nvpt):
         return torch.tensor([[0]], device=devices.device) # XXX
 

+ 7 - 3
modules/models/sd3/sd3_model.py

@@ -31,7 +31,7 @@ class SD3Inferencer(torch.nn.Module):
 
         self.alphas_cumprod = 1 / (self.model.model_sampling.sigmas ** 2 + 1)
 
-        self.cond_stage_model = SD3Cond()
+        self.text_encoders = SD3Cond()
         self.cond_stage_key = 'txt'
 
         self.parameterization = "eps"
@@ -40,8 +40,12 @@ class SD3Inferencer(torch.nn.Module):
         self.latent_format = SD3LatentFormat()
         self.latent_channels = 16
 
-    def after_load_weights(self):
-        self.cond_stage_model.load_weights()
+    @property
+    def cond_stage_model(self):
+        return self.text_encoders
+
+    def before_load_weights(self, state_dict):
+        self.cond_stage_model.before_load_weights(state_dict)
 
     def ema_scope(self):
         return contextlib.nullcontext()

+ 6 - 3
modules/sd_models.py

@@ -434,9 +434,15 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
         # cache newly loaded model
         checkpoints_loaded[checkpoint_info] = state_dict.copy()
 
+    if hasattr(model, "before_load_weights"):
+        model.before_load_weights(state_dict)
+
     model.load_state_dict(state_dict, strict=False)
     timer.record("apply weights to model")
 
+    if hasattr(model, "after_load_weights"):
+        model.after_load_weights(state_dict)
+
     del state_dict
 
     # Set is_sdxl_inpaint flag.
@@ -838,9 +844,6 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
     with sd_disable_initialization.LoadStateDictOnMeta(state_dict, device=model_target_device(sd_model), weight_dtype_conversion=weight_dtype_conversion):
         load_model_weights(sd_model, checkpoint_info, state_dict, timer)
 
-    if hasattr(sd_model, "after_load_weights"):
-        sd_model.after_load_weights()
-
     timer.record("load weights from state dict")
 
     send_model_to_device(sd_model)