Ver código fonte

add an option (on by default) to disable T5
revert t5xxl back to fp16

AUTOMATIC1111 1 ano atrás
pai
commit
34b4443cc3
2 arquivos alterados com 20 adições e 6 exclusões
  1. 16 6
      modules/models/sd3/sd3_model.py
  2. 4 0
      modules/shared_options.py

+ 16 - 6
modules/models/sd3/sd3_model.py

@@ -29,7 +29,7 @@ CLIPL_CONFIG = {
     "num_hidden_layers": 12,
     "num_hidden_layers": 12,
 }
 }
 
 
-T5_URL = "https://huggingface.co/AUTOMATIC/stable-diffusion-3-medium-text-encoders/resolve/main/t5xxl_fp8_e4m3fn.safetensors"
+T5_URL = "https://huggingface.co/AUTOMATIC/stable-diffusion-3-medium-text-encoders/resolve/main/t5xxl_fp16.safetensors"
 T5_CONFIG = {
 T5_CONFIG = {
     "d_ff": 10240,
     "d_ff": 10240,
     "d_model": 4096,
     "d_model": 4096,
@@ -63,7 +63,11 @@ class SD3Cond(torch.nn.Module):
         with torch.no_grad():
         with torch.no_grad():
             self.clip_g = SDXLClipG(CLIPG_CONFIG, device="cpu", dtype=devices.dtype)
             self.clip_g = SDXLClipG(CLIPG_CONFIG, device="cpu", dtype=devices.dtype)
             self.clip_l = SDClipModel(layer="hidden", layer_idx=-2, device="cpu", dtype=devices.dtype, layer_norm_hidden_state=False, return_projected_pooled=False, textmodel_json_config=CLIPL_CONFIG)
             self.clip_l = SDClipModel(layer="hidden", layer_idx=-2, device="cpu", dtype=devices.dtype, layer_norm_hidden_state=False, return_projected_pooled=False, textmodel_json_config=CLIPL_CONFIG)
-            self.t5xxl = T5XXLModel(T5_CONFIG, device="cpu", dtype=devices.dtype)
+
+            if shared.opts.sd3_enable_t5:
+                self.t5xxl = T5XXLModel(T5_CONFIG, device="cpu", dtype=devices.dtype)
+            else:
+                self.t5xxl = None
 
 
         self.weights_loaded = False
         self.weights_loaded = False
 
 
@@ -74,7 +78,12 @@ class SD3Cond(torch.nn.Module):
             tokens = self.tokenizer.tokenize_with_weights(prompt)
             tokens = self.tokenizer.tokenize_with_weights(prompt)
             l_out, l_pooled = self.clip_l.encode_token_weights(tokens["l"])
             l_out, l_pooled = self.clip_l.encode_token_weights(tokens["l"])
             g_out, g_pooled = self.clip_g.encode_token_weights(tokens["g"])
             g_out, g_pooled = self.clip_g.encode_token_weights(tokens["g"])
-            t5_out, t5_pooled = self.t5xxl.encode_token_weights(tokens["t5xxl"])
+
+            if self.t5xxl and shared.opts.sd3_enable_t5:
+                t5_out, t5_pooled = self.t5xxl.encode_token_weights(tokens["t5xxl"])
+            else:
+                t5_out = torch.zeros(l_out.shape[0:2] + (4096,), dtype=l_out.dtype, device=l_out.device)
+
             lg_out = torch.cat([l_out, g_out], dim=-1)
             lg_out = torch.cat([l_out, g_out], dim=-1)
             lg_out = torch.nn.functional.pad(lg_out, (0, 4096 - lg_out.shape[-1]))
             lg_out = torch.nn.functional.pad(lg_out, (0, 4096 - lg_out.shape[-1]))
             lgt_out = torch.cat([lg_out, t5_out], dim=-2)
             lgt_out = torch.cat([lg_out, t5_out], dim=-2)
@@ -101,9 +110,10 @@ class SD3Cond(torch.nn.Module):
         with safetensors.safe_open(clip_l_file, framework="pt") as file:
         with safetensors.safe_open(clip_l_file, framework="pt") as file:
             self.clip_l.transformer.load_state_dict(SafetensorsMapping(file), strict=False)
             self.clip_l.transformer.load_state_dict(SafetensorsMapping(file), strict=False)
 
 
-        t5_file = modelloader.load_file_from_url(T5_URL, model_dir=clip_path, file_name="t5xxl_fp8_e4m3fn.safetensors")
-        with safetensors.safe_open(t5_file, framework="pt") as file:
-            self.t5xxl.transformer.load_state_dict(SafetensorsMapping(file), strict=False)
+        if self.t5xxl:
+            t5_file = modelloader.load_file_from_url(T5_URL, model_dir=clip_path, file_name="t5xxl_fp16.safetensors")
+            with safetensors.safe_open(t5_file, framework="pt") as file:
+                self.t5xxl.transformer.load_state_dict(SafetensorsMapping(file), strict=False)
 
 
         self.weights_loaded = True
         self.weights_loaded = True
 
 

+ 4 - 0
modules/shared_options.py

@@ -191,6 +191,10 @@ options_templates.update(options_section(('sdxl', "Stable Diffusion XL", "sd"),
     "sdxl_refiner_high_aesthetic_score": OptionInfo(6.0, "SDXL high aesthetic score", gr.Number).info("used for refiner model prompt"),
     "sdxl_refiner_high_aesthetic_score": OptionInfo(6.0, "SDXL high aesthetic score", gr.Number).info("used for refiner model prompt"),
 }))
 }))
 
 
+options_templates.update(options_section(('sd3', "Stable Diffusion 3", "sd"), {
+    "sd3_enable_t5": OptionInfo(False, "Enable T5").info("load T5 text encoder; increases VRAM use by a lot, potentially improving quality of generation; requires model reload to apply"),
+}))
+
 options_templates.update(options_section(('vae', "VAE", "sd"), {
 options_templates.update(options_section(('vae', "VAE", "sd"), {
     "sd_vae_explanation": OptionHTML("""
     "sd_vae_explanation": OptionHTML("""
 <abbr title='Variational autoencoder'>VAE</abbr> is a neural network that transforms a standard <abbr title='red/green/blue'>RGB</abbr>
 <abbr title='Variational autoencoder'>VAE</abbr> is a neural network that transforms a standard <abbr title='red/green/blue'>RGB</abbr>