Procházet zdrojové kódy

Merge pull request #10803 from klimaleksus/refactoring-for-embedding-merge

Refactor EmbeddingDatabase.register_embedding() to allow unregistering
AUTOMATIC1111 před 2 roky
rodič
revize
881de0df38
1 změnil soubory, kde provedl 19 přidání a 6 odebrání
  1. 19 6
      modules/textual_inversion/textual_inversion.py

+ 19 - 6
modules/textual_inversion/textual_inversion.py

@@ -119,16 +119,29 @@ class EmbeddingDatabase:
         self.embedding_dirs.clear()
 
     def register_embedding(self, embedding, model):
-        self.word_embeddings[embedding.name] = embedding
-
-        ids = model.cond_stage_model.tokenize([embedding.name])[0]
+        return self.register_embedding_by_name(embedding, model, embedding.name)
 
+    def register_embedding_by_name(self, embedding, model, name):
+        ids = model.cond_stage_model.tokenize([name])[0]
         first_id = ids[0]
         if first_id not in self.ids_lookup:
             self.ids_lookup[first_id] = []
-
-        self.ids_lookup[first_id] = sorted(self.ids_lookup[first_id] + [(ids, embedding)], key=lambda x: len(x[0]), reverse=True)
-
+        if name in self.word_embeddings:
+            # remove old one from the lookup list
+            lookup = [x for x in self.ids_lookup[first_id] if x[1].name!=name]
+        else:
+            lookup = self.ids_lookup[first_id]
+        if embedding is not None:
+            lookup += [(ids, embedding)]
+        self.ids_lookup[first_id] = sorted(lookup, key=lambda x: len(x[0]), reverse=True)
+        if embedding is None:
+            # unregister embedding with specified name
+            if name in self.word_embeddings:
+                del self.word_embeddings[name]
+            if len(self.ids_lookup[first_id])==0:
+                del self.ids_lookup[first_id]
+            return None
+        self.word_embeddings[name] = embedding
         return embedding
 
     def get_expected_shape(self):