|
@@ -28,7 +28,8 @@ class LocalDocQA:
|
|
embedding_device=EMBEDDING_DEVICE,
|
|
embedding_device=EMBEDDING_DEVICE,
|
|
llm_history_len: int = LLM_HISTORY_LEN,
|
|
llm_history_len: int = LLM_HISTORY_LEN,
|
|
llm_model: str = LLM_MODEL,
|
|
llm_model: str = LLM_MODEL,
|
|
- llm_device=LLM_DEVICE
|
|
|
|
|
|
+ llm_device=LLM_DEVICE,
|
|
|
|
+ top_k=VECTOR_SEARCH_TOP_K,
|
|
):
|
|
):
|
|
self.llm = ChatGLM()
|
|
self.llm = ChatGLM()
|
|
self.llm.load_model(model_name_or_path=llm_model_dict[llm_model],
|
|
self.llm.load_model(model_name_or_path=llm_model_dict[llm_model],
|
|
@@ -38,6 +39,7 @@ class LocalDocQA:
|
|
self.embeddings = HuggingFaceEmbeddings(model_name=embedding_model_dict[embedding_model], )
|
|
self.embeddings = HuggingFaceEmbeddings(model_name=embedding_model_dict[embedding_model], )
|
|
self.embeddings.client = sentence_transformers.SentenceTransformer(self.embeddings.model_name,
|
|
self.embeddings.client = sentence_transformers.SentenceTransformer(self.embeddings.model_name,
|
|
device=embedding_device)
|
|
device=embedding_device)
|
|
|
|
+ self.top_k = top_k
|
|
|
|
|
|
def init_knowledge_vector_store(self,
|
|
def init_knowledge_vector_store(self,
|
|
filepath: str):
|
|
filepath: str):
|
|
@@ -65,15 +67,14 @@ class LocalDocQA:
|
|
print(f"{file} 未能成功加载")
|
|
print(f"{file} 未能成功加载")
|
|
|
|
|
|
vector_store = FAISS.from_documents(docs, self.embeddings)
|
|
vector_store = FAISS.from_documents(docs, self.embeddings)
|
|
- vs_path = f"""./vector_store/{os.path.splitext(file)}_FAISS_{datetime.datetime.now().strftime("%Y%m%d_%H%M%S")}"""
|
|
|
|
|
|
+ vs_path = f"""./vector_store/{os.path.splitext(file)[0]}_FAISS_{datetime.datetime.now().strftime("%Y%m%d_%H%M%S")}"""
|
|
vector_store.save_local(vs_path)
|
|
vector_store.save_local(vs_path)
|
|
return vs_path
|
|
return vs_path
|
|
|
|
|
|
def get_knowledge_based_answer(self,
|
|
def get_knowledge_based_answer(self,
|
|
query,
|
|
query,
|
|
vs_path,
|
|
vs_path,
|
|
- chat_history=[],
|
|
|
|
- top_k=VECTOR_SEARCH_TOP_K):
|
|
|
|
|
|
+ chat_history=[],):
|
|
prompt_template = """基于以下已知信息,简洁和专业的来回答用户的问题。
|
|
prompt_template = """基于以下已知信息,简洁和专业的来回答用户的问题。
|
|
如果无法从中得到答案,请说 "根据已知信息无法回答该问题" 或 "没有提供足够的相关信息",不允许在答案中添加编造成分,答案请使用中文。
|
|
如果无法从中得到答案,请说 "根据已知信息无法回答该问题" 或 "没有提供足够的相关信息",不允许在答案中添加编造成分,答案请使用中文。
|
|
|
|
|
|
@@ -90,7 +91,7 @@ class LocalDocQA:
|
|
vector_store = FAISS.load_local(vs_path, self.embeddings)
|
|
vector_store = FAISS.load_local(vs_path, self.embeddings)
|
|
knowledge_chain = RetrievalQA.from_llm(
|
|
knowledge_chain = RetrievalQA.from_llm(
|
|
llm=self.llm,
|
|
llm=self.llm,
|
|
- retriever=vector_store.as_retriever(search_kwargs={"k": top_k}),
|
|
|
|
|
|
+ retriever=vector_store.as_retriever(search_kwargs={"k": self.top_k}),
|
|
prompt=prompt
|
|
prompt=prompt
|
|
)
|
|
)
|
|
knowledge_chain.combine_documents_chain.document_prompt = PromptTemplate(
|
|
knowledge_chain.combine_documents_chain.document_prompt = PromptTemplate(
|