瀏覽代碼

remove the need to place configs near models

AUTOMATIC 2 年之前
父節點
當前提交
d2ac95fa7b

+ 99 - 0
configs/instruct-pix2pix.yaml

@@ -0,0 +1,99 @@
+# File modified by authors of InstructPix2Pix from original (https://github.com/CompVis/stable-diffusion).
+# See more details in LICENSE.
+
+model:
+  base_learning_rate: 1.0e-04
+  target: modules.models.diffusion.ddpm_edit.LatentDiffusion
+  params:
+    linear_start: 0.00085
+    linear_end: 0.0120
+    num_timesteps_cond: 1
+    log_every_t: 200
+    timesteps: 1000
+    first_stage_key: edited
+    cond_stage_key: edit
+    # image_size: 64
+    # image_size: 32
+    image_size: 16
+    channels: 4
+    cond_stage_trainable: false   # Note: different from the one we trained before
+    conditioning_key: hybrid
+    monitor: val/loss_simple_ema
+    scale_factor: 0.18215
+    use_ema: true
+    load_ema: true
+
+    scheduler_config: # 10000 warmup steps
+      target: ldm.lr_scheduler.LambdaLinearScheduler
+      params:
+        warm_up_steps: [ 0 ]
+        cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
+        f_start: [ 1.e-6 ]
+        f_max: [ 1. ]
+        f_min: [ 1. ]
+
+    unet_config:
+      target: ldm.modules.diffusionmodules.openaimodel.UNetModel
+      params:
+        image_size: 32 # unused
+        in_channels: 8
+        out_channels: 4
+        model_channels: 320
+        attention_resolutions: [ 4, 2, 1 ]
+        num_res_blocks: 2
+        channel_mult: [ 1, 2, 4, 4 ]
+        num_heads: 8
+        use_spatial_transformer: True
+        transformer_depth: 1
+        context_dim: 768
+        use_checkpoint: True
+        legacy: False
+
+    first_stage_config:
+      target: ldm.models.autoencoder.AutoencoderKL
+      params:
+        embed_dim: 4
+        monitor: val/rec_loss
+        ddconfig:
+          double_z: true
+          z_channels: 4
+          resolution: 256
+          in_channels: 3
+          out_ch: 3
+          ch: 128
+          ch_mult:
+          - 1
+          - 2
+          - 4
+          - 4
+          num_res_blocks: 2
+          attn_resolutions: []
+          dropout: 0.0
+        lossconfig:
+          target: torch.nn.Identity
+
+    cond_stage_config:
+      target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
+
+data:
+  target: main.DataModuleFromConfig
+  params:
+    batch_size: 128
+    num_workers: 1
+    wrap: false
+    validation:
+      target: edit_dataset.EditDataset
+      params:
+        path: data/clip-filtered-dataset
+        cache_dir:  data/
+        cache_name: data_10k
+        split: val
+        min_text_sim: 0.2
+        min_image_sim: 0.75
+        min_direction_sim: 0.2
+        max_samples_per_prompt: 1
+        min_resize_res: 512
+        max_resize_res: 512
+        crop_res: 512
+        output_as_edit: False
+        real_input: True

+ 19 - 17
v2-inference-v.yaml → configs/v1-inpainting-inference.yaml

@@ -1,8 +1,7 @@
 model:
-  base_learning_rate: 1.0e-4
-  target: ldm.models.diffusion.ddpm.LatentDiffusion
+  base_learning_rate: 7.5e-05
+  target: ldm.models.diffusion.ddpm.LatentInpaintDiffusion
   params:
-    parameterization: "v"
     linear_start: 0.00085
     linear_end: 0.0120
     num_timesteps_cond: 1
@@ -12,29 +11,36 @@ model:
     cond_stage_key: "txt"
     image_size: 64
     channels: 4
-    cond_stage_trainable: false
-    conditioning_key: crossattn
+    cond_stage_trainable: false   # Note: different from the one we trained before
+    conditioning_key: hybrid   # important
     monitor: val/loss_simple_ema
     scale_factor: 0.18215
-    use_ema: False # we set this to false because this is an inference only config
+    finetune_keys: null
+
+    scheduler_config: # 10000 warmup steps
+      target: ldm.lr_scheduler.LambdaLinearScheduler
+      params:
+        warm_up_steps: [ 2500 ] # NOTE for resuming. use 10000 if starting from scratch
+        cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
+        f_start: [ 1.e-6 ]
+        f_max: [ 1. ]
+        f_min: [ 1. ]
 
     unet_config:
       target: ldm.modules.diffusionmodules.openaimodel.UNetModel
       params:
-        use_checkpoint: True
-        use_fp16: True
         image_size: 32 # unused
-        in_channels: 4
+        in_channels: 9  # 4 data + 4 downscaled image + 1 mask
         out_channels: 4
         model_channels: 320
         attention_resolutions: [ 4, 2, 1 ]
         num_res_blocks: 2
         channel_mult: [ 1, 2, 4, 4 ]
-        num_head_channels: 64 # need to fix for flash-attn
+        num_heads: 8
         use_spatial_transformer: True
-        use_linear_in_transformer: True
         transformer_depth: 1
-        context_dim: 1024
+        context_dim: 768
+        use_checkpoint: True
         legacy: False
 
     first_stage_config:
@@ -43,7 +49,6 @@ model:
         embed_dim: 4
         monitor: val/rec_loss
         ddconfig:
-          #attn_type: "vanilla-xformers"
           double_z: true
           z_channels: 4
           resolution: 256
@@ -62,7 +67,4 @@ model:
           target: torch.nn.Identity
 
     cond_stage_config:
-      target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
-      params:
-        freeze: True
-        layer: "penultimate"
+      target: ldm.modules.encoders.modules.FrozenCLIPEmbedder

+ 3 - 2
modules/api/api.py

@@ -18,7 +18,8 @@ from modules.textual_inversion.textual_inversion import create_embedding, train_
 from modules.textual_inversion.preprocess import preprocess
 from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork
 from PIL import PngImagePlugin,Image
-from modules.sd_models import checkpoints_list, find_checkpoint_config
+from modules.sd_models import checkpoints_list
+from modules.sd_models_config import find_checkpoint_config_near_filename
 from modules.realesrgan_model import get_realesrgan_models
 from modules import devices
 from typing import List
@@ -387,7 +388,7 @@ class Api:
         ]
 
     def get_sd_models(self):
-        return [{"title": x.title, "model_name": x.model_name, "hash": x.shorthash, "sha256": x.sha256, "filename": x.filename, "config": find_checkpoint_config(x)} for x in checkpoints_list.values()]
+        return [{"title": x.title, "model_name": x.model_name, "hash": x.shorthash, "sha256": x.sha256, "filename": x.filename, "config": find_checkpoint_config_near_filename(x)} for x in checkpoints_list.values()]
 
     def get_hypernetworks(self):
         return [{"name": name, "path": shared.hypernetworks[name]} for name in shared.hypernetworks]

+ 8 - 4
modules/devices.py

@@ -34,14 +34,18 @@ def get_cuda_device_string():
     return "cuda"
 
 
-def get_optimal_device():
+def get_optimal_device_name():
     if torch.cuda.is_available():
-        return torch.device(get_cuda_device_string())
+        return get_cuda_device_string()
 
     if has_mps():
-        return torch.device("mps")
+        return "mps"
+
+    return "cpu"
 
-    return cpu
+
+def get_optimal_device():
+    return torch.device(get_optimal_device_name())
 
 
 def get_device_for(task):

+ 0 - 9
modules/sd_hijack_inpainting.py

@@ -96,15 +96,6 @@ def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=F
     return x_prev, pred_x0, e_t
 
 
-def should_hijack_inpainting(checkpoint_info):
-    from modules import sd_models
-
-    ckpt_basename = os.path.basename(checkpoint_info.filename).lower()
-    cfg_basename = os.path.basename(sd_models.find_checkpoint_config(checkpoint_info)).lower()
-
-    return "inpainting" in ckpt_basename and not "inpainting" in cfg_basename
-
-
 def do_inpainting_hijack():
     # p_sample_plms is needed because PLMS can't work with dicts as conditionings
 

+ 113 - 115
modules/sd_models.py

@@ -2,8 +2,6 @@ import collections
 import os.path
 import sys
 import gc
-import time
-from collections import namedtuple
 import torch
 import re
 import safetensors.torch
@@ -14,10 +12,10 @@ import ldm.modules.midas as midas
 
 from ldm.util import instantiate_from_config
 
-from modules import shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes
+from modules import shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config
 from modules.paths import models_path
-from modules.sd_hijack_inpainting import do_inpainting_hijack, should_hijack_inpainting
-from modules.sd_hijack_ip2p import should_hijack_ip2p
+from modules.sd_hijack_inpainting import do_inpainting_hijack
+from modules.timer import Timer
 
 model_dir = "Stable-diffusion"
 model_path = os.path.abspath(os.path.join(models_path, model_dir))
@@ -99,17 +97,6 @@ def checkpoint_tiles():
     return sorted([x.title for x in checkpoints_list.values()], key=alphanumeric_key)
 
 
-def find_checkpoint_config(info):
-    if info is None:
-        return shared.cmd_opts.config
-
-    config = os.path.splitext(info.filename)[0] + ".yaml"
-    if os.path.exists(config):
-        return config
-
-    return shared.cmd_opts.config
-
-
 def list_models():
     checkpoints_list.clear()
     checkpoint_alisases.clear()
@@ -215,9 +202,7 @@ def get_state_dict_from_checkpoint(pl_sd):
 def read_state_dict(checkpoint_file, print_global_state=False, map_location=None):
     _, extension = os.path.splitext(checkpoint_file)
     if extension.lower() == ".safetensors":
-        device = map_location or shared.weight_load_location
-        if device is None:
-            device = devices.get_cuda_device_string() if torch.cuda.is_available() else "cpu"
+        device = map_location or shared.weight_load_location or devices.get_optimal_device_name()
         pl_sd = safetensors.torch.load_file(checkpoint_file, device=device)
     else:
         pl_sd = torch.load(checkpoint_file, map_location=map_location or shared.weight_load_location)
@@ -229,60 +214,74 @@ def read_state_dict(checkpoint_file, print_global_state=False, map_location=None
     return sd
 
 
-def load_model_weights(model, checkpoint_info: CheckpointInfo):
+def get_checkpoint_state_dict(checkpoint_info: CheckpointInfo, timer):
+    sd_model_hash = checkpoint_info.calculate_shorthash()
+    timer.record("calculate hash")
+
+    if checkpoint_info in checkpoints_loaded:
+        # use checkpoint cache
+        print(f"Loading weights [{sd_model_hash}] from cache")
+        return checkpoints_loaded[checkpoint_info]
+
+    print(f"Loading weights [{sd_model_hash}] from {checkpoint_info.filename}")
+    res = read_state_dict(checkpoint_info.filename)
+    timer.record("load weights from disk")
+
+    return res
+
+
+def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer):
     title = checkpoint_info.title
     sd_model_hash = checkpoint_info.calculate_shorthash()
+    timer.record("calculate hash")
+
     if checkpoint_info.title != title:
         shared.opts.data["sd_model_checkpoint"] = checkpoint_info.title
 
-    cache_enabled = shared.opts.sd_checkpoint_cache > 0
+    if state_dict is None:
+        state_dict = get_checkpoint_state_dict(checkpoint_info, timer)
 
-    if cache_enabled and checkpoint_info in checkpoints_loaded:
-        # use checkpoint cache
-        print(f"Loading weights [{sd_model_hash}] from cache")
-        model.load_state_dict(checkpoints_loaded[checkpoint_info])
-    else:
-        # load from file
-        print(f"Loading weights [{sd_model_hash}] from {checkpoint_info.filename}")
+    model.load_state_dict(state_dict, strict=False)
+    del state_dict
+    timer.record("apply weights to model")
 
-        sd = read_state_dict(checkpoint_info.filename)
-        model.load_state_dict(sd, strict=False)
-        del sd
-        
-        if cache_enabled:
-            # cache newly loaded model
-            checkpoints_loaded[checkpoint_info] = model.state_dict().copy()
+    if shared.opts.sd_checkpoint_cache > 0:
+        # cache newly loaded model
+        checkpoints_loaded[checkpoint_info] = model.state_dict().copy()
+
+    if shared.cmd_opts.opt_channelslast:
+        model.to(memory_format=torch.channels_last)
+        timer.record("apply channels_last")
 
-        if shared.cmd_opts.opt_channelslast:
-            model.to(memory_format=torch.channels_last)
+    if not shared.cmd_opts.no_half:
+        vae = model.first_stage_model
+        depth_model = getattr(model, 'depth_model', None)
 
-        if not shared.cmd_opts.no_half:
-            vae = model.first_stage_model
-            depth_model = getattr(model, 'depth_model', None)
+        # with --no-half-vae, remove VAE from model when doing half() to prevent its weights from being converted to float16
+        if shared.cmd_opts.no_half_vae:
+            model.first_stage_model = None
+        # with --upcast-sampling, don't convert the depth model weights to float16
+        if shared.cmd_opts.upcast_sampling and depth_model:
+            model.depth_model = None
 
-            # with --no-half-vae, remove VAE from model when doing half() to prevent its weights from being converted to float16
-            if shared.cmd_opts.no_half_vae:
-                model.first_stage_model = None
-            # with --upcast-sampling, don't convert the depth model weights to float16
-            if shared.cmd_opts.upcast_sampling and depth_model:
-                model.depth_model = None
+        model.half()
+        model.first_stage_model = vae
+        if depth_model:
+            model.depth_model = depth_model
 
-            model.half()
-            model.first_stage_model = vae
-            if depth_model:
-                model.depth_model = depth_model
+        timer.record("apply half()")
 
-        devices.dtype = torch.float32 if shared.cmd_opts.no_half else torch.float16
-        devices.dtype_vae = torch.float32 if shared.cmd_opts.no_half or shared.cmd_opts.no_half_vae else torch.float16
-        devices.dtype_unet = model.model.diffusion_model.dtype
-        devices.unet_needs_upcast = shared.cmd_opts.upcast_sampling and devices.dtype == torch.float16 and devices.dtype_unet == torch.float16
+    devices.dtype = torch.float32 if shared.cmd_opts.no_half else torch.float16
+    devices.dtype_vae = torch.float32 if shared.cmd_opts.no_half or shared.cmd_opts.no_half_vae else torch.float16
+    devices.dtype_unet = model.model.diffusion_model.dtype
+    devices.unet_needs_upcast = shared.cmd_opts.upcast_sampling and devices.dtype == torch.float16 and devices.dtype_unet == torch.float16
 
-        model.first_stage_model.to(devices.dtype_vae)
+    model.first_stage_model.to(devices.dtype_vae)
+    timer.record("apply dtype to VAE")
 
     # clean up cache if limit is reached
-    if cache_enabled:
-        while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache + 1: # we need to count the current model
-            checkpoints_loaded.popitem(last=False)  # LRU
+    while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache:
+        checkpoints_loaded.popitem(last=False)
 
     model.sd_model_hash = sd_model_hash
     model.sd_model_checkpoint = checkpoint_info.filename
@@ -295,6 +294,7 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo):
     sd_vae.clear_loaded_vae()
     vae_file, vae_source = sd_vae.resolve_vae(checkpoint_info.filename)
     sd_vae.load_vae(model, vae_file, vae_source)
+    timer.record("load VAE")
 
 
 def enable_midas_autodownload():
@@ -340,24 +340,20 @@ def enable_midas_autodownload():
     midas.api.load_model = load_model_wrapper
 
 
-class Timer:
-    def __init__(self):
-        self.start = time.time()
+def repair_config(sd_config):
 
-    def elapsed(self):
-        end = time.time()
-        res = end - self.start
-        self.start = end
-        return res
+    if not hasattr(sd_config.model.params, "use_ema"):
+        sd_config.model.params.use_ema = False
 
+    if shared.cmd_opts.no_half:
+        sd_config.model.params.unet_config.params.use_fp16 = False
+    elif shared.cmd_opts.upcast_sampling:
+        sd_config.model.params.unet_config.params.use_fp16 = True
 
-def load_model(checkpoint_info=None):
+
+def load_model(checkpoint_info=None, already_loaded_state_dict=None, time_taken_to_load_state_dict=None):
     from modules import lowvram, sd_hijack
     checkpoint_info = checkpoint_info or select_checkpoint()
-    checkpoint_config = find_checkpoint_config(checkpoint_info)
-
-    if checkpoint_config != shared.cmd_opts.config:
-        print(f"Loading config from: {checkpoint_config}")
 
     if shared.sd_model:
         sd_hijack.model_hijack.undo_hijack(shared.sd_model)
@@ -365,38 +361,27 @@ def load_model(checkpoint_info=None):
         gc.collect()
         devices.torch_gc()
 
-    sd_config = OmegaConf.load(checkpoint_config)
-    
-    if should_hijack_inpainting(checkpoint_info):
-        # Hardcoded config for now...
-        sd_config.model.target = "ldm.models.diffusion.ddpm.LatentInpaintDiffusion"
-        sd_config.model.params.conditioning_key = "hybrid"
-        sd_config.model.params.unet_config.params.in_channels = 9
-        sd_config.model.params.finetune_keys = None
-
-    if should_hijack_ip2p(checkpoint_info):
-        sd_config.model.target = "modules.models.diffusion.ddpm_edit.LatentDiffusion"
-        sd_config.model.params.conditioning_key = "hybrid"
-        sd_config.model.params.first_stage_key = "edited"
-        sd_config.model.params.cond_stage_key = "edit"
-        sd_config.model.params.image_size = 16
-        sd_config.model.params.unet_config.params.in_channels = 8
-        sd_config.model.params.unet_config.params.out_channels = 4
+    do_inpainting_hijack()
 
-    if not hasattr(sd_config.model.params, "use_ema"):
-        sd_config.model.params.use_ema = False
+    timer = Timer()
 
-    do_inpainting_hijack()
+    if already_loaded_state_dict is not None:
+        state_dict = already_loaded_state_dict
+    else:
+        state_dict = get_checkpoint_state_dict(checkpoint_info, timer)
 
-    if shared.cmd_opts.no_half:
-        sd_config.model.params.unet_config.params.use_fp16 = False
-    elif shared.cmd_opts.upcast_sampling:
-        sd_config.model.params.unet_config.params.use_fp16 = True
+    checkpoint_config = sd_models_config.find_checkpoint_config(state_dict, checkpoint_info)
 
-    timer = Timer()
+    timer.record("find config")
 
-    sd_model = None
+    sd_config = OmegaConf.load(checkpoint_config)
+    repair_config(sd_config)
+
+    timer.record("load config")
+
+    print(f"Creating model from config: {checkpoint_config}")
 
+    sd_model = None
     try:
         with sd_disable_initialization.DisableInitialization():
             sd_model = instantiate_from_config(sd_config.model)
@@ -407,29 +392,35 @@ def load_model(checkpoint_info=None):
         print('Failed to create model quickly; will retry using slow method.', file=sys.stderr)
         sd_model = instantiate_from_config(sd_config.model)
 
-    elapsed_create = timer.elapsed()
+    sd_model.used_config = checkpoint_config
 
-    load_model_weights(sd_model, checkpoint_info)
+    timer.record("create model")
 
-    elapsed_load_weights = timer.elapsed()
+    load_model_weights(sd_model, checkpoint_info, state_dict, timer)
 
     if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
         lowvram.setup_for_low_vram(sd_model, shared.cmd_opts.medvram)
     else:
         sd_model.to(shared.device)
 
+    timer.record("move model to device")
+
     sd_hijack.model_hijack.hijack(sd_model)
 
+    timer.record("hijack")
+
     sd_model.eval()
     shared.sd_model = sd_model
 
     sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True)  # Reload embeddings after model load as they may or may not fit the model
 
+    timer.record("load textual inversion embeddings")
+
     script_callbacks.model_loaded_callback(sd_model)
 
-    elapsed_the_rest = timer.elapsed()
+    timer.record("scripts callbacks")
 
-    print(f"Model loaded in {elapsed_create + elapsed_load_weights + elapsed_the_rest:.1f}s ({elapsed_create:.1f}s create model, {elapsed_load_weights:.1f}s load weights).")
+    print(f"Model loaded in {timer.summary()}.")
 
     return sd_model
 
@@ -440,6 +431,7 @@ def reload_model_weights(sd_model=None, info=None):
 
     if not sd_model:
         sd_model = shared.sd_model
+
     if sd_model is None:  # previous model load failed
         current_checkpoint_info = None
     else:
@@ -447,14 +439,6 @@ def reload_model_weights(sd_model=None, info=None):
         if sd_model.sd_model_checkpoint == checkpoint_info.filename:
             return
 
-    checkpoint_config = find_checkpoint_config(current_checkpoint_info)
-
-    if current_checkpoint_info is None or checkpoint_config != find_checkpoint_config(checkpoint_info) or should_hijack_inpainting(checkpoint_info) != should_hijack_inpainting(sd_model.sd_checkpoint_info) or should_hijack_ip2p(checkpoint_info) != should_hijack_ip2p(sd_model.sd_checkpoint_info):
-        del sd_model
-        checkpoints_loaded.clear()
-        load_model(checkpoint_info)
-        return shared.sd_model
-
     if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
         lowvram.send_everything_to_cpu()
     else:
@@ -464,21 +448,35 @@ def reload_model_weights(sd_model=None, info=None):
 
     timer = Timer()
 
+    state_dict = get_checkpoint_state_dict(checkpoint_info, timer)
+
+    checkpoint_config = sd_models_config.find_checkpoint_config(state_dict, checkpoint_info)
+
+    timer.record("find config")
+
+    if sd_model is None or checkpoint_config != sd_model.used_config:
+        del sd_model
+        checkpoints_loaded.clear()
+        load_model(checkpoint_info, already_loaded_state_dict=state_dict, time_taken_to_load_state_dict=timer.records["load weights from disk"])
+        return shared.sd_model
+
     try:
-        load_model_weights(sd_model, checkpoint_info)
+        load_model_weights(sd_model, checkpoint_info, state_dict, timer)
     except Exception as e:
         print("Failed to load checkpoint, restoring previous")
-        load_model_weights(sd_model, current_checkpoint_info)
+        load_model_weights(sd_model, current_checkpoint_info, None, timer)
         raise
     finally:
         sd_hijack.model_hijack.hijack(sd_model)
+        timer.record("hijack")
+
         script_callbacks.model_loaded_callback(sd_model)
+        timer.record("script callbacks")
 
         if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram:
             sd_model.to(devices.device)
+            timer.record("move model to device")
 
-    elapsed = timer.elapsed()
-
-    print(f"Weights loaded in {elapsed:.1f}s.")
+    print(f"Weights loaded in {timer.summary()}.")
 
     return sd_model

+ 65 - 0
modules/sd_models_config.py

@@ -0,0 +1,65 @@
+import re
+import os
+
+from modules import shared, paths
+
+sd_configs_path = shared.sd_configs_path
+sd_repo_configs_path = os.path.join(paths.paths['Stable Diffusion'], "configs", "stable-diffusion")
+
+
+config_default = shared.sd_default_config
+config_sd2 = os.path.join(sd_repo_configs_path, "v2-inference.yaml")
+config_sd2v = os.path.join(sd_repo_configs_path, "v2-inference-v.yaml")
+config_inpainting = os.path.join(sd_configs_path, "v1-inpainting-inference.yaml")
+config_instruct_pix2pix = os.path.join(sd_configs_path, "instruct-pix2pix.yaml")
+config_alt_diffusion = os.path.join(sd_configs_path, "alt-diffusion-inference.yaml")
+
+re_parametrization_v = re.compile(r'-v\b')
+
+
+def guess_model_config_from_state_dict(sd, filename):
+    fn = os.path.basename(filename)
+
+    sd2_cond_proj_weight = sd.get('cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight', None)
+    diffusion_model_input = sd.get('model.diffusion_model.input_blocks.0.0.weight', None)
+    roberta_weight = sd.get('cond_stage_model.roberta.embeddings.word_embeddings.weight', None)
+
+    if sd2_cond_proj_weight is not None and sd2_cond_proj_weight.shape[1] == 1024:
+        if re.search(re_parametrization_v, fn) or "v2-1_768" in fn:
+            return config_sd2v
+        else:
+            return config_sd2
+
+    if diffusion_model_input is not None:
+        if diffusion_model_input.shape[1] == 9:
+            return config_inpainting
+        if diffusion_model_input.shape[1] == 8:
+            return config_instruct_pix2pix
+
+    if roberta_weight is not None:
+        return config_alt_diffusion
+
+    return config_default
+
+
+def find_checkpoint_config(state_dict, info):
+    if info is None:
+        return guess_model_config_from_state_dict(state_dict, "")
+
+    config = find_checkpoint_config_near_filename(info)
+    if config is not None:
+        return config
+
+    return guess_model_config_from_state_dict(state_dict, info.filename)
+
+
+def find_checkpoint_config_near_filename(info):
+    if info is None:
+        return None
+
+    config = os.path.splitext(info.filename)[0] + ".yaml"
+    if os.path.exists(config):
+        return config
+
+    return None
+

+ 4 - 3
modules/shared.py

@@ -13,13 +13,14 @@ import modules.interrogate
 import modules.memmon
 import modules.styles
 import modules.devices as devices
-from modules import localization, sd_vae, extensions, script_loading, errors, ui_components, shared_items
+from modules import localization, extensions, script_loading, errors, ui_components, shared_items
 from modules.paths import models_path, script_path
 
 
 demo = None
 
-sd_default_config = os.path.join(script_path, "configs/v1-inference.yaml")
+sd_configs_path = os.path.join(script_path, "configs")
+sd_default_config = os.path.join(sd_configs_path, "v1-inference.yaml")
 sd_model_file = os.path.join(script_path, 'model.ckpt')
 default_sd_model_file = sd_model_file
 
@@ -391,7 +392,7 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), {
     "sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": list_checkpoint_tiles()}, refresh=refresh_checkpoints),
     "sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
     "sd_vae_checkpoint_cache": OptionInfo(0, "VAE Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
-    "sd_vae": OptionInfo("Automatic", "SD VAE", gr.Dropdown, lambda: {"choices": ["Automatic", "None"] + list(sd_vae.vae_dict)}, refresh=sd_vae.refresh_vae_list),
+    "sd_vae": OptionInfo("Automatic", "SD VAE", gr.Dropdown, lambda: {"choices": shared_items.sd_vae_items()}, refresh=shared_items.refresh_vae_list),
     "sd_vae_as_default": OptionInfo(True, "Ignore selected VAE for stable diffusion checkpoints that have their own .vae.pt next to them"),
     "inpainting_mask_weight": OptionInfo(1.0, "Inpainting conditioning mask strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
     "initial_noise_multiplier": OptionInfo(1.0, "Noise multiplier for img2img", gr.Slider, {"minimum": 0.5, "maximum": 1.5, "step": 0.01}),

+ 14 - 1
modules/shared_items.py

@@ -4,7 +4,20 @@ def realesrgan_models_names():
     import modules.realesrgan_model
     return [x.name for x in modules.realesrgan_model.get_realesrgan_models(None)]
 
+
 def postprocessing_scripts():
     import modules.scripts
 
-    return modules.scripts.scripts_postproc.scripts
+    return modules.scripts.scripts_postproc.scripts
+
+
+def sd_vae_items():
+    import modules.sd_vae
+
+    return ["Automatic", "None"] + list(modules.sd_vae.vae_dict)
+
+
+def refresh_vae_list():
+    import modules.sd_vae
+
+    return modules.sd_vae.refresh_vae_list

+ 35 - 0
modules/timer.py

@@ -0,0 +1,35 @@
+import time
+
+
+class Timer:
+    def __init__(self):
+        self.start = time.time()
+        self.records = {}
+        self.total = 0
+
+    def elapsed(self):
+        end = time.time()
+        res = end - self.start
+        self.start = end
+        return res
+
+    def record(self, category, extra_time=0):
+        e = self.elapsed()
+        if category not in self.records:
+            self.records[category] = 0
+
+        self.records[category] += e + extra_time
+        self.total += e + extra_time
+
+    def summary(self):
+        res = f"{self.total:.1f}s"
+
+        additions = [x for x in self.records.items() if x[1] >= 0.1]
+        if not additions:
+            return res
+
+        res += " ("
+        res += ", ".join([f"{category}: {time_taken:.1f}s" for category, time_taken in additions])
+        res += ")"
+
+        return res