embeddings.py 1.1 KB

12345678910111213141516171819202122232425262728293031323334
  1. from langchain.embeddings.huggingface import HuggingFaceEmbeddings
  2. from typing import Any, List
  3. class MyEmbeddings(HuggingFaceEmbeddings):
  4. def __init__(self, **kwargs: Any):
  5. super().__init__(**kwargs)
  6. def embed_documents(self, texts: List[str]) -> List[List[float]]:
  7. """Compute doc embeddings using a HuggingFace transformer model.
  8. Args:
  9. texts: The list of texts to embed.
  10. Returns:
  11. List of embeddings, one for each text.
  12. """
  13. texts = list(map(lambda x: x.replace("\n", " "), texts))
  14. embeddings = self.client.encode(texts, normalize_embeddings=True)
  15. return embeddings.tolist()
  16. def embed_query(self, text: str) -> List[float]:
  17. """Compute query embeddings using a HuggingFace transformer model.
  18. Args:
  19. text: The text to embed.
  20. Returns:
  21. Embeddings for the text.
  22. """
  23. text = text.replace("\n", " ")
  24. embedding = self.client.encode(text, normalize_embeddings=True)
  25. return embedding.tolist()