Browse Source

feat: 重写FAISS和HuggingFaceEmbedding,采用MMR相似向量检索

wangxinkai 2 years ago
parent
commit
daafe8d5fa
2 changed files with 155 additions and 0 deletions
  1. 34 0
      chains/modules/embeddings.py
  2. 121 0
      chains/modules/vectorstores.py

+ 34 - 0
chains/modules/embeddings.py

@@ -0,0 +1,34 @@
+from langchain.embeddings.huggingface import HuggingFaceEmbeddings
+
+from typing import Any, List
+
+
+class MyEmbeddings(HuggingFaceEmbeddings):
+    def __init__(self, **kwargs: Any):
+        super().__init__(**kwargs)
+        
+    def embed_documents(self, texts: List[str]) -> List[List[float]]:
+        """Compute doc embeddings using a HuggingFace transformer model.
+
+        Args:
+            texts: The list of texts to embed.
+
+        Returns:
+            List of embeddings, one for each text.
+        """
+        texts = list(map(lambda x: x.replace("\n", " "), texts))
+        embeddings = self.client.encode(texts, normalize_embeddings=True)
+        return embeddings.tolist()
+
+    def embed_query(self, text: str) -> List[float]:
+        """Compute query embeddings using a HuggingFace transformer model.
+
+        Args:
+            text: The text to embed.
+
+        Returns:
+            Embeddings for the text.
+        """
+        text = text.replace("\n", " ")
+        embedding = self.client.encode(text, normalize_embeddings=True)
+        return embedding.tolist()

+ 121 - 0
chains/modules/vectorstores.py

@@ -0,0 +1,121 @@
+from langchain.vectorstores import FAISS
+from typing import Any, Callable, List, Optional, Tuple, Dict
+from langchain.docstore.document import Document
+from langchain.docstore.base import Docstore
+
+from langchain.vectorstores.utils import maximal_marginal_relevance
+from langchain.embeddings.base import Embeddings
+import uuid
+from langchain.docstore.in_memory import InMemoryDocstore
+
+import numpy as np
+
+def dependable_faiss_import() -> Any:
+    """Import faiss if available, otherwise raise error."""
+    try:
+        import faiss
+    except ImportError:
+        raise ValueError(
+            "Could not import faiss python package. "
+            "Please install it with `pip install faiss` "
+            "or `pip install faiss-cpu` (depending on Python version)."
+        )
+    return faiss
+
+class FAISSVS(FAISS):
+    def __init__(self, 
+                 embedding_function: Callable[..., Any], 
+                 index: Any, 
+                 docstore: Docstore, 
+                 index_to_docstore_id: Dict[int, str]):
+        super().__init__(embedding_function, index, docstore, index_to_docstore_id)
+
+    def max_marginal_relevance_search_by_vector(
+        self, embedding: List[float], k: int = 4, fetch_k: int = 20, **kwargs: Any
+    ) -> List[Tuple[Document, float]]:
+        """Return docs selected using the maximal marginal relevance.
+
+        Maximal marginal relevance optimizes for similarity to query AND diversity
+        among selected documents.
+
+        Args:
+            embedding: Embedding to look up documents similar to.
+            k: Number of Documents to return. Defaults to 4.
+            fetch_k: Number of Documents to fetch to pass to MMR algorithm.
+
+        Returns:
+            List of Documents with scores selected by maximal marginal relevance.
+        """
+        scores, indices = self.index.search(np.array([embedding], dtype=np.float32), fetch_k)
+        # -1 happens when not enough docs are returned.
+        embeddings = [self.index.reconstruct(int(i)) for i in indices[0] if i != -1]
+        mmr_selected = maximal_marginal_relevance(
+            np.array([embedding], dtype=np.float32), embeddings, k=k
+        )
+        selected_indices = [indices[0][i] for i in mmr_selected]
+        selected_scores = [scores[0][i] for i in mmr_selected]
+        docs = []
+        for i, score in zip(selected_indices, selected_scores):
+            if i == -1:
+                # This happens when not enough docs are returned.
+                continue
+            _id = self.index_to_docstore_id[i]
+            doc = self.docstore.search(_id)
+            if not isinstance(doc, Document):
+                raise ValueError(f"Could not find document for id {_id}, got {doc}")
+            docs.append((doc, score))
+        return docs
+
+    def max_marginal_relevance_search(
+        self,
+        query: str,
+        k: int = 4,
+        fetch_k: int = 20,
+        **kwargs: Any,
+    ) -> List[Tuple[Document, float]]:
+        """Return docs selected using the maximal marginal relevance.
+
+        Maximal marginal relevance optimizes for similarity to query AND diversity
+        among selected documents.
+
+        Args:
+            query: Text to look up documents similar to.
+            k: Number of Documents to return. Defaults to 4.
+            fetch_k: Number of Documents to fetch to pass to MMR algorithm.
+
+        Returns:
+            List of Documents with scores selected by maximal marginal relevance.
+        """
+        embedding = self.embedding_function(query)
+        docs = self.max_marginal_relevance_search_by_vector(embedding, k, fetch_k)
+        return docs
+    
+    @classmethod
+    def __from(
+        cls,
+        texts: List[str],
+        embeddings: List[List[float]],
+        embedding: Embeddings,
+        metadatas: Optional[List[dict]] = None,
+        **kwargs: Any,
+    ) -> FAISS:
+        faiss = dependable_faiss_import()
+        index = faiss.IndexFlatIP(len(embeddings[0]))
+        index.add(np.array(embeddings, dtype=np.float32))
+
+        # # my code, for speeding up search
+        # quantizer = faiss.IndexFlatL2(len(embeddings[0]))
+        # index = faiss.IndexIVFFlat(quantizer, len(embeddings[0]), 100)
+        # index.train(np.array(embeddings, dtype=np.float32))
+        # index.add(np.array(embeddings, dtype=np.float32))
+
+        documents = []
+        for i, text in enumerate(texts):
+            metadata = metadatas[i] if metadatas else {}
+            documents.append(Document(page_content=text, metadata=metadata))
+        index_to_id = {i: str(uuid.uuid4()) for i in range(len(documents))}
+        docstore = InMemoryDocstore(
+            {index_to_id[i]: doc for i, doc in enumerate(documents)}
+        )
+        return cls(embedding.embed_query, index, docstore, index_to_id)
+