Browse Source

getting SD2.1 to run on SDXL repo

AUTOMATIC1111 2 years ago
parent
commit
af081211ee

+ 3 - 0
modules/launch_utils.py

@@ -235,11 +235,13 @@ def prepare_environment():
     openclip_package = os.environ.get('OPENCLIP_PACKAGE', "https://github.com/mlfoundations/open_clip/archive/bb6e834e9c70d9c27d0dc3ecedeebeaeb1ffad6b.zip")
 
     stable_diffusion_repo = os.environ.get('STABLE_DIFFUSION_REPO', "https://github.com/Stability-AI/stablediffusion.git")
+    stable_diffusion_xl_repo = os.environ.get('STABLE_DIFFUSION_XL_REPO', "https://github.com/Stability-AI/generative-models.git")
     k_diffusion_repo = os.environ.get('K_DIFFUSION_REPO', 'https://github.com/crowsonkb/k-diffusion.git')
     codeformer_repo = os.environ.get('CODEFORMER_REPO', 'https://github.com/sczhou/CodeFormer.git')
     blip_repo = os.environ.get('BLIP_REPO', 'https://github.com/salesforce/BLIP.git')
 
     stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "cf1d67a6fd5ea1aa600c4df58e5b47da45f6bdbf")
+    stable_diffusion_xl_commit_hash = os.environ.get('STABLE_DIFFUSION_XL_COMMIT_HASH', "5c10deee76adad0032b412294130090932317a87")
     k_diffusion_commit_hash = os.environ.get('K_DIFFUSION_COMMIT_HASH', "c9fe758757e022f05ca5a53fa8fac28889e4f1cf")
     codeformer_commit_hash = os.environ.get('CODEFORMER_COMMIT_HASH', "c5b4593074ba6214284d6acd5f1719b6c5d739af")
     blip_commit_hash = os.environ.get('BLIP_COMMIT_HASH', "48211a1594f1321b00f14c9f7a5b4813144b2fb9")
@@ -297,6 +299,7 @@ def prepare_environment():
     os.makedirs(os.path.join(script_path, dir_repos), exist_ok=True)
 
     git_clone(stable_diffusion_repo, repo_dir('stable-diffusion-stability-ai'), "Stable Diffusion", stable_diffusion_commit_hash)
+    git_clone(stable_diffusion_xl_repo, repo_dir('generative-models'), "Stable Diffusion XL", stable_diffusion_xl_commit_hash)
     git_clone(k_diffusion_repo, repo_dir('k-diffusion'), "K-diffusion", k_diffusion_commit_hash)
     git_clone(codeformer_repo, repo_dir('CodeFormer'), "CodeFormer", codeformer_commit_hash)
     git_clone(blip_repo, repo_dir('BLIP'), "BLIP", blip_commit_hash)

+ 1 - 0
modules/paths.py

@@ -20,6 +20,7 @@ assert sd_path is not None, f"Couldn't find Stable Diffusion in any of: {possibl
 
 path_dirs = [
     (sd_path, 'ldm', 'Stable Diffusion', []),
+    (os.path.join(sd_path, '../generative-models'), 'sgm', 'Stable Diffusion XL', []),
     (os.path.join(sd_path, '../CodeFormer'), 'inference_codeformer.py', 'CodeFormer', []),
     (os.path.join(sd_path, '../BLIP'), 'models/blip.py', 'BLIP', []),
     (os.path.join(sd_path, '../k-diffusion'), 'k_diffusion/sampling.py', 'k_diffusion', ["atstart"]),

+ 52 - 12
modules/prompt_parser.py

@@ -144,7 +144,12 @@ def get_learned_conditioning(model, prompts, steps):
 
         cond_schedule = []
         for i, (end_at_step, _) in enumerate(prompt_schedule):
-            cond_schedule.append(ScheduledPromptConditioning(end_at_step, conds[i]))
+            if isinstance(conds, dict):
+                cond = {k: v[i] for k, v in conds.items()}
+            else:
+                cond = conds[i]
+
+            cond_schedule.append(ScheduledPromptConditioning(end_at_step, cond))
 
         cache[prompt] = cond_schedule
         res.append(cond_schedule)
@@ -214,20 +219,57 @@ def get_multicond_learned_conditioning(model, prompts, steps) -> MulticondLearne
     return MulticondLearnedConditioning(shape=(len(prompts),), batch=res)
 
 
+class DictWithShape(dict):
+    def __init__(self, x, shape):
+        super().__init__()
+        self.update(x)
+
+    @property
+    def shape(self):
+        return self["crossattn"].shape
+
+
 def reconstruct_cond_batch(c: List[List[ScheduledPromptConditioning]], current_step):
     param = c[0][0].cond
-    res = torch.zeros((len(c),) + param.shape, device=param.device, dtype=param.dtype)
+    is_dict = isinstance(param, dict)
+
+    if is_dict:
+        dict_cond = param
+        res = {k: torch.zeros((len(c),) + param.shape, device=param.device, dtype=param.dtype) for k, param in dict_cond.items()}
+        res = DictWithShape(res, (len(c),) + dict_cond['crossattn'].shape)
+    else:
+        res = torch.zeros((len(c),) + param.shape, device=param.device, dtype=param.dtype)
+
     for i, cond_schedule in enumerate(c):
         target_index = 0
         for current, entry in enumerate(cond_schedule):
             if current_step <= entry.end_at_step:
                 target_index = current
                 break
-        res[i] = cond_schedule[target_index].cond
+
+        if is_dict:
+            for k, param in cond_schedule[target_index].cond.items():
+                res[k][i] = param
+        else:
+            res[i] = cond_schedule[target_index].cond
 
     return res
 
 
+def stack_conds(tensors):
+    # if prompts have wildly different lengths above the limit we'll get tensors of different shapes
+    # and won't be able to torch.stack them. So this fixes that.
+    token_count = max([x.shape[0] for x in tensors])
+    for i in range(len(tensors)):
+        if tensors[i].shape[0] != token_count:
+            last_vector = tensors[i][-1:]
+            last_vector_repeated = last_vector.repeat([token_count - tensors[i].shape[0], 1])
+            tensors[i] = torch.vstack([tensors[i], last_vector_repeated])
+
+    return torch.stack(tensors)
+
+
+
 def reconstruct_multicond_batch(c: MulticondLearnedConditioning, current_step):
     param = c.batch[0][0].schedules[0].cond
 
@@ -249,16 +291,14 @@ def reconstruct_multicond_batch(c: MulticondLearnedConditioning, current_step):
 
         conds_list.append(conds_for_batch)
 
-    # if prompts have wildly different lengths above the limit we'll get tensors fo different shapes
-    # and won't be able to torch.stack them. So this fixes that.
-    token_count = max([x.shape[0] for x in tensors])
-    for i in range(len(tensors)):
-        if tensors[i].shape[0] != token_count:
-            last_vector = tensors[i][-1:]
-            last_vector_repeated = last_vector.repeat([token_count - tensors[i].shape[0], 1])
-            tensors[i] = torch.vstack([tensors[i], last_vector_repeated])
+    if isinstance(tensors[0], dict):
+        keys = list(tensors[0].keys())
+        stacked = {k: stack_conds([x[k] for x in tensors]) for k in keys}
+        stacked = DictWithShape(stacked, stacked['crossattn'].shape)
+    else:
+        stacked = stack_conds(tensors).to(device=param.device, dtype=param.dtype)
 
-    return conds_list, torch.stack(tensors).to(device=param.device, dtype=param.dtype)
+    return conds_list, stacked
 
 
 re_attention = re.compile(r"""

+ 9 - 0
modules/sd_hijack.py

@@ -166,6 +166,15 @@ class StableDiffusionModelHijack:
             undo_optimizations()
 
     def hijack(self, m):
+        conditioner = getattr(m, 'conditioner', None)
+        if conditioner:
+            for i in range(len(conditioner.embedders)):
+                embedder = conditioner.embedders[i]
+                if type(embedder).__name__ == 'FrozenOpenCLIPEmbedder':
+                    embedder.model.token_embedding = EmbeddingsWithFixes(embedder.model.token_embedding, self)
+                    m.cond_stage_model = sd_hijack_open_clip.FrozenOpenCLIPEmbedderWithCustomWords(embedder, self)
+                    conditioner.embedders[i] = m.cond_stage_model
+
         if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation:
             model_embeddings = m.cond_stage_model.roberta.embeddings
             model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.word_embeddings, self)

+ 4 - 0
modules/sd_hijack_open_clip.py

@@ -16,6 +16,10 @@ class FrozenOpenCLIPEmbedderWithCustomWords(sd_hijack_clip.FrozenCLIPEmbedderWit
         self.id_end = tokenizer.encoder["<end_of_text>"]
         self.id_pad = 0
 
+        self.is_trainable = getattr(wrapped, 'is_trainable', False)
+        self.input_key = getattr(wrapped, 'input_key', 'txt')
+        self.legacy_ucg_val = None
+
     def tokenize(self, texts):
         assert not opts.use_old_emphasis_implementation, 'Old emphasis implementation not supported for Open Clip'
 

+ 6 - 2
modules/sd_models.py

@@ -14,7 +14,7 @@ import ldm.modules.midas as midas
 
 from ldm.util import instantiate_from_config
 
-from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet
+from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet, sd_models_xl
 from modules.sd_hijack_inpainting import do_inpainting_hijack
 from modules.timer import Timer
 import tomesd
@@ -289,6 +289,9 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
     if state_dict is None:
         state_dict = get_checkpoint_state_dict(checkpoint_info, timer)
 
+    if hasattr(model, 'conditioner'):
+        sd_models_xl.extend_sdxl(model)
+
     model.load_state_dict(state_dict, strict=False)
     del state_dict
     timer.record("apply weights to model")
@@ -334,7 +337,8 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
     model.sd_checkpoint_info = checkpoint_info
     shared.opts.data["sd_checkpoint_hash"] = checkpoint_info.sha256
 
-    model.logvar = model.logvar.to(devices.device)  # fix for training
+    if hasattr(model, 'logvar'):
+        model.logvar = model.logvar.to(devices.device)  # fix for training
 
     sd_vae.delete_base_vae()
     sd_vae.clear_loaded_vae()

+ 2 - 0
modules/sd_models_config.py

@@ -6,11 +6,13 @@ from modules import shared, paths, sd_disable_initialization
 
 sd_configs_path = shared.sd_configs_path
 sd_repo_configs_path = os.path.join(paths.paths['Stable Diffusion'], "configs", "stable-diffusion")
+sd_xl_repo_configs_path = os.path.join(paths.paths['Stable Diffusion XL'], "configs", "inference")
 
 
 config_default = shared.sd_default_config
 config_sd2 = os.path.join(sd_repo_configs_path, "v2-inference.yaml")
 config_sd2v = os.path.join(sd_repo_configs_path, "v2-inference-v.yaml")
+config_sd2v = os.path.join(sd_xl_repo_configs_path, "sd_2_1_768.yaml")
 config_sd2_inpainting = os.path.join(sd_repo_configs_path, "v2-inpainting-inference.yaml")
 config_depth_model = os.path.join(sd_repo_configs_path, "v2-midas-inference.yaml")
 config_unclip = os.path.join(sd_repo_configs_path, "v2-1-stable-unclip-l-inference.yaml")

+ 40 - 0
modules/sd_models_xl.py

@@ -0,0 +1,40 @@
+from __future__ import annotations
+
+import torch
+
+import sgm.models.diffusion
+import sgm.modules.diffusionmodules.denoiser_scaling
+import sgm.modules.diffusionmodules.discretizer
+from modules import devices
+
+
+def get_learned_conditioning(self: sgm.models.diffusion.DiffusionEngine, batch: list[str]):
+    for embedder in self.conditioner.embedders:
+        embedder.ucg_rate = 0.0
+
+    c = self.conditioner({'txt': batch})
+
+    return c
+
+
+def apply_model(self: sgm.models.diffusion.DiffusionEngine, x, t, cond):
+    return self.model(x, t, cond)
+
+
+def extend_sdxl(model):
+    dtype = next(model.model.diffusion_model.parameters()).dtype
+    model.model.diffusion_model.dtype = dtype
+    model.model.conditioning_key = 'crossattn'
+
+    model.cond_stage_model = [x for x in model.conditioner.embedders if type(x).__name__ == 'FrozenOpenCLIPEmbedder'][0]
+    model.cond_stage_key = model.cond_stage_model.input_key
+
+    model.parameterization = "v" if isinstance(model.denoiser.scaling, sgm.modules.diffusionmodules.denoiser_scaling.VScaling) else "eps"
+
+    discretization = sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization()
+    model.alphas_cumprod = torch.asarray(discretization.alphas_cumprod, device=devices.device, dtype=dtype)
+
+
+sgm.models.diffusion.DiffusionEngine.get_learned_conditioning = get_learned_conditioning
+sgm.models.diffusion.DiffusionEngine.apply_model = apply_model
+

+ 35 - 10
modules/sd_samplers_kdiffusion.py

@@ -53,6 +53,28 @@ k_diffusion_scheduler = {
 }
 
 
+def catenate_conds(conds):
+    if not isinstance(conds[0], dict):
+        return torch.cat(conds)
+
+    return {key: torch.cat([x[key] for x in conds]) for key in conds[0].keys()}
+
+
+def subscript_cond(cond, a, b):
+    if not isinstance(cond, dict):
+        return cond[a:b]
+
+    return {key: vec[a:b] for key, vec in cond.items()}
+
+
+def pad_cond(tensor, repeats, empty):
+    if not isinstance(tensor, dict):
+        return torch.cat([tensor, empty.repeat((tensor.shape[0], repeats, 1))], axis=1)
+
+    tensor['crossattn'] = pad_cond(tensor['crossattn'], repeats, empty)
+    return tensor
+
+
 class CFGDenoiser(torch.nn.Module):
     """
     Classifier free guidance denoiser. A wrapper for stable diffusion model (specifically for unet)
@@ -105,10 +127,13 @@ class CFGDenoiser(torch.nn.Module):
 
         if shared.sd_model.model.conditioning_key == "crossattn-adm":
             image_uncond = torch.zeros_like(image_cond)
-            make_condition_dict = lambda c_crossattn, c_adm: {"c_crossattn": c_crossattn, "c_adm": c_adm}
+            make_condition_dict = lambda c_crossattn, c_adm: {"c_crossattn": [c_crossattn], "c_adm": c_adm}
         else:
             image_uncond = image_cond
-            make_condition_dict = lambda c_crossattn, c_concat: {"c_crossattn": c_crossattn, "c_concat": [c_concat]}
+            if isinstance(uncond, dict):
+                make_condition_dict = lambda c_crossattn, c_concat: {**c_crossattn, "c_concat": [c_concat]}
+            else:
+                make_condition_dict = lambda c_crossattn, c_concat: {"c_crossattn": [c_crossattn], "c_concat": [c_concat]}
 
         if not is_edit_model:
             x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x])
@@ -140,28 +165,28 @@ class CFGDenoiser(torch.nn.Module):
             num_repeats = (tensor.shape[1] - uncond.shape[1]) // empty.shape[1]
 
             if num_repeats < 0:
-                tensor = torch.cat([tensor, empty.repeat((tensor.shape[0], -num_repeats, 1))], axis=1)
+                tensor = pad_cond(tensor, -num_repeats, empty)
                 self.padded_cond_uncond = True
             elif num_repeats > 0:
-                uncond = torch.cat([uncond, empty.repeat((uncond.shape[0], num_repeats, 1))], axis=1)
+                uncond = pad_cond(uncond, num_repeats, empty)
                 self.padded_cond_uncond = True
 
         if tensor.shape[1] == uncond.shape[1] or skip_uncond:
             if is_edit_model:
-                cond_in = torch.cat([tensor, uncond, uncond])
+                cond_in = catenate_conds([tensor, uncond, uncond])
             elif skip_uncond:
                 cond_in = tensor
             else:
-                cond_in = torch.cat([tensor, uncond])
+                cond_in = catenate_conds([tensor, uncond])
 
             if shared.batch_cond_uncond:
-                x_out = self.inner_model(x_in, sigma_in, cond=make_condition_dict([cond_in], image_cond_in))
+                x_out = self.inner_model(x_in, sigma_in, cond=make_condition_dict(cond_in, image_cond_in))
             else:
                 x_out = torch.zeros_like(x_in)
                 for batch_offset in range(0, x_out.shape[0], batch_size):
                     a = batch_offset
                     b = a + batch_size
-                    x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=make_condition_dict([cond_in[a:b]], image_cond_in[a:b]))
+                    x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=make_condition_dict(cond_in[a:b], image_cond_in[a:b]))
         else:
             x_out = torch.zeros_like(x_in)
             batch_size = batch_size*2 if shared.batch_cond_uncond else batch_size
@@ -170,14 +195,14 @@ class CFGDenoiser(torch.nn.Module):
                 b = min(a + batch_size, tensor.shape[0])
 
                 if not is_edit_model:
-                    c_crossattn = [tensor[a:b]]
+                    c_crossattn = subscript_cond(tensor, a, b)
                 else:
                     c_crossattn = torch.cat([tensor[a:b]], uncond)
 
                 x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=make_condition_dict(c_crossattn, image_cond_in[a:b]))
 
             if not skip_uncond:
-                x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond=make_condition_dict([uncond], image_cond_in[-uncond.shape[0]:]))
+                x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond=make_condition_dict(uncond, image_cond_in[-uncond.shape[0]:]))
 
         denoised_image_indexes = [x[0][0] for x in conds_list]
         if skip_uncond: