|
@@ -29,7 +29,7 @@ CLIPL_CONFIG = {
|
|
"num_hidden_layers": 12,
|
|
"num_hidden_layers": 12,
|
|
}
|
|
}
|
|
|
|
|
|
-T5_URL = "https://huggingface.co/AUTOMATIC/stable-diffusion-3-medium-text-encoders/resolve/main/t5xxl_fp8_e4m3fn.safetensors"
|
|
|
|
|
|
+T5_URL = "https://huggingface.co/AUTOMATIC/stable-diffusion-3-medium-text-encoders/resolve/main/t5xxl_fp16.safetensors"
|
|
T5_CONFIG = {
|
|
T5_CONFIG = {
|
|
"d_ff": 10240,
|
|
"d_ff": 10240,
|
|
"d_model": 4096,
|
|
"d_model": 4096,
|
|
@@ -63,7 +63,11 @@ class SD3Cond(torch.nn.Module):
|
|
with torch.no_grad():
|
|
with torch.no_grad():
|
|
self.clip_g = SDXLClipG(CLIPG_CONFIG, device="cpu", dtype=devices.dtype)
|
|
self.clip_g = SDXLClipG(CLIPG_CONFIG, device="cpu", dtype=devices.dtype)
|
|
self.clip_l = SDClipModel(layer="hidden", layer_idx=-2, device="cpu", dtype=devices.dtype, layer_norm_hidden_state=False, return_projected_pooled=False, textmodel_json_config=CLIPL_CONFIG)
|
|
self.clip_l = SDClipModel(layer="hidden", layer_idx=-2, device="cpu", dtype=devices.dtype, layer_norm_hidden_state=False, return_projected_pooled=False, textmodel_json_config=CLIPL_CONFIG)
|
|
- self.t5xxl = T5XXLModel(T5_CONFIG, device="cpu", dtype=devices.dtype)
|
|
|
|
|
|
+
|
|
|
|
+ if shared.opts.sd3_enable_t5:
|
|
|
|
+ self.t5xxl = T5XXLModel(T5_CONFIG, device="cpu", dtype=devices.dtype)
|
|
|
|
+ else:
|
|
|
|
+ self.t5xxl = None
|
|
|
|
|
|
self.weights_loaded = False
|
|
self.weights_loaded = False
|
|
|
|
|
|
@@ -74,7 +78,12 @@ class SD3Cond(torch.nn.Module):
|
|
tokens = self.tokenizer.tokenize_with_weights(prompt)
|
|
tokens = self.tokenizer.tokenize_with_weights(prompt)
|
|
l_out, l_pooled = self.clip_l.encode_token_weights(tokens["l"])
|
|
l_out, l_pooled = self.clip_l.encode_token_weights(tokens["l"])
|
|
g_out, g_pooled = self.clip_g.encode_token_weights(tokens["g"])
|
|
g_out, g_pooled = self.clip_g.encode_token_weights(tokens["g"])
|
|
- t5_out, t5_pooled = self.t5xxl.encode_token_weights(tokens["t5xxl"])
|
|
|
|
|
|
+
|
|
|
|
+ if self.t5xxl and shared.opts.sd3_enable_t5:
|
|
|
|
+ t5_out, t5_pooled = self.t5xxl.encode_token_weights(tokens["t5xxl"])
|
|
|
|
+ else:
|
|
|
|
+ t5_out = torch.zeros(l_out.shape[0:2] + (4096,), dtype=l_out.dtype, device=l_out.device)
|
|
|
|
+
|
|
lg_out = torch.cat([l_out, g_out], dim=-1)
|
|
lg_out = torch.cat([l_out, g_out], dim=-1)
|
|
lg_out = torch.nn.functional.pad(lg_out, (0, 4096 - lg_out.shape[-1]))
|
|
lg_out = torch.nn.functional.pad(lg_out, (0, 4096 - lg_out.shape[-1]))
|
|
lgt_out = torch.cat([lg_out, t5_out], dim=-2)
|
|
lgt_out = torch.cat([lg_out, t5_out], dim=-2)
|
|
@@ -101,9 +110,10 @@ class SD3Cond(torch.nn.Module):
|
|
with safetensors.safe_open(clip_l_file, framework="pt") as file:
|
|
with safetensors.safe_open(clip_l_file, framework="pt") as file:
|
|
self.clip_l.transformer.load_state_dict(SafetensorsMapping(file), strict=False)
|
|
self.clip_l.transformer.load_state_dict(SafetensorsMapping(file), strict=False)
|
|
|
|
|
|
- t5_file = modelloader.load_file_from_url(T5_URL, model_dir=clip_path, file_name="t5xxl_fp8_e4m3fn.safetensors")
|
|
|
|
- with safetensors.safe_open(t5_file, framework="pt") as file:
|
|
|
|
- self.t5xxl.transformer.load_state_dict(SafetensorsMapping(file), strict=False)
|
|
|
|
|
|
+ if self.t5xxl:
|
|
|
|
+ t5_file = modelloader.load_file_from_url(T5_URL, model_dir=clip_path, file_name="t5xxl_fp16.safetensors")
|
|
|
|
+ with safetensors.safe_open(t5_file, framework="pt") as file:
|
|
|
|
+ self.t5xxl.transformer.load_state_dict(SafetensorsMapping(file), strict=False)
|
|
|
|
|
|
self.weights_loaded = True
|
|
self.weights_loaded = True
|
|
|
|
|