Browse Source

prevent accidental creation of CLIP models in float32 type when user wants float16

AUTOMATIC1111 1 year ago
parent
commit
b443fdcf76
2 changed files with 4 additions and 3 deletions
  1. 3 3
      modules/models/sd3/sd3_model.py
  2. 1 0
      modules/sd_models.py

+ 3 - 3
modules/models/sd3/sd3_model.py

@@ -61,9 +61,9 @@ class SD3Cond(torch.nn.Module):
         self.tokenizer = SD3Tokenizer()
 
         with torch.no_grad():
-            self.clip_g = SDXLClipG(CLIPG_CONFIG, device="cpu", dtype=torch.float32)
-            self.clip_l = SDClipModel(layer="hidden", layer_idx=-2, device="cpu", dtype=torch.float32, layer_norm_hidden_state=False, return_projected_pooled=False, textmodel_json_config=CLIPL_CONFIG)
-            self.t5xxl = T5XXLModel(T5_CONFIG, device="cpu", dtype=torch.float32)
+            self.clip_g = SDXLClipG(CLIPG_CONFIG, device="cpu", dtype=devices.dtype)
+            self.clip_l = SDClipModel(layer="hidden", layer_idx=-2, device="cpu", dtype=devices.dtype, layer_norm_hidden_state=False, return_projected_pooled=False, textmodel_json_config=CLIPL_CONFIG)
+            self.t5xxl = T5XXLModel(T5_CONFIG, device="cpu", dtype=devices.dtype)
 
         self.weights_loaded = False
 

+ 1 - 0
modules/sd_models.py

@@ -406,6 +406,7 @@ def set_model_fields(model):
     if not hasattr(model, 'latent_channels'):
         model.latent_channels = 4
 
+
 def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer):
     sd_model_hash = checkpoint_info.calculate_shorthash()
     timer.record("calculate hash")