ソースを参照

add TAESD for i2i and t2i

Kohaku-Blueleaf 2 年 前
コミット
75336dfc84

+ 6 - 7
modules/processing.py

@@ -573,9 +573,10 @@ def decode_latent_batch(model, batch, target_device=None, check_for_nans=False):
 
 
 def decode_first_stage(model, x):
-    x = model.decode_first_stage(x.to(devices.dtype_vae))
-
-    return x
+    from modules.sd_samplers_common import samples_to_images_tensor, approximation_indexes
+    x = x.to(devices.dtype_vae)
+    approx_index = approximation_indexes.get(opts.sd_vae_decode_method, 0)
+    return samples_to_images_tensor(x, approx_index, model)
 
 
 def get_fixed_seed(seed):
@@ -1344,10 +1345,8 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
             raise RuntimeError(f"bad number of images passed: {len(imgs)}; expecting {self.batch_size} or less")
 
         image = torch.from_numpy(batch_images)
-        image = 2. * image - 1.
-        image = image.to(shared.device, dtype=devices.dtype_vae)
-
-        self.init_latent = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(image))
+        from modules.sd_samplers_common import images_tensor_to_samples, approximation_indexes
+        self.init_latent = images_tensor_to_samples(image, approximation_indexes.get(opts.sd_vae_encode_method), self.sd_model)
         devices.torch_gc()
 
         if self.resize_mode == 3:

+ 33 - 5
modules/sd_samplers_common.py

@@ -23,19 +23,29 @@ def setup_img2img_steps(p, steps=None):
 approximation_indexes = {"Full": 0, "Approx NN": 1, "Approx cheap": 2, "TAESD": 3}
 
 
-def single_sample_to_image(sample, approximation=None):
+def samples_to_images_tensor(sample, approximation=None, model=None):
+    '''latents -> images [-1, 1]'''
     if approximation is None:
         approximation = approximation_indexes.get(opts.show_progress_type, 0)
 
     if approximation == 2:
-        x_sample = sd_vae_approx.cheap_approximation(sample) * 0.5 + 0.5
+        x_sample = sd_vae_approx.cheap_approximation(sample)
     elif approximation == 1:
-        x_sample = sd_vae_approx.model()(sample.to(devices.device, devices.dtype).unsqueeze(0))[0].detach() * 0.5 + 0.5
+        x_sample = sd_vae_approx.model()(sample.to(devices.device, devices.dtype)).detach()
     elif approximation == 3:
         x_sample = sample * 1.5
-        x_sample = sd_vae_taesd.model()(x_sample.to(devices.device, devices.dtype).unsqueeze(0))[0].detach()
+        x_sample = sd_vae_taesd.decoder_model()(x_sample.to(devices.device, devices.dtype)).detach()
+        x_sample = x_sample * 2 - 1
     else:
-        x_sample = processing.decode_first_stage(shared.sd_model, sample.unsqueeze(0))[0] * 0.5 + 0.5
+        if model is None:
+            model = shared.sd_model
+        x_sample = model.decode_first_stage(sample)
+    
+    return x_sample
+
+
+def single_sample_to_image(sample, approximation=None):
+    x_sample = samples_to_images_tensor(sample.unsqueeze(0), approximation)[0] * 0.5 + 0.5
 
     x_sample = torch.clamp(x_sample, min=0.0, max=1.0)
     x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
@@ -52,6 +62,24 @@ def samples_to_image_grid(samples, approximation=None):
     return images.image_grid([single_sample_to_image(sample, approximation) for sample in samples])
 
 
+def images_tensor_to_samples(image, approximation=None, model=None):
+    '''image[0, 1] -> latent'''
+    if approximation is None:
+        approximation = approximation_indexes.get(opts.sd_vae_encode_method, 0)
+
+    if approximation == 3:
+        image = image.to(devices.device, devices.dtype)
+        x_latent = sd_vae_taesd.encoder_model()(image) / 1.5
+    else:
+        if model is None:
+            model = shared.sd_model
+        image = image.to(shared.device, dtype=devices.dtype_vae)
+        image = image * 2 - 1
+        x_latent = model.get_first_stage_encoding(model.encode_first_stage(image))
+
+    return x_latent
+
+
 def store_latent(decoded):
     state.current_latent = decoded
 

+ 1 - 1
modules/sd_vae_approx.py

@@ -81,6 +81,6 @@ def cheap_approximation(sample):
 
     coefs = torch.tensor(coeffs).to(sample.device)
 
-    x_sample = torch.einsum("lxy,lr -> rxy", sample, coefs)
+    x_sample = torch.einsum("...lxy,lr -> ...rxy", sample, coefs)
 
     return x_sample

+ 44 - 8
modules/sd_vae_taesd.py

@@ -44,7 +44,17 @@ def decoder():
     )
 
 
-class TAESD(nn.Module):
+def encoder():
+    return nn.Sequential(
+        conv(3, 64), Block(64, 64),
+        conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64),
+        conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64),
+        conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64),
+        conv(64, 4),
+    )
+
+
+class TAESDDecoder(nn.Module):
     latent_magnitude = 3
     latent_shift = 0.5
 
@@ -55,21 +65,28 @@ class TAESD(nn.Module):
         self.decoder.load_state_dict(
             torch.load(decoder_path, map_location='cpu' if devices.device.type != 'cuda' else None))
 
-    @staticmethod
-    def unscale_latents(x):
-        """[0, 1] -> raw latents"""
-        return x.sub(TAESD.latent_shift).mul(2 * TAESD.latent_magnitude)
+
+class TAESDEncoder(nn.Module):
+    latent_magnitude = 3
+    latent_shift = 0.5
+
+    def __init__(self, encoder_path="taesd_encoder.pth"):
+        """Initialize pretrained TAESD on the given device from the given checkpoints."""
+        super().__init__()
+        self.encoder = encoder()
+        self.encoder.load_state_dict(
+            torch.load(encoder_path, map_location='cpu' if devices.device.type != 'cuda' else None))
 
 
 def download_model(model_path, model_url):
     if not os.path.exists(model_path):
         os.makedirs(os.path.dirname(model_path), exist_ok=True)
 
-        print(f'Downloading TAESD decoder to: {model_path}')
+        print(f'Downloading TAESD model to: {model_path}')
         torch.hub.download_url_to_file(model_url, model_path)
 
 
-def model():
+def decoder_model():
     model_name = "taesdxl_decoder.pth" if getattr(shared.sd_model, 'is_sdxl', False) else "taesd_decoder.pth"
     loaded_model = sd_vae_taesd_models.get(model_name)
 
@@ -78,7 +95,7 @@ def model():
         download_model(model_path, 'https://github.com/madebyollin/taesd/raw/main/' + model_name)
 
         if os.path.exists(model_path):
-            loaded_model = TAESD(model_path)
+            loaded_model = TAESDDecoder(model_path)
             loaded_model.eval()
             loaded_model.to(devices.device, devices.dtype)
             sd_vae_taesd_models[model_name] = loaded_model
@@ -86,3 +103,22 @@ def model():
             raise FileNotFoundError('TAESD model not found')
 
     return loaded_model.decoder
+
+
+def encoder_model():
+    model_name = "taesdxl_encoder.pth" if getattr(shared.sd_model, 'is_sdxl', False) else "taesd_encoder.pth"
+    loaded_model = sd_vae_taesd_models.get(model_name)
+
+    if loaded_model is None:
+        model_path = os.path.join(paths_internal.models_path, "VAE-taesd", model_name)
+        download_model(model_path, 'https://github.com/madebyollin/taesd/raw/main/' + model_name)
+
+        if os.path.exists(model_path):
+            loaded_model = TAESDEncoder(model_path)
+            loaded_model.eval()
+            loaded_model.to(devices.device, devices.dtype)
+            sd_vae_taesd_models[model_name] = loaded_model
+        else:
+            raise FileNotFoundError('TAESD model not found')
+
+    return loaded_model.encoder

+ 2 - 0
modules/shared.py

@@ -430,6 +430,8 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), {
     "upcast_attn": OptionInfo(False, "Upcast cross attention layer to float32"),
     "auto_vae_precision": OptionInfo(True, "Automaticlly revert VAE to 32-bit floats").info("triggers when a tensor with NaNs is produced in VAE; disabling the option in this case will result in a black square image"),
     "randn_source": OptionInfo("GPU", "Random number generator source.", gr.Radio, {"choices": ["GPU", "CPU", "NV"]}).info("changes seeds drastically; use CPU to produce the same picture across different videocard vendors; use NV to produce same picture as on NVidia videocards"),
+    "sd_vae_encode_method": OptionInfo("Full", "VAE type for encode", gr.Radio, {"choices": ["Full", "TAESD"]}).info("method to encode image to latent (use in img2img or inpaint mask)"),
+    "sd_vae_decode_method": OptionInfo("Full", "VAE type for decode", gr.Radio, {"choices": ["Full", "TAESD"]}).info("method to decode latent to image"),
 }))
 
 options_templates.update(options_section(('sdxl', "Stable Diffusion XL"), {