|
@@ -130,7 +130,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
|
|
while i < len(tokens):
|
|
while i < len(tokens):
|
|
token = tokens[i]
|
|
token = tokens[i]
|
|
|
|
|
|
- embedding = self.hijack.embedding_db.find_embedding_at_position(tokens, i)
|
|
|
|
|
|
+ embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, i)
|
|
|
|
|
|
if embedding is None:
|
|
if embedding is None:
|
|
remade_tokens.append(token)
|
|
remade_tokens.append(token)
|
|
@@ -142,7 +142,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
|
|
remade_tokens += [0] * emb_len
|
|
remade_tokens += [0] * emb_len
|
|
multipliers += [weight] * emb_len
|
|
multipliers += [weight] * emb_len
|
|
used_custom_terms.append((embedding.name, embedding.checksum()))
|
|
used_custom_terms.append((embedding.name, embedding.checksum()))
|
|
- i += emb_len
|
|
|
|
|
|
+ i += embedding_length_in_tokens
|
|
|
|
|
|
if len(remade_tokens) > maxlen - 2:
|
|
if len(remade_tokens) > maxlen - 2:
|
|
vocab = {v: k for k, v in self.wrapped.tokenizer.get_vocab().items()}
|
|
vocab = {v: k for k, v in self.wrapped.tokenizer.get_vocab().items()}
|
|
@@ -213,7 +213,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
|
|
while i < len(tokens):
|
|
while i < len(tokens):
|
|
token = tokens[i]
|
|
token = tokens[i]
|
|
|
|
|
|
- embedding = self.hijack.embedding_db.find_embedding_at_position(tokens, i)
|
|
|
|
|
|
+ embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, i)
|
|
|
|
|
|
mult_change = self.token_mults.get(token) if opts.enable_emphasis else None
|
|
mult_change = self.token_mults.get(token) if opts.enable_emphasis else None
|
|
if mult_change is not None:
|
|
if mult_change is not None:
|
|
@@ -229,7 +229,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
|
|
remade_tokens += [0] * emb_len
|
|
remade_tokens += [0] * emb_len
|
|
multipliers += [mult] * emb_len
|
|
multipliers += [mult] * emb_len
|
|
used_custom_terms.append((embedding.name, embedding.checksum()))
|
|
used_custom_terms.append((embedding.name, embedding.checksum()))
|
|
- i += emb_len
|
|
|
|
|
|
+ i += embedding_length_in_tokens
|
|
|
|
|
|
if len(remade_tokens) > maxlen - 2:
|
|
if len(remade_tokens) > maxlen - 2:
|
|
vocab = {v: k for k, v in self.wrapped.tokenizer.get_vocab().items()}
|
|
vocab = {v: k for k, v in self.wrapped.tokenizer.get_vocab().items()}
|