瀏覽代碼

fix for incorrect embedding token length calculation (will break seeds that use embeddings, you're welcome!)
add option to input initialization text for embeddings

AUTOMATIC 2 年之前
父節點
當前提交
88ec0cf557
共有 4 個文件被更改,包括 13 次插入14 次删除
  1. 4 4
      modules/sd_hijack.py
  2. 5 8
      modules/textual_inversion/textual_inversion.py
  3. 2 2
      modules/textual_inversion/ui.py
  4. 2 0
      modules/ui.py

+ 4 - 4
modules/sd_hijack.py

@@ -130,7 +130,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
             while i < len(tokens):
                 token = tokens[i]
 
-                embedding = self.hijack.embedding_db.find_embedding_at_position(tokens, i)
+                embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, i)
 
                 if embedding is None:
                     remade_tokens.append(token)
@@ -142,7 +142,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
                     remade_tokens += [0] * emb_len
                     multipliers += [weight] * emb_len
                     used_custom_terms.append((embedding.name, embedding.checksum()))
-                    i += emb_len
+                    i += embedding_length_in_tokens
 
         if len(remade_tokens) > maxlen - 2:
             vocab = {v: k for k, v in self.wrapped.tokenizer.get_vocab().items()}
@@ -213,7 +213,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
                 while i < len(tokens):
                     token = tokens[i]
 
-                    embedding = self.hijack.embedding_db.find_embedding_at_position(tokens, i)
+                    embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, i)
 
                     mult_change = self.token_mults.get(token) if opts.enable_emphasis else None
                     if mult_change is not None:
@@ -229,7 +229,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
                         remade_tokens += [0] * emb_len
                         multipliers += [mult] * emb_len
                         used_custom_terms.append((embedding.name, embedding.checksum()))
-                        i += emb_len
+                        i += embedding_length_in_tokens
 
                 if len(remade_tokens) > maxlen - 2:
                     vocab = {v: k for k, v in self.wrapped.tokenizer.get_vocab().items()}

+ 5 - 8
modules/textual_inversion/textual_inversion.py

@@ -117,24 +117,21 @@ class EmbeddingDatabase:
         possible_matches = self.ids_lookup.get(token, None)
 
         if possible_matches is None:
-            return None
+            return None, None
 
         for ids, embedding in possible_matches:
             if tokens[offset:offset + len(ids)] == ids:
-                return embedding
+                return embedding, len(ids)
 
-        return None
+        return None, None
 
 
-
-def create_embedding(name, num_vectors_per_token):
-    init_text = '*'
-
+def create_embedding(name, num_vectors_per_token, init_text='*'):
     cond_model = shared.sd_model.cond_stage_model
     embedding_layer = cond_model.wrapped.transformer.text_model.embeddings
 
     ids = cond_model.tokenizer(init_text, max_length=num_vectors_per_token, return_tensors="pt", add_special_tokens=False)["input_ids"]
-    embedded = embedding_layer(ids.to(devices.device)).squeeze(0)
+    embedded = embedding_layer.token_embedding.wrapped(ids.to(devices.device)).squeeze(0)
     vec = torch.zeros((num_vectors_per_token, embedded.shape[1]), device=devices.device)
 
     for i in range(num_vectors_per_token):

+ 2 - 2
modules/textual_inversion/ui.py

@@ -6,8 +6,8 @@ import modules.textual_inversion.textual_inversion as ti
 from modules import sd_hijack, shared
 
 
-def create_embedding(name, nvpt):
-    filename = ti.create_embedding(name, nvpt)
+def create_embedding(name, initialization_text, nvpt):
+    filename = ti.create_embedding(name, nvpt, init_text=initialization_text)
 
     sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings()
 

+ 2 - 0
modules/ui.py

@@ -954,6 +954,7 @@ def create_ui(wrap_gradio_gpu_call):
                     gr.HTML(value="<p style='margin-bottom: 0.7em'>Create a new embedding</p>")
 
                     new_embedding_name = gr.Textbox(label="Name")
+                    initialization_text = gr.Textbox(label="Initialization text", value="*")
                     nvpt = gr.Slider(label="Number of vectors per token", minimum=1, maximum=75, step=1, value=1)
 
                     with gr.Row():
@@ -997,6 +998,7 @@ def create_ui(wrap_gradio_gpu_call):
             fn=modules.textual_inversion.ui.create_embedding,
             inputs=[
                 new_embedding_name,
+                initialization_text,
                 nvpt,
             ],
             outputs=[