12345678910111213141516171819202122232425262728293031323334 |
- from langchain.embeddings.huggingface import HuggingFaceEmbeddings
- from typing import Any, List
- class MyEmbeddings(HuggingFaceEmbeddings):
- def __init__(self, **kwargs: Any):
- super().__init__(**kwargs)
-
- def embed_documents(self, texts: List[str]) -> List[List[float]]:
- """Compute doc embeddings using a HuggingFace transformer model.
- Args:
- texts: The list of texts to embed.
- Returns:
- List of embeddings, one for each text.
- """
- texts = list(map(lambda x: x.replace("\n", " "), texts))
- embeddings = self.client.encode(texts, normalize_embeddings=True)
- return embeddings.tolist()
- def embed_query(self, text: str) -> List[float]:
- """Compute query embeddings using a HuggingFace transformer model.
- Args:
- text: The text to embed.
- Returns:
- Embeddings for the text.
- """
- text = text.replace("\n", " ")
- embedding = self.client.encode(text, normalize_embeddings=True)
- return embedding.tolist()
|