AUTOMATIC1111 1 год назад
Родитель
Сommit
5b2a60b8e2

+ 1 - 1
README.md

@@ -150,7 +150,7 @@ For the purposes of getting Google and other search engines to crawl the wiki, h
 ## Credits
 Licenses for borrowed code can be found in `Settings -> Licenses` screen, and also in `html/licenses.html` file.
 
-- Stable Diffusion - https://github.com/Stability-AI/stablediffusion, https://github.com/CompVis/taming-transformers
+- Stable Diffusion - https://github.com/Stability-AI/stablediffusion, https://github.com/CompVis/taming-transformers, https://github.com/mcmonkey4eva/sd3-ref
 - k-diffusion - https://github.com/crowsonkb/k-diffusion.git
 - Spandrel - https://github.com/chaiNNer-org/spandrel implementing
   - GFPGAN - https://github.com/TencentARC/GFPGAN.git

+ 5 - 0
configs/sd3-inference.yaml

@@ -0,0 +1,5 @@
+model:
+  target: modules.models.sd3.sd3_model.SD3Inferencer
+  params:
+    shift: 3
+    state_dict: null

+ 3 - 1
extensions-builtin/Lora/networks.py

@@ -130,7 +130,9 @@ def assign_network_names_to_compvis_modules(sd_model):
                 network_layer_mapping[network_name] = module
                 module.network_layer_name = network_name
     else:
-        for name, module in shared.sd_model.cond_stage_model.wrapped.named_modules():
+        cond_stage_model = getattr(shared.sd_model.cond_stage_model, 'wrapped', shared.sd_model.cond_stage_model)
+
+        for name, module in cond_stage_model.named_modules():
             network_name = name.replace(".", "_")
             network_layer_mapping[network_name] = module
             module.network_layer_name = network_name

+ 2 - 1
modules/models/sd3/mmdit.py

@@ -6,7 +6,8 @@ import numpy as np
 import torch
 import torch.nn as nn
 from einops import rearrange, repeat
-from other_impls import attention, Mlp
+from modules.models.sd3.other_impls import attention, Mlp
+
 
 class PatchEmbed(nn.Module):
     """ 2D Image to Patch Embedding"""

+ 7 - 7
modules/models/sd3/sd3_impls.py

@@ -1,7 +1,7 @@
 ### Impls of the SD3 core diffusion model and VAE
 
 import torch, math, einops
-from mmdit import MMDiT
+from modules.models.sd3.mmdit import MMDiT
 from PIL import Image
 
 
@@ -46,16 +46,16 @@ class ModelSamplingDiscreteFlow(torch.nn.Module):
 
 class BaseModel(torch.nn.Module):
     """Wrapper around the core MM-DiT model"""
-    def __init__(self, shift=1.0, device=None, dtype=torch.float32, file=None, prefix=""):
+    def __init__(self, shift=1.0, device=None, dtype=torch.float32, state_dict=None, prefix=""):
         super().__init__()
         # Important configuration values can be quickly determined by checking shapes in the source file
         # Some of these will vary between models (eg 2B vs 8B primarily differ in their depth, but also other details change)
-        patch_size = file.get_tensor(f"{prefix}x_embedder.proj.weight").shape[2]
-        depth = file.get_tensor(f"{prefix}x_embedder.proj.weight").shape[0] // 64
-        num_patches = file.get_tensor(f"{prefix}pos_embed").shape[1]
+        patch_size = state_dict[f"{prefix}x_embedder.proj.weight"].shape[2]
+        depth = state_dict[f"{prefix}x_embedder.proj.weight"].shape[0] // 64
+        num_patches = state_dict[f"{prefix}pos_embed"].shape[1]
         pos_embed_max_size = round(math.sqrt(num_patches))
-        adm_in_channels = file.get_tensor(f"{prefix}y_embedder.mlp.0.weight").shape[1]
-        context_shape = file.get_tensor(f"{prefix}context_embedder.weight").shape
+        adm_in_channels = state_dict[f"{prefix}y_embedder.mlp.0.weight"].shape[1]
+        context_shape = state_dict[f"{prefix}context_embedder.weight"].shape
         context_embedder_config = {
             "target": "torch.nn.Linear",
             "params": {

+ 166 - 0
modules/models/sd3/sd3_model.py

@@ -0,0 +1,166 @@
+import contextlib
+import os
+from typing import Mapping
+
+import safetensors
+import torch
+
+import k_diffusion
+from modules.models.sd3.other_impls import SDClipModel, SDXLClipG, T5XXLModel, SD3Tokenizer
+from modules.models.sd3.sd3_impls import BaseModel, SDVAE, SD3LatentFormat
+
+from modules import shared, modelloader, devices
+
+CLIPG_URL = "https://huggingface.co/stabilityai/stable-diffusion-3-medium/resolve/main/text_encoders/clip_g.safetensors"
+CLIPG_CONFIG = {
+    "hidden_act": "gelu",
+    "hidden_size": 1280,
+    "intermediate_size": 5120,
+    "num_attention_heads": 20,
+    "num_hidden_layers": 32,
+}
+
+CLIPL_URL = "https://huggingface.co/stabilityai/stable-diffusion-3-medium/resolve/main/text_encoders/clip_l.safetensors"
+CLIPL_CONFIG = {
+    "hidden_act": "quick_gelu",
+    "hidden_size": 768,
+    "intermediate_size": 3072,
+    "num_attention_heads": 12,
+    "num_hidden_layers": 12,
+}
+
+T5_URL = "https://huggingface.co/stabilityai/stable-diffusion-3-medium/resolve/main/text_encoders/t5xxl_fp16.safetensors"
+T5_CONFIG = {
+    "d_ff": 10240,
+    "d_model": 4096,
+    "num_heads": 64,
+    "num_layers": 24,
+    "vocab_size": 32128,
+}
+
+
+class SafetensorsMapping(Mapping):
+    def __init__(self, file):
+        self.file = file
+
+    def __len__(self):
+        return len(self.file.keys())
+
+    def __iter__(self):
+        for key in self.file.keys():
+            yield key
+
+    def __getitem__(self, key):
+        return self.file.get_tensor(key)
+
+
+class SD3Cond(torch.nn.Module):
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+
+        self.tokenizer = SD3Tokenizer()
+
+        with torch.no_grad():
+            self.clip_g = SDXLClipG(CLIPG_CONFIG, device="cpu", dtype=torch.float32)
+            self.clip_l = SDClipModel(layer="hidden", layer_idx=-2, device="cpu", dtype=torch.float32, layer_norm_hidden_state=False, return_projected_pooled=False, textmodel_json_config=CLIPL_CONFIG)
+            self.t5xxl = T5XXLModel(T5_CONFIG, device="cpu", dtype=torch.float32)
+
+        self.weights_loaded = False
+
+    def forward(self, prompts: list[str]):
+        res = []
+
+        for prompt in prompts:
+            tokens = self.tokenizer.tokenize_with_weights(prompt)
+            l_out, l_pooled = self.clip_l.encode_token_weights(tokens["l"])
+            g_out, g_pooled = self.clip_g.encode_token_weights(tokens["g"])
+            t5_out, t5_pooled = self.t5xxl.encode_token_weights(tokens["t5xxl"])
+            lg_out = torch.cat([l_out, g_out], dim=-1)
+            lg_out = torch.nn.functional.pad(lg_out, (0, 4096 - lg_out.shape[-1]))
+            lgt_out = torch.cat([lg_out, t5_out], dim=-2)
+            vector_out = torch.cat((l_pooled, g_pooled), dim=-1)
+
+            res.append({
+                'crossattn': lgt_out[0].to(devices.device),
+                'vector': vector_out[0].to(devices.device),
+            })
+
+        return res
+
+    def load_weights(self):
+        if self.weights_loaded:
+            return
+
+        clip_path = os.path.join(shared.models_path, "CLIP")
+
+        clip_g_file = modelloader.load_file_from_url(CLIPG_URL, model_dir=clip_path, file_name="clip_g.safetensors")
+        with safetensors.safe_open(clip_g_file, framework="pt") as file:
+            self.clip_g.transformer.load_state_dict(SafetensorsMapping(file))
+
+        clip_l_file = modelloader.load_file_from_url(CLIPL_URL, model_dir=clip_path, file_name="clip_l.safetensors")
+        with safetensors.safe_open(clip_l_file, framework="pt") as file:
+            self.clip_l.transformer.load_state_dict(SafetensorsMapping(file), strict=False)
+
+        t5_file = modelloader.load_file_from_url(T5_URL, model_dir=clip_path, file_name="t5xxl_fp16.safetensors")
+        with safetensors.safe_open(t5_file, framework="pt") as file:
+            self.t5xxl.transformer.load_state_dict(SafetensorsMapping(file), strict=False)
+
+        self.weights_loaded = True
+
+    def encode_embedding_init_text(self, init_text, nvpt):
+        return torch.tensor([[0]], device=devices.device) # XXX
+
+
+class SD3Denoiser(k_diffusion.external.DiscreteSchedule):
+    def __init__(self, inner_model, sigmas):
+        super().__init__(sigmas, quantize=shared.opts.enable_quantization)
+        self.inner_model = inner_model
+
+    def forward(self, input, sigma, **kwargs):
+        return self.inner_model.apply_model(input, sigma, **kwargs)
+
+
+class SD3Inferencer(torch.nn.Module):
+    def __init__(self, state_dict, shift=3, use_ema=False):
+        super().__init__()
+
+        self.shift = shift
+
+        with torch.no_grad():
+            self.model = BaseModel(shift=shift, state_dict=state_dict, prefix="model.diffusion_model.", device="cpu", dtype=devices.dtype)
+            self.first_stage_model = SDVAE(device="cpu", dtype=devices.dtype_vae)
+            self.first_stage_model.dtype = self.model.diffusion_model.dtype
+
+        self.alphas_cumprod = 1 / (self.model.model_sampling.sigmas ** 2 + 1)
+
+        self.cond_stage_model = SD3Cond()
+        self.cond_stage_key = 'txt'
+
+        self.parameterization = "eps"
+        self.model.conditioning_key = "crossattn"
+
+        self.latent_format = SD3LatentFormat()
+        self.latent_channels = 16
+
+    def after_load_weights(self):
+        self.cond_stage_model.load_weights()
+
+    def ema_scope(self):
+        return contextlib.nullcontext()
+
+    def get_learned_conditioning(self, batch: list[str]):
+        return self.cond_stage_model(batch)
+
+    def apply_model(self, x, t, cond):
+        return self.model.apply_model(x, t, c_crossattn=cond['crossattn'], y=cond['vector'])
+
+    def decode_first_stage(self, latent):
+        latent = self.latent_format.process_out(latent)
+        return self.first_stage_model.decode(latent)
+
+    def encode_first_stage(self, image):
+        latent = self.first_stage_model.encode(image)
+        return self.latent_format.process_in(latent)
+
+    def create_denoiser(self):
+        return SD3Denoiser(self, self.model.model_sampling.sigmas)

+ 2 - 1
modules/processing.py

@@ -942,7 +942,8 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
             p.seeds = p.all_seeds[n * p.batch_size:(n + 1) * p.batch_size]
             p.subseeds = p.all_subseeds[n * p.batch_size:(n + 1) * p.batch_size]
 
-            p.rng = rng.ImageRNG((opt_C, p.height // opt_f, p.width // opt_f), p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, seed_resize_from_h=p.seed_resize_from_h, seed_resize_from_w=p.seed_resize_from_w)
+            latent_channels = getattr(shared.sd_model, 'latent_channels', opt_C)
+            p.rng = rng.ImageRNG((latent_channels, p.height // opt_f, p.width // opt_f), p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, seed_resize_from_h=p.seed_resize_from_h, seed_resize_from_w=p.seed_resize_from_w)
 
             if p.scripts is not None:
                 p.scripts.before_process_batch(p, batch_number=n, prompts=p.prompts, seeds=p.seeds, subseeds=p.subseeds)

+ 74 - 13
modules/sd_models.py

@@ -1,7 +1,9 @@
 import collections
+import importlib
 import os
 import sys
 import threading
+import enum
 
 import torch
 import re
@@ -10,8 +12,6 @@ from omegaconf import OmegaConf, ListConfig
 from urllib import request
 import ldm.modules.midas as midas
 
-from ldm.util import instantiate_from_config
-
 from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet, sd_models_xl, cache, extra_networks, processing, lowvram, sd_hijack, patches
 from modules.timer import Timer
 from modules.shared import opts
@@ -27,6 +27,14 @@ checkpoint_alisases = checkpoint_aliases  # for compatibility with old name
 checkpoints_loaded = collections.OrderedDict()
 
 
+class ModelType(enum.Enum):
+    SD1 = 1
+    SD2 = 2
+    SDXL = 3
+    SSD = 4
+    SD3 = 5
+
+
 def replace_key(d, key, new_key, value):
     keys = list(d.keys())
 
@@ -368,6 +376,36 @@ def check_fp8(model):
     return enable_fp8
 
 
+def set_model_type(model, state_dict):
+    model.is_sd1 = False
+    model.is_sd2 = False
+    model.is_sdxl = False
+    model.is_ssd = False
+    model.is_ssd3 = False
+
+    if "model.diffusion_model.x_embedder.proj.weight" in state_dict:
+        model.is_sd3 = True
+        model.model_type = ModelType.SD3
+    elif hasattr(model, 'conditioner'):
+        model.is_sdxl = True
+
+        if 'model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_q.weight' not in state_dict.keys():
+            model.is_ssd = True
+            model.model_type = ModelType.SSD
+        else:
+            model.model_type = ModelType.SDXL
+    elif hasattr(model.cond_stage_model, 'model'):
+        model.is_sd2 = True
+        model.model_type = ModelType.SD2
+    else:
+        model.is_sd1 = True
+        model.model_type = ModelType.SD1
+
+
+def set_model_fields(model):
+    if not hasattr(model, 'latent_channels'):
+        model.latent_channels = 4
+
 def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer):
     sd_model_hash = checkpoint_info.calculate_shorthash()
     timer.record("calculate hash")
@@ -382,10 +420,9 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
     if state_dict is None:
         state_dict = get_checkpoint_state_dict(checkpoint_info, timer)
 
-    model.is_sdxl = hasattr(model, 'conditioner')
-    model.is_sd2 = not model.is_sdxl and hasattr(model.cond_stage_model, 'model')
-    model.is_sd1 = not model.is_sdxl and not model.is_sd2
-    model.is_ssd = model.is_sdxl and 'model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_q.weight' not in state_dict.keys()
+    set_model_type(model, state_dict)
+    set_model_fields(model)
+
     if model.is_sdxl:
         sd_models_xl.extend_sdxl(model)
 
@@ -552,8 +589,7 @@ def patch_given_betas():
     original_register_schedule = patches.patch(__name__, ldm.models.diffusion.ddpm.DDPM, 'register_schedule', patched_register_schedule)
 
 
-def repair_config(sd_config):
-
+def repair_config(sd_config, state_dict=None):
     if not hasattr(sd_config.model.params, "use_ema"):
         sd_config.model.params.use_ema = False
 
@@ -563,8 +599,9 @@ def repair_config(sd_config):
         elif shared.cmd_opts.upcast_sampling or shared.cmd_opts.precision == "half":
             sd_config.model.params.unet_config.params.use_fp16 = True
 
-    if getattr(sd_config.model.params.first_stage_config.params.ddconfig, "attn_type", None) == "vanilla-xformers" and not shared.xformers_available:
-        sd_config.model.params.first_stage_config.params.ddconfig.attn_type = "vanilla"
+    if hasattr(sd_config.model.params, 'first_stage_config'):
+        if getattr(sd_config.model.params.first_stage_config.params.ddconfig, "attn_type", None) == "vanilla-xformers" and not shared.xformers_available:
+            sd_config.model.params.first_stage_config.params.ddconfig.attn_type = "vanilla"
 
     # For UnCLIP-L, override the hardcoded karlo directory
     if hasattr(sd_config.model.params, "noise_aug_config") and hasattr(sd_config.model.params.noise_aug_config.params, "clip_stats_path"):
@@ -580,6 +617,7 @@ def repair_config(sd_config):
         sd_config.model.params.unet_config.params.use_checkpoint = False
 
 
+
 def rescale_zero_terminal_snr_abar(alphas_cumprod):
     alphas_bar_sqrt = alphas_cumprod.sqrt()
 
@@ -715,6 +753,25 @@ def send_model_to_trash(m):
     devices.torch_gc()
 
 
+def instantiate_from_config(config, state_dict=None):
+    constructor = get_obj_from_str(config["target"])
+
+    params = {**config.get("params", {})}
+
+    if state_dict and "state_dict" in params and params["state_dict"] is None:
+        params["state_dict"] = state_dict
+
+    return constructor(**params)
+
+
+def get_obj_from_str(string, reload=False):
+    module, cls = string.rsplit(".", 1)
+    if reload:
+        module_imp = importlib.import_module(module)
+        importlib.reload(module_imp)
+    return getattr(importlib.import_module(module, package=None), cls)
+
+
 def load_model(checkpoint_info=None, already_loaded_state_dict=None):
     from modules import sd_hijack
     checkpoint_info = checkpoint_info or select_checkpoint()
@@ -739,7 +796,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
     timer.record("find config")
 
     sd_config = OmegaConf.load(checkpoint_config)
-    repair_config(sd_config)
+    repair_config(sd_config, state_dict)
 
     timer.record("load config")
 
@@ -749,7 +806,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
     try:
         with sd_disable_initialization.DisableInitialization(disable_clip=clip_is_included_into_sd or shared.cmd_opts.do_not_download_clip):
             with sd_disable_initialization.InitializeOnMeta():
-                sd_model = instantiate_from_config(sd_config.model)
+                sd_model = instantiate_from_config(sd_config.model, state_dict)
 
     except Exception as e:
         errors.display(e, "creating model quickly", full_traceback=True)
@@ -758,7 +815,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
         print('Failed to create model quickly; will retry using slow method.', file=sys.stderr)
 
         with sd_disable_initialization.InitializeOnMeta():
-            sd_model = instantiate_from_config(sd_config.model)
+            sd_model = instantiate_from_config(sd_config.model, state_dict)
 
     sd_model.used_config = checkpoint_config
 
@@ -775,6 +832,10 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
 
     with sd_disable_initialization.LoadStateDictOnMeta(state_dict, device=model_target_device(sd_model), weight_dtype_conversion=weight_dtype_conversion):
         load_model_weights(sd_model, checkpoint_info, state_dict, timer)
+
+    if hasattr(sd_model, "after_load_weights"):
+        sd_model.after_load_weights()
+
     timer.record("load weights from state dict")
 
     send_model_to_device(sd_model)

+ 6 - 1
modules/sd_models_config.py

@@ -23,6 +23,8 @@ config_inpainting = os.path.join(sd_configs_path, "v1-inpainting-inference.yaml"
 config_instruct_pix2pix = os.path.join(sd_configs_path, "instruct-pix2pix.yaml")
 config_alt_diffusion = os.path.join(sd_configs_path, "alt-diffusion-inference.yaml")
 config_alt_diffusion_m18 = os.path.join(sd_configs_path, "alt-diffusion-m18-inference.yaml")
+config_sd3 = os.path.join(sd_configs_path, "sd3-inference.yaml")
+
 
 def is_using_v_parameterization_for_sd2(state_dict):
     """
@@ -71,11 +73,15 @@ def guess_model_config_from_state_dict(sd, filename):
     diffusion_model_input = sd.get('model.diffusion_model.input_blocks.0.0.weight', None)
     sd2_variations_weight = sd.get('embedder.model.ln_final.weight', None)
 
+    if "model.diffusion_model.x_embedder.proj.weight" in sd:
+        return config_sd3
+
     if sd.get('conditioner.embedders.1.model.ln_final.weight', None) is not None:
         if diffusion_model_input.shape[1] == 9:
             return config_sdxl_inpainting
         else:
             return config_sdxl
+
     if sd.get('conditioner.embedders.0.model.ln_final.weight', None) is not None:
         return config_sdxl_refiner
     elif sd.get('depth_model.model.pretrained.act_postprocess3.0.project.0.bias', None) is not None:
@@ -99,7 +105,6 @@ def guess_model_config_from_state_dict(sd, filename):
         if diffusion_model_input.shape[1] == 8:
             return config_instruct_pix2pix
 
-
     if sd.get('cond_stage_model.roberta.embeddings.word_embeddings.weight', None) is not None:
         if sd.get('cond_stage_model.transformation.weight').size()[0] == 1024:
             return config_alt_diffusion_m18

+ 6 - 0
modules/sd_models_types.py

@@ -32,3 +32,9 @@ class WebuiSdModel(LatentDiffusion):
 
     is_sd1: bool
     """True if the model's architecture is SD 1.x"""
+
+    is_sd3: bool
+    """True if the model's architecture is SD 3"""
+
+    latent_channels: int
+    """number of layer in latent image representation; will be 16 in SD3 and 4 in other version"""

+ 2 - 2
modules/sd_samplers_common.py

@@ -54,7 +54,7 @@ def samples_to_images_tensor(sample, approximation=None, model=None):
     else:
         if model is None:
             model = shared.sd_model
-        with devices.without_autocast(): # fixes an issue with unstable VAEs that are flaky even in fp32
+        with torch.no_grad(), devices.without_autocast(): # fixes an issue with unstable VAEs that are flaky even in fp32
             x_sample = model.decode_first_stage(sample.to(model.first_stage_model.dtype))
 
     return x_sample
@@ -246,7 +246,7 @@ class Sampler:
         self.eta_infotext_field = 'Eta'
         self.eta_default = 1.0
 
-        self.conditioning_key = shared.sd_model.model.conditioning_key
+        self.conditioning_key = getattr(shared.sd_model.model, 'conditioning_key', 'crossattn')
 
         self.p = None
         self.model_wrap_cfg = None

+ 7 - 2
modules/sd_samplers_kdiffusion.py

@@ -53,8 +53,13 @@ class CFGDenoiserKDiffusion(sd_samplers_cfg_denoiser.CFGDenoiser):
     @property
     def inner_model(self):
         if self.model_wrap is None:
-            denoiser = k_diffusion.external.CompVisVDenoiser if shared.sd_model.parameterization == "v" else k_diffusion.external.CompVisDenoiser
-            self.model_wrap = denoiser(shared.sd_model, quantize=shared.opts.enable_quantization)
+            denoiser_constructor = getattr(shared.sd_model, 'create_denoiser', None)
+
+            if denoiser_constructor is not None:
+                self.model_wrap = denoiser_constructor()
+            else:
+                denoiser = k_diffusion.external.CompVisVDenoiser if shared.sd_model.parameterization == "v" else k_diffusion.external.CompVisDenoiser
+                self.model_wrap = denoiser(shared.sd_model, quantize=shared.opts.enable_quantization)
 
         return self.model_wrap
 

+ 22 - 5
modules/sd_vae_approx.py

@@ -8,9 +8,9 @@ sd_vae_approx_models = {}
 
 
 class VAEApprox(nn.Module):
-    def __init__(self):
+    def __init__(self, latent_channels=4):
         super(VAEApprox, self).__init__()
-        self.conv1 = nn.Conv2d(4, 8, (7, 7))
+        self.conv1 = nn.Conv2d(latent_channels, 8, (7, 7))
         self.conv2 = nn.Conv2d(8, 16, (5, 5))
         self.conv3 = nn.Conv2d(16, 32, (3, 3))
         self.conv4 = nn.Conv2d(32, 64, (3, 3))
@@ -40,7 +40,13 @@ def download_model(model_path, model_url):
 
 
 def model():
-    model_name = "vaeapprox-sdxl.pt" if getattr(shared.sd_model, 'is_sdxl', False) else "model.pt"
+    if shared.sd_model.is_sd3:
+        model_name = "vaeapprox-sd3.pt"
+    elif shared.sd_model.is_sdxl:
+        model_name = "vaeapprox-sdxl.pt"
+    else:
+        model_name = "model.pt"
+
     loaded_model = sd_vae_approx_models.get(model_name)
 
     if loaded_model is None:
@@ -52,7 +58,7 @@ def model():
             model_path = os.path.join(paths.models_path, "VAE-approx", model_name)
             download_model(model_path, 'https://github.com/AUTOMATIC1111/stable-diffusion-webui/releases/download/v1.0.0-pre/' + model_name)
 
-        loaded_model = VAEApprox()
+        loaded_model = VAEApprox(latent_channels=shared.sd_model.latent_channels)
         loaded_model.load_state_dict(torch.load(model_path, map_location='cpu' if devices.device.type != 'cuda' else None))
         loaded_model.eval()
         loaded_model.to(devices.device, devices.dtype)
@@ -64,7 +70,18 @@ def model():
 def cheap_approximation(sample):
     # https://discuss.huggingface.co/t/decoding-latents-to-rgb-without-upscaling/23204/2
 
-    if shared.sd_model.is_sdxl:
+    if shared.sd_model.is_sd3:
+        coeffs = [
+            [-0.0645,  0.0177,  0.1052], [ 0.0028,  0.0312,  0.0650],
+            [ 0.1848,  0.0762,  0.0360], [ 0.0944,  0.0360,  0.0889],
+            [ 0.0897,  0.0506, -0.0364], [-0.0020,  0.1203,  0.0284],
+            [ 0.0855,  0.0118,  0.0283], [-0.0539,  0.0658,  0.1047],
+            [-0.0057,  0.0116,  0.0700], [-0.0412,  0.0281, -0.0039],
+            [ 0.1106,  0.1171,  0.1220], [-0.0248,  0.0682, -0.0481],
+            [ 0.0815,  0.0846,  0.1207], [-0.0120, -0.0055, -0.0867],
+            [-0.0749, -0.0634, -0.0456], [-0.1418, -0.1457, -0.1259],
+        ]
+    elif shared.sd_model.is_sdxl:
         coeffs = [
             [ 0.3448,  0.4168,  0.4395],
             [-0.1953, -0.0290,  0.0250],

+ 30 - 10
modules/sd_vae_taesd.py

@@ -34,9 +34,9 @@ class Block(nn.Module):
         return self.fuse(self.conv(x) + self.skip(x))
 
 
-def decoder():
+def decoder(latent_channels=4):
     return nn.Sequential(
-        Clamp(), conv(4, 64), nn.ReLU(),
+        Clamp(), conv(latent_channels, 64), nn.ReLU(),
         Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False),
         Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False),
         Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False),
@@ -44,13 +44,13 @@ def decoder():
     )
 
 
-def encoder():
+def encoder(latent_channels=4):
     return nn.Sequential(
         conv(3, 64), Block(64, 64),
         conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64),
         conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64),
         conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64),
-        conv(64, 4),
+        conv(64, latent_channels),
     )
 
 
@@ -58,10 +58,14 @@ class TAESDDecoder(nn.Module):
     latent_magnitude = 3
     latent_shift = 0.5
 
-    def __init__(self, decoder_path="taesd_decoder.pth"):
+    def __init__(self, decoder_path="taesd_decoder.pth", latent_channels=None):
         """Initialize pretrained TAESD on the given device from the given checkpoints."""
         super().__init__()
-        self.decoder = decoder()
+
+        if latent_channels is None:
+            latent_channels = 16 if "taesd3" in str(decoder_path) else 4
+
+        self.decoder = decoder(latent_channels)
         self.decoder.load_state_dict(
             torch.load(decoder_path, map_location='cpu' if devices.device.type != 'cuda' else None))
 
@@ -70,10 +74,14 @@ class TAESDEncoder(nn.Module):
     latent_magnitude = 3
     latent_shift = 0.5
 
-    def __init__(self, encoder_path="taesd_encoder.pth"):
+    def __init__(self, encoder_path="taesd_encoder.pth", latent_channels=None):
         """Initialize pretrained TAESD on the given device from the given checkpoints."""
         super().__init__()
-        self.encoder = encoder()
+
+        if latent_channels is None:
+            latent_channels = 16 if "taesd3" in str(encoder_path) else 4
+
+        self.encoder = encoder(latent_channels)
         self.encoder.load_state_dict(
             torch.load(encoder_path, map_location='cpu' if devices.device.type != 'cuda' else None))
 
@@ -87,7 +95,13 @@ def download_model(model_path, model_url):
 
 
 def decoder_model():
-    model_name = "taesdxl_decoder.pth" if getattr(shared.sd_model, 'is_sdxl', False) else "taesd_decoder.pth"
+    if shared.sd_model.is_sd3:
+        model_name = "taesd3_decoder.pth"
+    elif shared.sd_model.is_sdxl:
+        model_name = "taesdxl_decoder.pth"
+    else:
+        model_name = "taesd_decoder.pth"
+
     loaded_model = sd_vae_taesd_models.get(model_name)
 
     if loaded_model is None:
@@ -106,7 +120,13 @@ def decoder_model():
 
 
 def encoder_model():
-    model_name = "taesdxl_encoder.pth" if getattr(shared.sd_model, 'is_sdxl', False) else "taesd_encoder.pth"
+    if shared.sd_model.is_sd3:
+        model_name = "taesd3_encoder.pth"
+    elif shared.sd_model.is_sdxl:
+        model_name = "taesdxl_encoder.pth"
+    else:
+        model_name = "taesd_encoder.pth"
+
     loaded_model = sd_vae_taesd_models.get(model_name)
 
     if loaded_model is None: