Browse Source

add XL support for live previews: approx and TAESD

AUTOMATIC1111 2 years ago
parent
commit
b8159d0919
3 changed files with 40 additions and 25 deletions
  1. 1 1
      modules/sd_models_xl.py
  2. 26 11
      modules/sd_vae_approx.py
  3. 13 13
      modules/sd_vae_taesd.py

+ 1 - 1
modules/sd_models_xl.py

@@ -48,7 +48,7 @@ def extend_sdxl(model):
     discretization = sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization()
     model.alphas_cumprod = torch.asarray(discretization.alphas_cumprod, device=devices.device, dtype=dtype)
 
-    model.is_xl = True
+    model.is_sdxl = True
 
 
 sgm.models.diffusion.DiffusionEngine.get_learned_conditioning = get_learned_conditioning

+ 26 - 11
modules/sd_vae_approx.py

@@ -2,9 +2,9 @@ import os
 
 import torch
 from torch import nn
-from modules import devices, paths
+from modules import devices, paths, shared
 
-sd_vae_approx_model = None
+sd_vae_approx_models = {}
 
 
 class VAEApprox(nn.Module):
@@ -31,19 +31,34 @@ class VAEApprox(nn.Module):
         return x
 
 
+def download_model(model_path, model_url):
+    if not os.path.exists(model_path):
+        os.makedirs(os.path.dirname(model_path), exist_ok=True)
+
+        print(f'Downloading VAEApprox model to: {model_path}')
+        torch.hub.download_url_to_file(model_url, model_path)
+
+
 def model():
-    global sd_vae_approx_model
+    model_name = "vaeapprox-sdxl.pt" if getattr(shared.sd_model, 'is_sdxl', False) else "model.pt"
+    loaded_model = sd_vae_approx_models.get(model_name)
 
-    if sd_vae_approx_model is None:
-        model_path = os.path.join(paths.models_path, "VAE-approx", "model.pt")
-        sd_vae_approx_model = VAEApprox()
+    if loaded_model is None:
+        model_path = os.path.join(paths.models_path, "VAE-approx", model_name)
         if not os.path.exists(model_path):
-            model_path = os.path.join(paths.script_path, "models", "VAE-approx", "model.pt")
-        sd_vae_approx_model.load_state_dict(torch.load(model_path, map_location='cpu' if devices.device.type != 'cuda' else None))
-        sd_vae_approx_model.eval()
-        sd_vae_approx_model.to(devices.device, devices.dtype)
+            model_path = os.path.join(paths.script_path, "models", "VAE-approx", model_name)
+
+        if not os.path.exists(model_path):
+            model_path = os.path.join(paths.models_path, "VAE-approx", model_name)
+            download_model(model_path, 'https://github.com/AUTOMATIC1111/stable-diffusion-webui/releases/download/v1.0.0-pre/' + model_name)
+
+        loaded_model = VAEApprox()
+        loaded_model.load_state_dict(torch.load(model_path, map_location='cpu' if devices.device.type != 'cuda' else None))
+        loaded_model.eval()
+        loaded_model.to(devices.device, devices.dtype)
+        sd_vae_approx_models[model_name] = loaded_model
 
-    return sd_vae_approx_model
+    return loaded_model
 
 
 def cheap_approximation(sample):

+ 13 - 13
modules/sd_vae_taesd.py

@@ -8,9 +8,9 @@ import os
 import torch
 import torch.nn as nn
 
-from modules import devices, paths_internal
+from modules import devices, paths_internal, shared
 
-sd_vae_taesd = None
+sd_vae_taesd_models = {}
 
 
 def conv(n_in, n_out, **kwargs):
@@ -61,9 +61,7 @@ class TAESD(nn.Module):
         return x.sub(TAESD.latent_shift).mul(2 * TAESD.latent_magnitude)
 
 
-def download_model(model_path):
-    model_url = 'https://github.com/madebyollin/taesd/raw/main/taesd_decoder.pth'
-
+def download_model(model_path, model_url):
     if not os.path.exists(model_path):
         os.makedirs(os.path.dirname(model_path), exist_ok=True)
 
@@ -72,17 +70,19 @@ def download_model(model_path):
 
 
 def model():
-    global sd_vae_taesd
+    model_name = "taesdxl_decoder.pth" if getattr(shared.sd_model, 'is_sdxl', False) else "taesd_decoder.pth"
+    loaded_model = sd_vae_taesd_models.get(model_name)
 
-    if sd_vae_taesd is None:
-        model_path = os.path.join(paths_internal.models_path, "VAE-taesd", "taesd_decoder.pth")
-        download_model(model_path)
+    if loaded_model is None:
+        model_path = os.path.join(paths_internal.models_path, "VAE-taesd", model_name)
+        download_model(model_path, 'https://github.com/madebyollin/taesd/raw/main/' + model_name)
 
         if os.path.exists(model_path):
-            sd_vae_taesd = TAESD(model_path)
-            sd_vae_taesd.eval()
-            sd_vae_taesd.to(devices.device, devices.dtype)
+            loaded_model = TAESD(model_path)
+            loaded_model.eval()
+            loaded_model.to(devices.device, devices.dtype)
+            sd_vae_taesd_models[model_name] = loaded_model
         else:
             raise FileNotFoundError('TAESD model not found')
 
-    return sd_vae_taesd.decoder
+    return loaded_model.decoder