Browse Source

Support SDXL v-pred models

catboxanon 10 months ago
parent
commit
1ae073c052
3 changed files with 106 additions and 3 deletions
  1. 98 0
      configs/sd_xl_v.yaml
  2. 4 3
      modules/sd_models.py
  3. 4 0
      modules/sd_models_config.py

+ 98 - 0
configs/sd_xl_v.yaml

@@ -0,0 +1,98 @@
+model:
+  target: sgm.models.diffusion.DiffusionEngine
+  params:
+    scale_factor: 0.13025
+    disable_first_stage_autocast: True
+
+    denoiser_config:
+      target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser
+      params:
+        num_idx: 1000
+
+        weighting_config:
+          target: sgm.modules.diffusionmodules.denoiser_weighting.EpsWeighting
+        scaling_config:
+          target: sgm.modules.diffusionmodules.denoiser_scaling.VScaling
+        discretization_config:
+          target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
+
+    network_config:
+      target: sgm.modules.diffusionmodules.openaimodel.UNetModel
+      params:
+        adm_in_channels: 2816
+        num_classes: sequential
+        use_checkpoint: True
+        in_channels: 4
+        out_channels: 4
+        model_channels: 320
+        attention_resolutions: [4, 2]
+        num_res_blocks: 2
+        channel_mult: [1, 2, 4]
+        num_head_channels: 64
+        use_spatial_transformer: True
+        use_linear_in_transformer: True
+        transformer_depth: [1, 2, 10]  # note: the first is unused (due to attn_res starting at 2) 32, 16, 8 --> 64, 32, 16
+        context_dim: 2048
+        spatial_transformer_attn_type: softmax-xformers
+        legacy: False
+
+    conditioner_config:
+      target: sgm.modules.GeneralConditioner
+      params:
+        emb_models:
+          # crossattn cond
+          - is_trainable: False
+            input_key: txt
+            target: sgm.modules.encoders.modules.FrozenCLIPEmbedder
+            params:
+              layer: hidden
+              layer_idx: 11
+          # crossattn and vector cond
+          - is_trainable: False
+            input_key: txt
+            target: sgm.modules.encoders.modules.FrozenOpenCLIPEmbedder2
+            params:
+              arch: ViT-bigG-14
+              version: laion2b_s39b_b160k
+              freeze: True
+              layer: penultimate
+              always_return_pooled: True
+              legacy: False
+          # vector cond
+          - is_trainable: False
+            input_key: original_size_as_tuple
+            target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
+            params:
+              outdim: 256  # multiplied by two
+          # vector cond
+          - is_trainable: False
+            input_key: crop_coords_top_left
+            target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
+            params:
+              outdim: 256  # multiplied by two
+          # vector cond
+          - is_trainable: False
+            input_key: target_size_as_tuple
+            target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
+            params:
+              outdim: 256  # multiplied by two
+
+    first_stage_config:
+      target: sgm.models.autoencoder.AutoencoderKLInferenceWrapper
+      params:
+        embed_dim: 4
+        monitor: val/rec_loss
+        ddconfig:
+          attn_type: vanilla-xformers
+          double_z: true
+          z_channels: 4
+          resolution: 256
+          in_channels: 3
+          out_ch: 3
+          ch: 128
+          ch_mult: [1, 2, 4, 4]
+          num_res_blocks: 2
+          attn_resolutions: []
+          dropout: 0.0
+        lossconfig:
+          target: torch.nn.Identity

+ 4 - 3
modules/sd_models.py

@@ -783,7 +783,7 @@ def get_obj_from_str(string, reload=False):
     return getattr(importlib.import_module(module, package=None), cls)
 
 
-def load_model(checkpoint_info=None, already_loaded_state_dict=None):
+def load_model(checkpoint_info=None, already_loaded_state_dict=None, checkpoint_config=None):
     from modules import sd_hijack
     checkpoint_info = checkpoint_info or select_checkpoint()
 
@@ -801,7 +801,8 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
     else:
         state_dict = get_checkpoint_state_dict(checkpoint_info, timer)
 
-    checkpoint_config = sd_models_config.find_checkpoint_config(state_dict, checkpoint_info)
+    if not checkpoint_config:
+        checkpoint_config = sd_models_config.find_checkpoint_config(state_dict, checkpoint_info)
     clip_is_included_into_sd = any(x for x in [sd1_clip_weight, sd2_clip_weight, sdxl_clip_weight, sdxl_refiner_clip_weight] if x in state_dict)
 
     timer.record("find config")
@@ -974,7 +975,7 @@ def reload_model_weights(sd_model=None, info=None, forced_reload=False):
         if sd_model is not None:
             send_model_to_trash(sd_model)
 
-        load_model(checkpoint_info, already_loaded_state_dict=state_dict)
+        load_model(checkpoint_info, already_loaded_state_dict=state_dict, checkpoint_config=checkpoint_config)
         return model_data.sd_model
 
     try:

+ 4 - 0
modules/sd_models_config.py

@@ -14,6 +14,7 @@ config_sd2 = os.path.join(sd_repo_configs_path, "v2-inference.yaml")
 config_sd2v = os.path.join(sd_repo_configs_path, "v2-inference-v.yaml")
 config_sd2_inpainting = os.path.join(sd_repo_configs_path, "v2-inpainting-inference.yaml")
 config_sdxl = os.path.join(sd_xl_repo_configs_path, "sd_xl_base.yaml")
+config_sdxlv = os.path.join(sd_configs_path, "sd_xl_v.yaml")
 config_sdxl_refiner = os.path.join(sd_xl_repo_configs_path, "sd_xl_refiner.yaml")
 config_sdxl_inpainting = os.path.join(sd_configs_path, "sd_xl_inpaint.yaml")
 config_depth_model = os.path.join(sd_repo_configs_path, "v2-midas-inference.yaml")
@@ -81,6 +82,9 @@ def guess_model_config_from_state_dict(sd, filename):
         if diffusion_model_input.shape[1] == 9:
             return config_sdxl_inpainting
         else:
+            if ('v_pred' in sd):
+                del sd['v_pred']
+                return config_sdxlv
             return config_sdxl
 
     if sd.get('conditioner.embedders.0.model.ln_final.weight', None) is not None: