Răsfoiți Sursa

update cli_demo.py

imClumsyPanda 2 ani în urmă
părinte
comite
1c51d6cafc
2 a modificat fișierele cu 8 adăugiri și 6 ștergeri
  1. 6 5
      chains/local_doc_qa.py
  2. 2 1
      cli_demo.py

+ 6 - 5
chains/local_doc_qa.py

@@ -28,7 +28,8 @@ class LocalDocQA:
                  embedding_device=EMBEDDING_DEVICE,
                  embedding_device=EMBEDDING_DEVICE,
                  llm_history_len: int = LLM_HISTORY_LEN,
                  llm_history_len: int = LLM_HISTORY_LEN,
                  llm_model: str = LLM_MODEL,
                  llm_model: str = LLM_MODEL,
-                 llm_device=LLM_DEVICE
+                 llm_device=LLM_DEVICE,
+                 top_k=VECTOR_SEARCH_TOP_K,
                  ):
                  ):
         self.llm = ChatGLM()
         self.llm = ChatGLM()
         self.llm.load_model(model_name_or_path=llm_model_dict[llm_model],
         self.llm.load_model(model_name_or_path=llm_model_dict[llm_model],
@@ -38,6 +39,7 @@ class LocalDocQA:
         self.embeddings = HuggingFaceEmbeddings(model_name=embedding_model_dict[embedding_model], )
         self.embeddings = HuggingFaceEmbeddings(model_name=embedding_model_dict[embedding_model], )
         self.embeddings.client = sentence_transformers.SentenceTransformer(self.embeddings.model_name,
         self.embeddings.client = sentence_transformers.SentenceTransformer(self.embeddings.model_name,
                                                                            device=embedding_device)
                                                                            device=embedding_device)
+        self.top_k = top_k
 
 
     def init_knowledge_vector_store(self,
     def init_knowledge_vector_store(self,
                                     filepath: str):
                                     filepath: str):
@@ -65,15 +67,14 @@ class LocalDocQA:
                     print(f"{file} 未能成功加载")
                     print(f"{file} 未能成功加载")
 
 
         vector_store = FAISS.from_documents(docs, self.embeddings)
         vector_store = FAISS.from_documents(docs, self.embeddings)
-        vs_path = f"""./vector_store/{os.path.splitext(file)}_FAISS_{datetime.datetime.now().strftime("%Y%m%d_%H%M%S")}"""
+        vs_path = f"""./vector_store/{os.path.splitext(file)[0]}_FAISS_{datetime.datetime.now().strftime("%Y%m%d_%H%M%S")}"""
         vector_store.save_local(vs_path)
         vector_store.save_local(vs_path)
         return vs_path
         return vs_path
 
 
     def get_knowledge_based_answer(self,
     def get_knowledge_based_answer(self,
                                    query,
                                    query,
                                    vs_path,
                                    vs_path,
-                                   chat_history=[],
-                                   top_k=VECTOR_SEARCH_TOP_K):
+                                   chat_history=[],):
         prompt_template = """基于以下已知信息,简洁和专业的来回答用户的问题。
         prompt_template = """基于以下已知信息,简洁和专业的来回答用户的问题。
     如果无法从中得到答案,请说 "根据已知信息无法回答该问题" 或 "没有提供足够的相关信息",不允许在答案中添加编造成分,答案请使用中文。
     如果无法从中得到答案,请说 "根据已知信息无法回答该问题" 或 "没有提供足够的相关信息",不允许在答案中添加编造成分,答案请使用中文。
     
     
@@ -90,7 +91,7 @@ class LocalDocQA:
         vector_store = FAISS.load_local(vs_path, self.embeddings)
         vector_store = FAISS.load_local(vs_path, self.embeddings)
         knowledge_chain = RetrievalQA.from_llm(
         knowledge_chain = RetrievalQA.from_llm(
             llm=self.llm,
             llm=self.llm,
-            retriever=vector_store.as_retriever(search_kwargs={"k": top_k}),
+            retriever=vector_store.as_retriever(search_kwargs={"k": self.top_k}),
             prompt=prompt
             prompt=prompt
         )
         )
         knowledge_chain.combine_documents_chain.document_prompt = PromptTemplate(
         knowledge_chain.combine_documents_chain.document_prompt = PromptTemplate(

+ 2 - 1
cli_demo.py

@@ -15,7 +15,8 @@ if __name__ == "__main__":
     local_doc_qa.init_cfg(llm_model=LLM_MODEL,
     local_doc_qa.init_cfg(llm_model=LLM_MODEL,
                           embedding_model=EMBEDDING_MODEL,
                           embedding_model=EMBEDDING_MODEL,
                           embedding_device=EMBEDDING_DEVICE,
                           embedding_device=EMBEDDING_DEVICE,
-                          llm_history_len=LLM_HISTORY_LEN)
+                          llm_history_len=LLM_HISTORY_LEN,
+                          top_k=VECTOR_SEARCH_TOP_K)
     vs_path = None
     vs_path = None
     while not vs_path:
     while not vs_path:
         filepath = input("Input your local knowledge file path 请输入本地知识文件路径:")
         filepath = input("Input your local knowledge file path 请输入本地知识文件路径:")