瀏覽代碼

textual inversion support for SDXL

AUTOMATIC1111 2 年之前
父節點
當前提交
6f0abbb71a
共有 4 個文件被更改,包括 29 次插入9 次删除
  1. 5 3
      modules/sd_hijack.py
  2. 1 1
      modules/sd_hijack_clip.py
  3. 9 0
      modules/sd_models_xl.py
  4. 14 5
      modules/textual_inversion/textual_inversion.py

+ 5 - 3
modules/sd_hijack.py

@@ -197,7 +197,7 @@ class StableDiffusionModelHijack:
                     conditioner.embedders[i] = sd_hijack_clip.FrozenCLIPEmbedderForSDXLWithCustomWords(embedder, self)
                     conditioner.embedders[i] = sd_hijack_clip.FrozenCLIPEmbedderForSDXLWithCustomWords(embedder, self)
                     text_cond_models.append(conditioner.embedders[i])
                     text_cond_models.append(conditioner.embedders[i])
                 if typename == 'FrozenOpenCLIPEmbedder2':
                 if typename == 'FrozenOpenCLIPEmbedder2':
-                    embedder.model.token_embedding = EmbeddingsWithFixes(embedder.model.token_embedding, self)
+                    embedder.model.token_embedding = EmbeddingsWithFixes(embedder.model.token_embedding, self, textual_inversion_key='clip_g')
                     conditioner.embedders[i] = sd_hijack_open_clip.FrozenOpenCLIPEmbedder2WithCustomWords(embedder, self)
                     conditioner.embedders[i] = sd_hijack_open_clip.FrozenOpenCLIPEmbedder2WithCustomWords(embedder, self)
                     text_cond_models.append(conditioner.embedders[i])
                     text_cond_models.append(conditioner.embedders[i])
 
 
@@ -292,10 +292,11 @@ class StableDiffusionModelHijack:
 
 
 
 
 class EmbeddingsWithFixes(torch.nn.Module):
 class EmbeddingsWithFixes(torch.nn.Module):
-    def __init__(self, wrapped, embeddings):
+    def __init__(self, wrapped, embeddings, textual_inversion_key='clip_l'):
         super().__init__()
         super().__init__()
         self.wrapped = wrapped
         self.wrapped = wrapped
         self.embeddings = embeddings
         self.embeddings = embeddings
+        self.textual_inversion_key = textual_inversion_key
 
 
     def forward(self, input_ids):
     def forward(self, input_ids):
         batch_fixes = self.embeddings.fixes
         batch_fixes = self.embeddings.fixes
@@ -309,7 +310,8 @@ class EmbeddingsWithFixes(torch.nn.Module):
         vecs = []
         vecs = []
         for fixes, tensor in zip(batch_fixes, inputs_embeds):
         for fixes, tensor in zip(batch_fixes, inputs_embeds):
             for offset, embedding in fixes:
             for offset, embedding in fixes:
-                emb = devices.cond_cast_unet(embedding.vec)
+                vec = embedding.vec[self.textual_inversion_key] if isinstance(embedding.vec, dict) else embedding.vec
+                emb = devices.cond_cast_unet(vec)
                 emb_len = min(tensor.shape[0] - offset - 1, emb.shape[0])
                 emb_len = min(tensor.shape[0] - offset - 1, emb.shape[0])
                 tensor = torch.cat([tensor[0:offset + 1], emb[0:emb_len], tensor[offset + 1 + emb_len:]])
                 tensor = torch.cat([tensor[0:offset + 1], emb[0:emb_len], tensor[offset + 1 + emb_len:]])
 
 

+ 1 - 1
modules/sd_hijack_clip.py

@@ -161,7 +161,7 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
                     position += 1
                     position += 1
                     continue
                     continue
 
 
-                emb_len = int(embedding.vec.shape[0])
+                emb_len = int(embedding.vectors)
                 if len(chunk.tokens) + emb_len > self.chunk_length:
                 if len(chunk.tokens) + emb_len > self.chunk_length:
                     next_chunk()
                     next_chunk()
 
 

+ 9 - 0
modules/sd_models_xl.py

@@ -56,6 +56,14 @@ def encode_embedding_init_text(self: sgm.modules.GeneralConditioner, init_text,
     return torch.cat(res, dim=1)
     return torch.cat(res, dim=1)
 
 
 
 
+def tokenize(self: sgm.modules.GeneralConditioner, texts):
+    for embedder in [embedder for embedder in self.embedders if hasattr(embedder, 'tokenize')]:
+        return embedder.tokenize(texts)
+
+    raise AssertionError('no tokenizer available')
+
+
+
 def process_texts(self, texts):
 def process_texts(self, texts):
     for embedder in [embedder for embedder in self.embedders if hasattr(embedder, 'process_texts')]:
     for embedder in [embedder for embedder in self.embedders if hasattr(embedder, 'process_texts')]:
         return embedder.process_texts(texts)
         return embedder.process_texts(texts)
@@ -68,6 +76,7 @@ def get_target_prompt_token_count(self, token_count):
 
 
 # those additions to GeneralConditioner make it possible to use it as model.cond_stage_model from SD1.5 in exist
 # those additions to GeneralConditioner make it possible to use it as model.cond_stage_model from SD1.5 in exist
 sgm.modules.GeneralConditioner.encode_embedding_init_text = encode_embedding_init_text
 sgm.modules.GeneralConditioner.encode_embedding_init_text = encode_embedding_init_text
+sgm.modules.GeneralConditioner.tokenize = tokenize
 sgm.modules.GeneralConditioner.process_texts = process_texts
 sgm.modules.GeneralConditioner.process_texts = process_texts
 sgm.modules.GeneralConditioner.get_target_prompt_token_count = get_target_prompt_token_count
 sgm.modules.GeneralConditioner.get_target_prompt_token_count = get_target_prompt_token_count
 
 

+ 14 - 5
modules/textual_inversion/textual_inversion.py

@@ -181,29 +181,38 @@ class EmbeddingDatabase:
         else:
         else:
             return
             return
 
 
+
         # textual inversion embeddings
         # textual inversion embeddings
         if 'string_to_param' in data:
         if 'string_to_param' in data:
             param_dict = data['string_to_param']
             param_dict = data['string_to_param']
             param_dict = getattr(param_dict, '_parameters', param_dict)  # fix for torch 1.12.1 loading saved file from torch 1.11
             param_dict = getattr(param_dict, '_parameters', param_dict)  # fix for torch 1.12.1 loading saved file from torch 1.11
             assert len(param_dict) == 1, 'embedding file has multiple terms in it'
             assert len(param_dict) == 1, 'embedding file has multiple terms in it'
             emb = next(iter(param_dict.items()))[1]
             emb = next(iter(param_dict.items()))[1]
-        # diffuser concepts
-        elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor:
+            vec = emb.detach().to(devices.device, dtype=torch.float32)
+            shape = vec.shape[-1]
+            vectors = vec.shape[0]
+        elif type(data) == dict and 'clip_g' in data and 'clip_l' in data:  # SDXL embedding
+            vec = {k: v.detach().to(devices.device, dtype=torch.float32) for k, v in data.items()}
+            shape = data['clip_g'].shape[-1] + data['clip_l'].shape[-1]
+            vectors = data['clip_g'].shape[0]
+        elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor: # diffuser concepts
             assert len(data.keys()) == 1, 'embedding file has multiple terms in it'
             assert len(data.keys()) == 1, 'embedding file has multiple terms in it'
 
 
             emb = next(iter(data.values()))
             emb = next(iter(data.values()))
             if len(emb.shape) == 1:
             if len(emb.shape) == 1:
                 emb = emb.unsqueeze(0)
                 emb = emb.unsqueeze(0)
+            vec = emb.detach().to(devices.device, dtype=torch.float32)
+            shape = vec.shape[-1]
+            vectors = vec.shape[0]
         else:
         else:
             raise Exception(f"Couldn't identify {filename} as neither textual inversion embedding nor diffuser concept.")
             raise Exception(f"Couldn't identify {filename} as neither textual inversion embedding nor diffuser concept.")
 
 
-        vec = emb.detach().to(devices.device, dtype=torch.float32)
         embedding = Embedding(vec, name)
         embedding = Embedding(vec, name)
         embedding.step = data.get('step', None)
         embedding.step = data.get('step', None)
         embedding.sd_checkpoint = data.get('sd_checkpoint', None)
         embedding.sd_checkpoint = data.get('sd_checkpoint', None)
         embedding.sd_checkpoint_name = data.get('sd_checkpoint_name', None)
         embedding.sd_checkpoint_name = data.get('sd_checkpoint_name', None)
-        embedding.vectors = vec.shape[0]
-        embedding.shape = vec.shape[-1]
+        embedding.vectors = vectors
+        embedding.shape = shape
         embedding.filename = path
         embedding.filename = path
         embedding.set_hash(hashes.sha256(embedding.filename, "textual_inversion/" + name) or '')
         embedding.set_hash(hashes.sha256(embedding.filename, "textual_inversion/" + name) or '')