Selaa lähdekoodia

update local_doc_qa.py

imClumsyPanda 2 vuotta sitten
vanhempi
commit
7d4560e599
3 muutettua tiedostoa jossa 20 lisäystä ja 8 poistoa
  1. 4 3
      api.py
  2. 15 4
      chains/local_doc_qa.py
  3. 1 1
      cli_demo.py

+ 4 - 3
api.py

@@ -54,10 +54,11 @@ async def upload_file(UserFile: UploadFile=File(...)):
         # print(UserFile.filename)
         with open(filepath, 'wb') as f:
             f.write(content)
-        vs_path = local_doc_qa.init_knowledge_vector_store(filepath)
+        vs_path, files = local_doc_qa.init_knowledge_vector_store(filepath)
         response = {
-            'msg': 'seccessful',
-            'status': 1
+            'msg': 'seccess' if len(files)>0 else 'fail',
+            'status': 1 if len(files)>0 else 0,
+            'loaded_files': files
         }
         
     except Exception as err:

+ 15 - 4
chains/local_doc_qa.py

@@ -53,7 +53,9 @@ class LocalDocQA:
         self.top_k = top_k
 
     def init_knowledge_vector_store(self,
-                                    filepath: str or List[str]):
+                                    filepath: str or List[str],
+                                    vs_path: str or os.PathLike = None):
+        loaded_files = []
         if isinstance(filepath, str):
             if not os.path.exists(filepath):
                 print("路径不存在")
@@ -63,6 +65,7 @@ class LocalDocQA:
                 try:
                     docs = load_file(filepath)
                     print(f"{file} 已成功加载")
+                    loaded_files.append(filepath)
                 except Exception as e:
                     print(e)
                     print(f"{file} 未能成功加载")
@@ -74,6 +77,7 @@ class LocalDocQA:
                     try:
                         docs += load_file(fullfilepath)
                         print(f"{file} 已成功加载")
+                        loaded_files.append(fullfilepath)
                     except Exception as e:
                         print(e)
                         print(f"{file} 未能成功加载")
@@ -83,14 +87,21 @@ class LocalDocQA:
                 try:
                     docs += load_file(file)
                     print(f"{file} 已成功加载")
+                    loaded_files.append(file)
                 except Exception as e:
                     print(e)
                     print(f"{file} 未能成功加载")
 
-        vector_store = FAISS.from_documents(docs, self.embeddings)
-        vs_path = f"""./vector_store/{os.path.splitext(file)[0]}_FAISS_{datetime.datetime.now().strftime("%Y%m%d_%H%M%S")}"""
+        if vs_path and os.path.isdir(vs_path):
+            vector_store = FAISS.load_local(vs_path, self.embeddings)
+            vector_store.add_documents(docs)
+        else:
+            if not vs_path:
+                vs_path = f"""{VS_ROOT_PATH}{os.path.splitext(file)[0]}_FAISS_{datetime.datetime.now().strftime("%Y%m%d_%H%M%S")}"""
+            vector_store = FAISS.from_documents(docs, self.embeddings)
+
         vector_store.save_local(vs_path)
-        return vs_path if len(docs) > 0 else None
+        return vs_path if len(docs) > 0 else None, loaded_files
 
     def get_knowledge_based_answer(self,
                                    query,

+ 1 - 1
cli_demo.py

@@ -24,7 +24,7 @@ if __name__ == "__main__":
     vs_path = None
     while not vs_path:
         filepath = input("Input your local knowledge file path 请输入本地知识文件路径:")
-        vs_path = local_doc_qa.init_knowledge_vector_store(filepath)
+        vs_path, _ = local_doc_qa.init_knowledge_vector_store(filepath)
     history = []
     while True:
         query = input("Input your question 请输入问题:")