AUTOMATIC 2 жил өмнө
parent
commit
79e39fae61

+ 3 - 3
modules/sd_hijack.py

@@ -150,10 +150,10 @@ class StableDiffusionModelHijack:
     def clear_comments(self):
         self.comments = []
 
-    def tokenize(self, text):
-        _, remade_batch_tokens, _, _, _, token_count = self.clip.process_text([text])
+    def get_prompt_lengths(self, text):
+        _, token_count = self.clip.process_texts([text])
 
-        return remade_batch_tokens[0], token_count, sd_hijack_clip.get_target_prompt_token_count(token_count)
+        return token_count, self.clip.get_target_prompt_token_count(token_count)
 
 
 class EmbeddingsWithFixes(torch.nn.Module):

+ 171 - 177
modules/sd_hijack_clip.py

@@ -1,12 +1,28 @@
 import math
+from collections import namedtuple
 
 import torch
 
 from modules import prompt_parser, devices
 from modules.shared import opts
 
-def get_target_prompt_token_count(token_count):
-    return math.ceil(max(token_count, 1) / 75) * 75
+
+class PromptChunk:
+    """
+    This object contains token ids, weight (multipliers:1.4) and textual inversion embedding info for a chunk of prompt.
+    If a prompt is short, it is represented by one PromptChunk, otherwise, multiple are necessary.
+    Each PromptChunk contains an exact amount of tokens - 77, which includes one for start and end token,
+    so just 75 tokens from prompt.
+    """
+
+    def __init__(self):
+        self.tokens = []
+        self.multipliers = []
+        self.fixes = []
+
+
+PromptChunkFix = namedtuple('PromptChunkFix', ['offset', 'embedding'])
+"""This is a marker showing that textual inversion embedding's vectors have to placed at offset in the prompt chunk"""
 
 
 class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
@@ -14,17 +30,49 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
         super().__init__()
         self.wrapped = wrapped
         self.hijack = hijack
+        self.chunk_length = 75
+
+    def empty_chunk(self):
+        """creates an empty PromptChunk and returns it"""
+
+        chunk = PromptChunk()
+        chunk.tokens = [self.id_start] + [self.id_end] * (self.chunk_length + 1)
+        chunk.multipliers = [1.0] * (self.chunk_length + 2)
+        return chunk
+
+    def get_target_prompt_token_count(self, token_count):
+        """returns the maximum number of tokens a prompt of a known length can have before it requires one more PromptChunk to be represented"""
+
+        return math.ceil(max(token_count, 1) / self.chunk_length) * self.chunk_length
 
     def tokenize(self, texts):
+        """Converts a batch of texts into a batch of token ids"""
+
         raise NotImplementedError
 
     def encode_with_transformers(self, tokens):
+        """
+        converts a batch of token ids (in python lists) into a single tensor with numeric respresentation of those tokens;
+        All python lists with tokens are assumed to have same length, usually 77.
+        if input is a list with B elements and each element has T tokens, expected output shape is (B, T, C), where C depends on
+        model - can be 768 and 1024
+        """
+
         raise NotImplementedError
 
     def encode_embedding_init_text(self, init_text, nvpt):
+        """Converts text into a tensor with this text's tokens' embeddings. Note that those are embeddings before they are passed through
+        transformers. nvpt is used as a maximum length in tokens. If text produces less teokens than nvpt, only this many is returned."""
+
         raise NotImplementedError
 
-    def tokenize_line(self, line, used_custom_terms, hijack_comments):
+    def tokenize_line(self, line):
+        """
+        this transforms a single prompt into a list of PromptChunk objects - as many as needed to
+        represent the prompt.
+        Returns the list and the total number of tokens in the prompt.
+        """
+
         if opts.enable_emphasis:
             parsed = prompt_parser.parse_prompt_attention(line)
         else:
@@ -32,205 +80,152 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
 
         tokenized = self.tokenize([text for text, _ in parsed])
 
-        fixes = []
-        remade_tokens = []
-        multipliers = []
+        chunks = []
+        chunk = PromptChunk()
+        token_count = 0
         last_comma = -1
 
-        for tokens, (text, weight) in zip(tokenized, parsed):
-            i = 0
-            while i < len(tokens):
-                token = tokens[i]
+        def next_chunk():
+            """puts current chunk into the list of results and produces the next one - empty"""
+            nonlocal token_count
+            nonlocal last_comma
+            nonlocal chunk
+
+            token_count += len(chunk.tokens)
+            to_add = self.chunk_length - len(chunk.tokens)
+            if to_add > 0:
+                chunk.tokens += [self.id_end] * to_add
+                chunk.multipliers += [1.0] * to_add
 
-                embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, i)
+            chunk.tokens = [self.id_start] + chunk.tokens + [self.id_end]
+            chunk.multipliers = [1.0] + chunk.multipliers + [1.0]
+
+            last_comma = -1
+            chunks.append(chunk)
+            chunk = PromptChunk()
+
+        for tokens, (text, weight) in zip(tokenized, parsed):
+            position = 0
+            while position < len(tokens):
+                token = tokens[position]
 
                 if token == self.comma_token:
-                    last_comma = len(remade_tokens)
-                elif opts.comma_padding_backtrack != 0 and max(len(remade_tokens), 1) % 75 == 0 and last_comma != -1 and len(remade_tokens) - last_comma <= opts.comma_padding_backtrack:
-                    last_comma += 1
-                    reloc_tokens = remade_tokens[last_comma:]
-                    reloc_mults = multipliers[last_comma:]
+                    last_comma = len(chunk.tokens)
+
+                # this is when we are at the end of alloted 75 tokens for the current chunk, and the current token is not a comma. opts.comma_padding_backtrack
+                # is a setting that specifies that is there is a comma nearby, the text after comma should be moved out of this chunk and into the next.
+                elif opts.comma_padding_backtrack != 0 and len(chunk.tokens) == self.chunk_length and last_comma != -1 and len(chunk.tokens) - last_comma <= opts.comma_padding_backtrack:
+                    break_location = last_comma + 1
+
+                    reloc_tokens = chunk.tokens[break_location:]
+                    reloc_mults = chunk.multipliers[break_location:]
 
-                    remade_tokens = remade_tokens[:last_comma]
-                    length = len(remade_tokens)
+                    chunk.tokens = chunk.tokens[:break_location]
+                    chunk.multipliers = chunk.multipliers[:break_location]
 
-                    rem = int(math.ceil(length / 75)) * 75 - length
-                    remade_tokens += [self.id_end] * rem + reloc_tokens
-                    multipliers = multipliers[:last_comma] + [1.0] * rem + reloc_mults
+                    next_chunk()
+                    chunk.tokens = reloc_tokens
+                    chunk.multipliers = reloc_mults
 
+                if len(chunk.tokens) == self.chunk_length:
+                    next_chunk()
+
+                embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, position)
                 if embedding is None:
-                    remade_tokens.append(token)
-                    multipliers.append(weight)
-                    i += 1
-                else:
-                    emb_len = int(embedding.vec.shape[0])
-                    iteration = len(remade_tokens) // 75
-                    if (len(remade_tokens) + emb_len) // 75 != iteration:
-                        rem = (75 * (iteration + 1) - len(remade_tokens))
-                        remade_tokens += [self.id_end] * rem
-                        multipliers += [1.0] * rem
-                        iteration += 1
-                    fixes.append((iteration, (len(remade_tokens) % 75, embedding)))
-                    remade_tokens += [0] * emb_len
-                    multipliers += [weight] * emb_len
-                    used_custom_terms.append((embedding.name, embedding.checksum()))
-                    i += embedding_length_in_tokens
-
-        token_count = len(remade_tokens)
-        prompt_target_length = get_target_prompt_token_count(token_count)
-        tokens_to_add = prompt_target_length - len(remade_tokens)
-
-        remade_tokens = remade_tokens + [self.id_end] * tokens_to_add
-        multipliers = multipliers + [1.0] * tokens_to_add
-
-        return remade_tokens, fixes, multipliers, token_count
-
-    def process_text(self, texts):
-        used_custom_terms = []
-        remade_batch_tokens = []
-        hijack_comments = []
-        hijack_fixes = []
+                    chunk.tokens.append(token)
+                    chunk.multipliers.append(weight)
+                    position += 1
+                    continue
+
+                emb_len = int(embedding.vec.shape[0])
+                if len(chunk.tokens) + emb_len > self.chunk_length:
+                    next_chunk()
+
+                chunk.fixes.append(PromptChunkFix(len(chunk.tokens), embedding))
+
+                chunk.tokens += [0] * emb_len
+                chunk.multipliers += [weight] * emb_len
+                position += embedding_length_in_tokens
+
+        if len(chunk.tokens) > 0:
+            next_chunk()
+
+        return chunks, token_count
+
+    def process_texts(self, texts):
+        """
+        Accepts a list of texts and calls tokenize_line() on each, with cache. Returns the list of results and maximum
+        length, in tokens, of all texts.
+        """
+
         token_count = 0
 
         cache = {}
-        batch_multipliers = []
+        batch_chunks = []
         for line in texts:
             if line in cache:
-                remade_tokens, fixes, multipliers = cache[line]
+                chunks = cache[line]
             else:
-                remade_tokens, fixes, multipliers, current_token_count = self.tokenize_line(line, used_custom_terms, hijack_comments)
+                chunks, current_token_count = self.tokenize_line(line)
                 token_count = max(current_token_count, token_count)
 
-                cache[line] = (remade_tokens, fixes, multipliers)
+                cache[line] = chunks
 
-            remade_batch_tokens.append(remade_tokens)
-            hijack_fixes.append(fixes)
-            batch_multipliers.append(multipliers)
+            batch_chunks.append(chunks)
 
-        return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count
+        return batch_chunks, token_count
 
-    def process_text_old(self, texts):
-        id_start = self.id_start
-        id_end = self.id_end
-        maxlen = self.wrapped.max_length  # you get to stay at 77
-        used_custom_terms = []
-        remade_batch_tokens = []
-        hijack_comments = []
-        hijack_fixes = []
-        token_count = 0
+    def forward(self, texts):
+        """
+        Accepts an array of texts; Passes texts through transformers network to create a tensor with numerical representation of those texts.
+        Returns a tensor with shape of (B, T, C), where B is length of the array; T is length, in tokens, of texts (including padding) - T will
+        be a multiple of 77; and C is dimensionality of each token - for SD1 it's 768, and for SD2 it's 1024.
+        An example shape returned by this function can be: (2, 77, 768).
+        Webui usually sends just one text at a time through this function - the only time when texts is an array with more than one elemenet
+        is when you do prompt editing: "a picture of a [cat:dog:0.4] eating ice cream"
+        """
 
-        cache = {}
-        batch_tokens = self.tokenize(texts)
-        batch_multipliers = []
-        for tokens in batch_tokens:
-            tuple_tokens = tuple(tokens)
+        if opts.use_old_emphasis_implementation:
+            import modules.sd_hijack_clip_old
+            return modules.sd_hijack_clip_old.forward_old(self, texts)
 
-            if tuple_tokens in cache:
-                remade_tokens, fixes, multipliers = cache[tuple_tokens]
-            else:
-                fixes = []
-                remade_tokens = []
-                multipliers = []
-                mult = 1.0
-
-                i = 0
-                while i < len(tokens):
-                    token = tokens[i]
-
-                    embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, i)
-
-                    mult_change = self.token_mults.get(token) if opts.enable_emphasis else None
-                    if mult_change is not None:
-                        mult *= mult_change
-                        i += 1
-                    elif embedding is None:
-                        remade_tokens.append(token)
-                        multipliers.append(mult)
-                        i += 1
-                    else:
-                        emb_len = int(embedding.vec.shape[0])
-                        fixes.append((len(remade_tokens), embedding))
-                        remade_tokens += [0] * emb_len
-                        multipliers += [mult] * emb_len
-                        used_custom_terms.append((embedding.name, embedding.checksum()))
-                        i += embedding_length_in_tokens
-
-                if len(remade_tokens) > maxlen - 2:
-                    vocab = {v: k for k, v in self.wrapped.tokenizer.get_vocab().items()}
-                    ovf = remade_tokens[maxlen - 2:]
-                    overflowing_words = [vocab.get(int(x), "") for x in ovf]
-                    overflowing_text = self.wrapped.tokenizer.convert_tokens_to_string(''.join(overflowing_words))
-                    hijack_comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n")
-
-                token_count = len(remade_tokens)
-                remade_tokens = remade_tokens + [id_end] * (maxlen - 2 - len(remade_tokens))
-                remade_tokens = [id_start] + remade_tokens[0:maxlen - 2] + [id_end]
-                cache[tuple_tokens] = (remade_tokens, fixes, multipliers)
-
-            multipliers = multipliers + [1.0] * (maxlen - 2 - len(multipliers))
-            multipliers = [1.0] + multipliers[0:maxlen - 2] + [1.0]
-
-            remade_batch_tokens.append(remade_tokens)
-            hijack_fixes.append(fixes)
-            batch_multipliers.append(multipliers)
-        return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count
-
-    def forward(self, text):
-        use_old = opts.use_old_emphasis_implementation
-        if use_old:
-            batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text_old(text)
-        else:
-            batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text(text)
-
-        self.hijack.comments += hijack_comments
-
-        if len(used_custom_terms) > 0:
-            self.hijack.comments.append("Used embeddings: " + ", ".join([f'{word} [{checksum}]' for word, checksum in used_custom_terms]))
-
-        if use_old:
-            self.hijack.fixes = hijack_fixes
-            return self.process_tokens(remade_batch_tokens, batch_multipliers)
-
-        z = None
-        i = 0
-        while max(map(len, remade_batch_tokens)) != 0:
-            rem_tokens = [x[75:] for x in remade_batch_tokens]
-            rem_multipliers = [x[75:] for x in batch_multipliers]
-
-            self.hijack.fixes = []
-            for unfiltered in hijack_fixes:
-                fixes = []
-                for fix in unfiltered:
-                    if fix[0] == i:
-                        fixes.append(fix[1])
-                self.hijack.fixes.append(fixes)
-
-            tokens = []
-            multipliers = []
-            for j in range(len(remade_batch_tokens)):
-                if len(remade_batch_tokens[j]) > 0:
-                    tokens.append(remade_batch_tokens[j][:75])
-                    multipliers.append(batch_multipliers[j][:75])
-                else:
-                    tokens.append([self.id_end] * 75)
-                    multipliers.append([1.0] * 75)
-
-            z1 = self.process_tokens(tokens, multipliers)
-            z = z1 if z is None else torch.cat((z, z1), axis=-2)
-
-            remade_batch_tokens = rem_tokens
-            batch_multipliers = rem_multipliers
-            i += 1
+        batch_chunks, token_count = self.process_texts(texts)
 
-        return z
+        used_embeddings = {}
+        chunk_count = max([len(x) for x in batch_chunks])
 
-    def process_tokens(self, remade_batch_tokens, batch_multipliers):
-        if not opts.use_old_emphasis_implementation:
-            remade_batch_tokens = [[self.id_start] + x[:75] + [self.id_end] for x in remade_batch_tokens]
-            batch_multipliers = [[1.0] + x[:75] + [1.0] for x in batch_multipliers]
+        zs = []
+        for i in range(chunk_count):
+            batch_chunk = [chunks[i] if i < len(chunks) else self.empty_chunk() for chunks in batch_chunks]
+
+            tokens = [x.tokens for x in batch_chunk]
+            multipliers = [x.multipliers for x in batch_chunk]
+            self.hijack.fixes = [x.fixes for x in batch_chunk]
 
+            for fixes in self.hijack.fixes:
+                for position, embedding in fixes:
+                    used_embeddings[embedding.name] = embedding
+
+            z = self.process_tokens(tokens, multipliers)
+            zs.append(z)
+
+        if len(used_embeddings) > 0:
+            embeddings_list = ", ".join([f'{name} [{embedding.checksum()}]' for name, embedding in used_embeddings.items()])
+            self.hijack.comments.append(f"Used embeddings: {embeddings_list}")
+
+        return torch.hstack(zs)
+
+    def process_tokens(self, remade_batch_tokens, batch_multipliers):
+        """
+        sends one single prompt chunk to be encoded by transformers neural network.
+        remade_batch_tokens is a batch of tokens - a list, where every element is a list of tokens; usually
+        there are exactly 77 tokens in the list. batch_multipliers is the same but for multipliers instead of tokens.
+        Multipliers are used to give more or less weight to the outputs of transformers network. Each multiplier
+        corresponds to one token.
+        """
         tokens = torch.asarray(remade_batch_tokens).to(devices.device)
 
+        # this is for SD2: SD1 uses the same token for padding and end of text, while SD2 uses different ones.
         if self.id_end != self.id_pad:
             for batch_pos in range(len(remade_batch_tokens)):
                 index = remade_batch_tokens[batch_pos].index(self.id_end)
@@ -239,8 +234,7 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
         z = self.encode_with_transformers(tokens)
 
         # restoring original mean is likely not correct, but it seems to work well to prevent artifacts that happen otherwise
-        batch_multipliers_of_same_length = [x + [1.0] * (75 - len(x)) for x in batch_multipliers]
-        batch_multipliers = torch.asarray(batch_multipliers_of_same_length).to(devices.device)
+        batch_multipliers = torch.asarray(batch_multipliers).to(devices.device)
         original_mean = z.mean()
         z *= batch_multipliers.reshape(batch_multipliers.shape + (1,)).expand(z.shape)
         new_mean = z.mean()

+ 81 - 0
modules/sd_hijack_clip_old.py

@@ -0,0 +1,81 @@
+from modules import sd_hijack_clip
+from modules import shared
+
+
+def process_text_old(self: sd_hijack_clip.FrozenCLIPEmbedderWithCustomWordsBase, texts):
+    id_start = self.id_start
+    id_end = self.id_end
+    maxlen = self.wrapped.max_length  # you get to stay at 77
+    used_custom_terms = []
+    remade_batch_tokens = []
+    hijack_comments = []
+    hijack_fixes = []
+    token_count = 0
+
+    cache = {}
+    batch_tokens = self.tokenize(texts)
+    batch_multipliers = []
+    for tokens in batch_tokens:
+        tuple_tokens = tuple(tokens)
+
+        if tuple_tokens in cache:
+            remade_tokens, fixes, multipliers = cache[tuple_tokens]
+        else:
+            fixes = []
+            remade_tokens = []
+            multipliers = []
+            mult = 1.0
+
+            i = 0
+            while i < len(tokens):
+                token = tokens[i]
+
+                embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, i)
+
+                mult_change = self.token_mults.get(token) if shared.opts.enable_emphasis else None
+                if mult_change is not None:
+                    mult *= mult_change
+                    i += 1
+                elif embedding is None:
+                    remade_tokens.append(token)
+                    multipliers.append(mult)
+                    i += 1
+                else:
+                    emb_len = int(embedding.vec.shape[0])
+                    fixes.append((len(remade_tokens), embedding))
+                    remade_tokens += [0] * emb_len
+                    multipliers += [mult] * emb_len
+                    used_custom_terms.append((embedding.name, embedding.checksum()))
+                    i += embedding_length_in_tokens
+
+            if len(remade_tokens) > maxlen - 2:
+                vocab = {v: k for k, v in self.wrapped.tokenizer.get_vocab().items()}
+                ovf = remade_tokens[maxlen - 2:]
+                overflowing_words = [vocab.get(int(x), "") for x in ovf]
+                overflowing_text = self.wrapped.tokenizer.convert_tokens_to_string(''.join(overflowing_words))
+                hijack_comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n")
+
+            token_count = len(remade_tokens)
+            remade_tokens = remade_tokens + [id_end] * (maxlen - 2 - len(remade_tokens))
+            remade_tokens = [id_start] + remade_tokens[0:maxlen - 2] + [id_end]
+            cache[tuple_tokens] = (remade_tokens, fixes, multipliers)
+
+        multipliers = multipliers + [1.0] * (maxlen - 2 - len(multipliers))
+        multipliers = [1.0] + multipliers[0:maxlen - 2] + [1.0]
+
+        remade_batch_tokens.append(remade_tokens)
+        hijack_fixes.append(fixes)
+        batch_multipliers.append(multipliers)
+    return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count
+
+
+def forward_old(self: sd_hijack_clip.FrozenCLIPEmbedderWithCustomWordsBase, texts):
+    batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = process_text_old(self, texts)
+
+    self.hijack.comments += hijack_comments
+
+    if len(used_custom_terms) > 0:
+        self.hijack.comments.append("Used embeddings: " + ", ".join([f'{word} [{checksum}]' for word, checksum in used_custom_terms]))
+
+    self.hijack.fixes = hijack_fixes
+    return self.process_tokens(remade_batch_tokens, batch_multipliers)

+ 0 - 1
modules/textual_inversion/textual_inversion.py

@@ -79,7 +79,6 @@ class EmbeddingDatabase:
 
         self.word_embeddings[embedding.name] = embedding
 
-        # TODO changing between clip and open clip changes tokenization, which will cause embeddings to stop working
         ids = model.cond_stage_model.tokenize([embedding.name])[0]
 
         first_id = ids[0]

+ 1 - 1
modules/ui.py

@@ -368,7 +368,7 @@ def update_token_counter(text, steps):
 
     flat_prompts = reduce(lambda list1, list2: list1+list2, prompt_schedules)
     prompts = [prompt_text for step, prompt_text in flat_prompts]
-    tokens, token_count, max_length = max([model_hijack.tokenize(prompt) for prompt in prompts], key=lambda args: args[1])
+    token_count, max_length = max([model_hijack.get_prompt_lengths(prompt) for prompt in prompts], key=lambda args: args[0])
     style_class = ' class="red"' if (token_count > max_length) else ""
     return f"<span {style_class}>{token_count}/{max_length}</span>"