imClumsyPanda 2 lat temu
rodzic
commit
2c1fd2bdd5
5 zmienionych plików z 32 dodań i 30 usunięć
  1. 22 22
      api.py
  2. 2 1
      chains/local_doc_qa.py
  3. 2 2
      configs/model_config.py
  4. 2 1
      utils/__init__.py
  5. 4 4
      webui.py

+ 22 - 22
api.py

@@ -97,9 +97,9 @@ async def upload_file(
     files: Annotated[
         List[UploadFile], File(description="Multiple files as UploadFile")
     ],
-    local_doc_id: str = Form(..., description="Local document ID", example="doc_id_1"),
+    knowledge_base_id: str = Form(..., description="Knowledge Base Name", example="kb1"),
 ):
-    saved_path = get_folder_path(local_doc_id)
+    saved_path = get_folder_path(knowledge_base_id)
     if not os.path.exists(saved_path):
         os.makedirs(saved_path)
     for file in files:
@@ -107,17 +107,17 @@ async def upload_file(
         with open(file_path, "wb") as f:
             f.write(file.file.read())
 
-    local_doc_qa.init_knowledge_vector_store(saved_path, get_vs_path(local_doc_id))
+    local_doc_qa.init_knowledge_vector_store(saved_path, get_vs_path(knowledge_base_id))
     return BaseResponse()
 
 
 async def list_docs(
-    local_doc_id: Optional[str] = Query(description="Document ID", example="doc_id1")
+    knowledge_base_id: Optional[str] = Query(description="Knowledge Base Name", example="kb1")
 ):
-    if local_doc_id:
-        local_doc_folder = get_folder_path(local_doc_id)
+    if knowledge_base_id:
+        local_doc_folder = get_folder_path(knowledge_base_id)
         if not os.path.exists(local_doc_folder):
-            return {"code": 1, "msg": f"document {local_doc_id} not found"}
+            return {"code": 1, "msg": f"Knowledge base {knowledge_base_id} not found"}
         all_doc_names = [
             doc
             for doc in os.listdir(local_doc_folder)
@@ -138,34 +138,34 @@ async def list_docs(
 
 
 async def delete_docs(
-    local_doc_id: str = Form(..., description="local doc id", example="doc_id_1"),
+    knowledge_base_id: str = Form(..., description="Knowledge Base Name", example="kb1"),
     doc_name: Optional[str] = Form(
         None, description="doc name", example="doc_name_1.pdf"
     ),
 ):
-    if not os.path.exists(os.path.join(API_UPLOAD_ROOT_PATH, local_doc_id)):
-        return {"code": 1, "msg": f"document {local_doc_id} not found"}
+    if not os.path.exists(os.path.join(API_UPLOAD_ROOT_PATH, knowledge_base_id)):
+        return {"code": 1, "msg": f"Knowledge base {knowledge_base_id} not found"}
     if doc_name:
-        doc_path = get_file_path(local_doc_id, doc_name)
+        doc_path = get_file_path(knowledge_base_id, doc_name)
         if os.path.exists(doc_path):
             os.remove(doc_path)
         else:
             return {"code": 1, "msg": f"document {doc_name} not found"}
 
-        remain_docs = await list_docs(local_doc_id)
+        remain_docs = await list_docs(knowledge_base_id)
         if remain_docs["code"] != 0 or len(remain_docs["data"]) == 0:
-            shutil.rmtree(get_folder_path(local_doc_id), ignore_errors=True)
+            shutil.rmtree(get_folder_path(knowledge_base_id), ignore_errors=True)
         else:
             local_doc_qa.init_knowledge_vector_store(
-                get_folder_path(local_doc_id), get_vs_path(local_doc_id)
+                get_folder_path(knowledge_base_id), get_vs_path(knowledge_base_id)
             )
     else:
-        shutil.rmtree(get_folder_path(local_doc_id))
+        shutil.rmtree(get_folder_path(knowledge_base_id))
     return BaseResponse()
 
 
 async def chat(
-    local_doc_id: str = Body(..., description="Document ID", example="doc_id1"),
+    knowledge_base_id: str = Body(..., description="Knowledge Base Name", example="kb1"),
     question: str = Body(..., description="Question", example="工伤保险是什么?"),
     history: List[List[str]] = Body(
         [],
@@ -178,9 +178,9 @@ async def chat(
         ],
     ),
 ):
-    vs_path = os.path.join(API_UPLOAD_ROOT_PATH, local_doc_id, "vector_store")
+    vs_path = os.path.join(API_UPLOAD_ROOT_PATH, knowledge_base_id, "vector_store")
     if not os.path.exists(vs_path):
-        raise ValueError(f"Document {local_doc_id} not found")
+        raise ValueError(f"Knowledge base {knowledge_base_id} not found")
 
     for resp, history in local_doc_qa.get_knowledge_based_answer(
         query=question, vs_path=vs_path, chat_history=history, streaming=True
@@ -200,12 +200,12 @@ async def chat(
     )
 
 
-async def stream_chat(websocket: WebSocket, local_doc_id: str):
+async def stream_chat(websocket: WebSocket, knowledge_base_id: str):
     await websocket.accept()
-    vs_path = os.path.join(API_UPLOAD_ROOT_PATH, local_doc_id, "vector_store")
+    vs_path = os.path.join(API_UPLOAD_ROOT_PATH, knowledge_base_id, "vector_store")
 
     if not os.path.exists(vs_path):
-        await websocket.send_json({"error": f"document {local_doc_id} not found"})
+        await websocket.send_json({"error": f"Knowledge base {knowledge_base_id} not found"})
         await websocket.close()
         return
 
@@ -288,7 +288,7 @@ def main():
     args = parser.parse_args()
 
     app = FastAPI()
-    app.websocket("/chat-docs/stream-chat/{local_doc_id}")(stream_chat)
+    app.websocket("/chat-docs/stream-chat/{knowledge_base_id}")(stream_chat)
     app.post("/chat-docs/chat", response_model=ChatMessage)(chat)
     app.post("/chat-docs/upload", response_model=BaseResponse)(upload_file)
     app.get("/chat-docs/list", response_model=ListDocsResponse)(list_docs)

+ 2 - 1
chains/local_doc_qa.py

@@ -184,7 +184,8 @@ class LocalDocQA:
                 torch_gc(DEVICE)
             else:
                 if not vs_path:
-                    vs_path = f"""{VS_ROOT_PATH}{os.path.splitext(file)[0]}_FAISS_{datetime.datetime.now().strftime("%Y%m%d_%H%M%S")}"""
+                    vs_path = os.path.join(VS_ROOT_PATH,
+                                           f"""{os.path.splitext(file)[0]}_FAISS_{datetime.datetime.now().strftime("%Y%m%d_%H%M%S")}""")
                 vector_store = FAISS.from_documents(docs, self.embeddings)
                 torch_gc(DEVICE)
 

+ 2 - 2
configs/model_config.py

@@ -36,9 +36,9 @@ USE_PTUNING_V2 = False
 # LLM running device
 LLM_DEVICE = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
 
-VS_ROOT_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "vector_store", "")
+VS_ROOT_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "vector_store")
 
-UPLOAD_ROOT_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "content", "")
+UPLOAD_ROOT_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "content")
 
 API_UPLOAD_ROOT_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "api_content")
 

+ 2 - 1
utils/__init__.py

@@ -7,7 +7,8 @@ def torch_gc(DEVICE):
             torch.cuda.ipc_collect()
     elif torch.backends.mps.is_available():
         try:
-            torch.mps.empty_cache()
+            from torch.mps import empty_cache
+            empty_cache()
         except Exception as e:
             print(e)
             print("如果您使用的是 macOS 建议将 pytorch 版本升级至 2.0.0 或更高版本,以支持及时清理 torch 产生的内存占用。")

+ 4 - 4
webui.py

@@ -95,12 +95,12 @@ def reinit_model(llm_model, embedding_model, llm_history_len, use_ptuning_v2, to
 
 
 def get_vector_store(vs_id, files, history):
-    vs_path = VS_ROOT_PATH + vs_id
+    vs_path = os.path.join(VS_ROOT_PATH, vs_id)
     filelist = []
     for file in files:
         filename = os.path.split(file.name)[-1]
-        shutil.move(file.name, UPLOAD_ROOT_PATH + filename)
-        filelist.append(UPLOAD_ROOT_PATH + filename)
+        shutil.move(file.name, os.path.join(UPLOAD_ROOT_PATH, filename))
+        filelist.append(os.path.join(UPLOAD_ROOT_PATH, filename))
     if local_doc_qa.llm and local_doc_qa.embeddings:
         vs_path, loaded_files = local_doc_qa.init_knowledge_vector_store(filelist, vs_path)
         if len(loaded_files):
@@ -118,7 +118,7 @@ def change_vs_name_input(vs_id):
     if vs_id == "新建知识库":
         return gr.update(visible=True), gr.update(visible=True), gr.update(visible=False), None
     else:
-        return gr.update(visible=False), gr.update(visible=False), gr.update(visible=True), VS_ROOT_PATH + vs_id
+        return gr.update(visible=False), gr.update(visible=False), gr.update(visible=True), os.path.join(VS_ROOT_PATH, vs_id)
 
 
 def change_mode(mode):