Explorar el Código

Merge pull request #6196 from philpax/add-embeddings-api

feat(api): add /sdapi/v1/embeddings
AUTOMATIC1111 hace 2 años
padre
commit
fd4461d44c
Se han modificado 3 ficheros con 35 adiciones y 4 borrados
  1. 21 0
      modules/api/api.py
  2. 10 0
      modules/api/models.py
  3. 4 4
      modules/textual_inversion/textual_inversion.py

+ 21 - 0
modules/api/api.py

@@ -100,6 +100,7 @@ class Api:
         self.add_api_route("/sdapi/v1/prompt-styles", self.get_prompt_styles, methods=["GET"], response_model=List[PromptStyleItem])
         self.add_api_route("/sdapi/v1/artist-categories", self.get_artists_categories, methods=["GET"], response_model=List[str])
         self.add_api_route("/sdapi/v1/artists", self.get_artists, methods=["GET"], response_model=List[ArtistItem])
+        self.add_api_route("/sdapi/v1/embeddings", self.get_embeddings, methods=["GET"], response_model=EmbeddingsResponse)
         self.add_api_route("/sdapi/v1/refresh-checkpoints", self.refresh_checkpoints, methods=["POST"])
         self.add_api_route("/sdapi/v1/create/embedding", self.create_embedding, methods=["POST"], response_model=CreateResponse)
         self.add_api_route("/sdapi/v1/create/hypernetwork", self.create_hypernetwork, methods=["POST"], response_model=CreateResponse)
@@ -327,6 +328,26 @@ class Api:
     def get_artists(self):
         return [{"name":x[0], "score":x[1], "category":x[2]} for x in shared.artist_db.artists]
 
+    def get_embeddings(self):
+        db = sd_hijack.model_hijack.embedding_db
+
+        def convert_embedding(embedding):
+            return {
+                "step": embedding.step,
+                "sd_checkpoint": embedding.sd_checkpoint,
+                "sd_checkpoint_name": embedding.sd_checkpoint_name,
+                "shape": embedding.shape,
+                "vectors": embedding.vectors,
+            }
+
+        def convert_embeddings(embeddings):
+            return {embedding.name: convert_embedding(embedding) for embedding in embeddings.values()}
+
+        return {
+            "loaded": convert_embeddings(db.word_embeddings),
+            "skipped": convert_embeddings(db.skipped_embeddings),
+        }
+
     def refresh_checkpoints(self):
         shared.refresh_checkpoints()
 

+ 10 - 0
modules/api/models.py

@@ -249,3 +249,13 @@ class ArtistItem(BaseModel):
     score: float = Field(title="Score")
     category: str = Field(title="Category")
 
+class EmbeddingItem(BaseModel):
+    step: Optional[int] = Field(title="Step", description="The number of steps that were used to train this embedding, if available")
+    sd_checkpoint: Optional[str] = Field(title="SD Checkpoint", description="The hash of the checkpoint this embedding was trained on, if available")
+    sd_checkpoint_name: Optional[str] = Field(title="SD Checkpoint Name", description="The name of the checkpoint this embedding was trained on, if available. Note that this is the name that was used by the trainer; for a stable identifier, use `sd_checkpoint` instead")
+    shape: int = Field(title="Shape", description="The length of each individual vector in the embedding")
+    vectors: int = Field(title="Vectors", description="The number of vectors in the embedding")
+
+class EmbeddingsResponse(BaseModel):
+    loaded: Dict[str, EmbeddingItem] = Field(title="Loaded", description="Embeddings loaded for the current model")
+    skipped: Dict[str, EmbeddingItem] = Field(title="Skipped", description="Embeddings skipped for the current model (likely due to architecture incompatibility)")

+ 4 - 4
modules/textual_inversion/textual_inversion.py

@@ -59,7 +59,7 @@ class EmbeddingDatabase:
     def __init__(self, embeddings_dir):
         self.ids_lookup = {}
         self.word_embeddings = {}
-        self.skipped_embeddings = []
+        self.skipped_embeddings = {}
         self.dir_mtime = None
         self.embeddings_dir = embeddings_dir
         self.expected_shape = -1
@@ -91,7 +91,7 @@ class EmbeddingDatabase:
         self.dir_mtime = mt
         self.ids_lookup.clear()
         self.word_embeddings.clear()
-        self.skipped_embeddings = []
+        self.skipped_embeddings.clear()
         self.expected_shape = self.get_expected_shape()
 
         def process_file(path, filename):
@@ -136,7 +136,7 @@ class EmbeddingDatabase:
             if self.expected_shape == -1 or self.expected_shape == embedding.shape:
                 self.register_embedding(embedding, shared.sd_model)
             else:
-                self.skipped_embeddings.append(name)
+                self.skipped_embeddings[name] = embedding
 
         for fn in os.listdir(self.embeddings_dir):
             try:
@@ -153,7 +153,7 @@ class EmbeddingDatabase:
 
         print(f"Textual inversion embeddings loaded({len(self.word_embeddings)}): {', '.join(self.word_embeddings.keys())}")
         if len(self.skipped_embeddings) > 0:
-            print(f"Textual inversion embeddings skipped({len(self.skipped_embeddings)}): {', '.join(self.skipped_embeddings)}")
+            print(f"Textual inversion embeddings skipped({len(self.skipped_embeddings)}): {', '.join(self.skipped_embeddings.keys())}")
 
     def find_embedding_at_position(self, tokens, offset):
         token = tokens[offset]