Преглед на файлове

automatically detect v-parameterization for SD2 checkpoints

AUTOMATIC преди 2 години
родител
ревизия
d04e3e921e
променени са 2 файла, в които са добавени 48 реда и са изтрити 5 реда
  1. 2 0
      modules/sd_hijack.py
  2. 46 5
      modules/sd_models_config.py

+ 2 - 0
modules/sd_hijack.py

@@ -131,6 +131,8 @@ class StableDiffusionModelHijack:
             m.cond_stage_model.wrapped.model.token_embedding = m.cond_stage_model.wrapped.model.token_embedding.wrapped
             m.cond_stage_model.wrapped.model.token_embedding = m.cond_stage_model.wrapped.model.token_embedding.wrapped
             m.cond_stage_model = m.cond_stage_model.wrapped
             m.cond_stage_model = m.cond_stage_model.wrapped
 
 
+        undo_optimizations()
+
         self.apply_circular(False)
         self.apply_circular(False)
         self.layers = None
         self.layers = None
         self.clip = None
         self.clip = None

+ 46 - 5
modules/sd_models_config.py

@@ -1,7 +1,9 @@
 import re
 import re
 import os
 import os
 
 
-from modules import shared, paths
+import torch
+
+from modules import shared, paths, sd_disable_initialization
 
 
 sd_configs_path = shared.sd_configs_path
 sd_configs_path = shared.sd_configs_path
 sd_repo_configs_path = os.path.join(paths.paths['Stable Diffusion'], "configs", "stable-diffusion")
 sd_repo_configs_path = os.path.join(paths.paths['Stable Diffusion'], "configs", "stable-diffusion")
@@ -16,12 +18,51 @@ config_inpainting = os.path.join(sd_configs_path, "v1-inpainting-inference.yaml"
 config_instruct_pix2pix = os.path.join(sd_configs_path, "instruct-pix2pix.yaml")
 config_instruct_pix2pix = os.path.join(sd_configs_path, "instruct-pix2pix.yaml")
 config_alt_diffusion = os.path.join(sd_configs_path, "alt-diffusion-inference.yaml")
 config_alt_diffusion = os.path.join(sd_configs_path, "alt-diffusion-inference.yaml")
 
 
-re_parametrization_v = re.compile(r'-v\b')
 
 
+def is_using_v_parameterization_for_sd2(state_dict):
+    """
+    Detects whether unet in state_dict is using v-parameterization. Returns True if it is. You're welcome.
+    """
 
 
-def guess_model_config_from_state_dict(sd, filename):
-    fn = os.path.basename(filename)
+    import ldm.modules.diffusionmodules.openaimodel
+    from modules import devices
+
+    device = devices.cpu
+
+    with sd_disable_initialization.DisableInitialization():
+        unet = ldm.modules.diffusionmodules.openaimodel.UNetModel(
+            use_checkpoint=True,
+            use_fp16=False,
+            image_size=32,
+            in_channels=4,
+            out_channels=4,
+            model_channels=320,
+            attention_resolutions=[4, 2, 1],
+            num_res_blocks=2,
+            channel_mult=[1, 2, 4, 4],
+            num_head_channels=64,
+            use_spatial_transformer=True,
+            use_linear_in_transformer=True,
+            transformer_depth=1,
+            context_dim=1024,
+            legacy=False
+        )
+        unet.eval()
+
+    with torch.no_grad():
+        unet_sd = {k.replace("model.diffusion_model.", ""): v for k, v in state_dict.items() if "model.diffusion_model." in k}
+        unet.load_state_dict(unet_sd, strict=True)
+        unet.to(device=device, dtype=torch.float)
 
 
+        test_cond = torch.ones((1, 2, 1024), device=device) * 0.5
+        x_test = torch.ones((1, 4, 8, 8), device=device) * 0.5
+
+        out = (unet(x_test, torch.asarray([999], device=device), context=test_cond) - x_test).mean().item()
+
+    return out < -1
+
+
+def guess_model_config_from_state_dict(sd, filename):
     sd2_cond_proj_weight = sd.get('cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight', None)
     sd2_cond_proj_weight = sd.get('cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight', None)
     diffusion_model_input = sd.get('model.diffusion_model.input_blocks.0.0.weight', None)
     diffusion_model_input = sd.get('model.diffusion_model.input_blocks.0.0.weight', None)
 
 
@@ -31,7 +72,7 @@ def guess_model_config_from_state_dict(sd, filename):
     if sd2_cond_proj_weight is not None and sd2_cond_proj_weight.shape[1] == 1024:
     if sd2_cond_proj_weight is not None and sd2_cond_proj_weight.shape[1] == 1024:
         if diffusion_model_input.shape[1] == 9:
         if diffusion_model_input.shape[1] == 9:
             return config_sd2_inpainting
             return config_sd2_inpainting
-        elif re.search(re_parametrization_v, fn):
+        elif is_using_v_parameterization_for_sd2(sd):
             return config_sd2v
             return config_sd2v
         else:
         else:
             return config_sd2
             return config_sd2