|
@@ -33,6 +33,7 @@ def load_file(filepath):
|
|
|
class LocalDocQA:
|
|
|
llm: object = None
|
|
|
embeddings: object = None
|
|
|
+ top_k: int = VECTOR_SEARCH_TOP_K
|
|
|
|
|
|
def init_cfg(self,
|
|
|
embedding_model: str = EMBEDDING_MODEL,
|
|
@@ -49,9 +50,10 @@ class LocalDocQA:
|
|
|
use_ptuning_v2=use_ptuning_v2)
|
|
|
self.llm.history_len = llm_history_len
|
|
|
|
|
|
- self.embeddings = HuggingFaceEmbeddings(model_name=embedding_model_dict[embedding_model], )
|
|
|
- self.embeddings.client = sentence_transformers.SentenceTransformer(self.embeddings.model_name,
|
|
|
- device=embedding_device)
|
|
|
+ self.embeddings = HuggingFaceEmbeddings(model_name=embedding_model_dict[embedding_model],
|
|
|
+ model_kwargs={'device': embedding_device})
|
|
|
+ # self.embeddings.client = sentence_transformers.SentenceTransformer(self.embeddings.model_name,
|
|
|
+ # device=embedding_device)
|
|
|
self.top_k = top_k
|
|
|
|
|
|
def init_knowledge_vector_store(self,
|