vectorstores.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121
  1. from langchain.vectorstores import FAISS
  2. from typing import Any, Callable, List, Optional, Tuple, Dict
  3. from langchain.docstore.document import Document
  4. from langchain.docstore.base import Docstore
  5. from langchain.vectorstores.utils import maximal_marginal_relevance
  6. from langchain.embeddings.base import Embeddings
  7. import uuid
  8. from langchain.docstore.in_memory import InMemoryDocstore
  9. import numpy as np
  10. def dependable_faiss_import() -> Any:
  11. """Import faiss if available, otherwise raise error."""
  12. try:
  13. import faiss
  14. except ImportError:
  15. raise ValueError(
  16. "Could not import faiss python package. "
  17. "Please install it with `pip install faiss` "
  18. "or `pip install faiss-cpu` (depending on Python version)."
  19. )
  20. return faiss
  21. class FAISSVS(FAISS):
  22. def __init__(self,
  23. embedding_function: Callable[..., Any],
  24. index: Any,
  25. docstore: Docstore,
  26. index_to_docstore_id: Dict[int, str]):
  27. super().__init__(embedding_function, index, docstore, index_to_docstore_id)
  28. def max_marginal_relevance_search_by_vector(
  29. self, embedding: List[float], k: int = 4, fetch_k: int = 20, **kwargs: Any
  30. ) -> List[Tuple[Document, float]]:
  31. """Return docs selected using the maximal marginal relevance.
  32. Maximal marginal relevance optimizes for similarity to query AND diversity
  33. among selected documents.
  34. Args:
  35. embedding: Embedding to look up documents similar to.
  36. k: Number of Documents to return. Defaults to 4.
  37. fetch_k: Number of Documents to fetch to pass to MMR algorithm.
  38. Returns:
  39. List of Documents with scores selected by maximal marginal relevance.
  40. """
  41. scores, indices = self.index.search(np.array([embedding], dtype=np.float32), fetch_k)
  42. # -1 happens when not enough docs are returned.
  43. embeddings = [self.index.reconstruct(int(i)) for i in indices[0] if i != -1]
  44. mmr_selected = maximal_marginal_relevance(
  45. np.array([embedding], dtype=np.float32), embeddings, k=k
  46. )
  47. selected_indices = [indices[0][i] for i in mmr_selected]
  48. selected_scores = [scores[0][i] for i in mmr_selected]
  49. docs = []
  50. for i, score in zip(selected_indices, selected_scores):
  51. if i == -1:
  52. # This happens when not enough docs are returned.
  53. continue
  54. _id = self.index_to_docstore_id[i]
  55. doc = self.docstore.search(_id)
  56. if not isinstance(doc, Document):
  57. raise ValueError(f"Could not find document for id {_id}, got {doc}")
  58. docs.append((doc, score))
  59. return docs
  60. def max_marginal_relevance_search(
  61. self,
  62. query: str,
  63. k: int = 4,
  64. fetch_k: int = 20,
  65. **kwargs: Any,
  66. ) -> List[Tuple[Document, float]]:
  67. """Return docs selected using the maximal marginal relevance.
  68. Maximal marginal relevance optimizes for similarity to query AND diversity
  69. among selected documents.
  70. Args:
  71. query: Text to look up documents similar to.
  72. k: Number of Documents to return. Defaults to 4.
  73. fetch_k: Number of Documents to fetch to pass to MMR algorithm.
  74. Returns:
  75. List of Documents with scores selected by maximal marginal relevance.
  76. """
  77. embedding = self.embedding_function(query)
  78. docs = self.max_marginal_relevance_search_by_vector(embedding, k, fetch_k)
  79. return docs
  80. @classmethod
  81. def __from(
  82. cls,
  83. texts: List[str],
  84. embeddings: List[List[float]],
  85. embedding: Embeddings,
  86. metadatas: Optional[List[dict]] = None,
  87. **kwargs: Any,
  88. ) -> FAISS:
  89. faiss = dependable_faiss_import()
  90. index = faiss.IndexFlatIP(len(embeddings[0]))
  91. index.add(np.array(embeddings, dtype=np.float32))
  92. # # my code, for speeding up search
  93. # quantizer = faiss.IndexFlatL2(len(embeddings[0]))
  94. # index = faiss.IndexIVFFlat(quantizer, len(embeddings[0]), 100)
  95. # index.train(np.array(embeddings, dtype=np.float32))
  96. # index.add(np.array(embeddings, dtype=np.float32))
  97. documents = []
  98. for i, text in enumerate(texts):
  99. metadata = metadatas[i] if metadatas else {}
  100. documents.append(Document(page_content=text, metadata=metadata))
  101. index_to_id = {i: str(uuid.uuid4()) for i in range(len(documents))}
  102. docstore = InMemoryDocstore(
  103. {index_to_id[i]: doc for i, doc in enumerate(documents)}
  104. )
  105. return cls(embedding.embed_query, index, docstore, index_to_id)