浏览代码

完善知识库路径问题,完善api接口 (#245)

* Fix 知识库无法上载,NLTK_DATA_PATH路径错误 (#236)

* Update chatglm_llm.py (#242)

* 完善知识库路径问题,完善api接口

统一webui、API接口知识库路径,后续路径如下:
知识库路经就是:/项目代码文件夹/vector_store/'知识库名字'
文件存放路经:/项目代码文件夹/content/'知识库名字'

修复通过api接口创建知识库的BUG,完善API接口功能。

* Update model_config.py

---------

Co-authored-by: Bob Chang <bob-chang@outlook.com>
Co-authored-by: imClumsyPanda <littlepanda0716@gmail.com>
shrimp 2 年之前
父节点
当前提交
7497b261b3
共有 4 个文件被更改,包括 62 次插入94 次删除
  1. 55 80
      api.py
  2. 1 4
      configs/model_config.py
  3. 2 2
      models/chatglm_llm.py
  4. 4 8
      webui.py

+ 55 - 80
api.py

@@ -13,11 +13,10 @@ from fastapi import Body, FastAPI, File, Form, Query, UploadFile, WebSocket
 from fastapi.openapi.utils import get_openapi
 from pydantic import BaseModel
 from typing_extensions import Annotated
-
+from starlette.responses import RedirectResponse
 from chains.local_doc_qa import LocalDocQA
-from configs.model_config import (API_UPLOAD_ROOT_PATH, EMBEDDING_DEVICE,
-                                  EMBEDDING_MODEL, LLM_MODEL, NLTK_DATA_PATH,
-                                  VECTOR_SEARCH_TOP_K, LLM_HISTORY_LEN)
+from configs.model_config import (VS_ROOT_PATH, EMBEDDING_DEVICE, EMBEDDING_MODEL, LLM_MODEL, UPLOAD_ROOT_PATH,
+                                  NLTK_DATA_PATH, VECTOR_SEARCH_TOP_K, LLM_HISTORY_LEN)
 
 nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path
 
@@ -76,37 +75,47 @@ class ChatMessage(BaseModel):
 
 
 def get_folder_path(local_doc_id: str):
-    return os.path.join(API_UPLOAD_ROOT_PATH, local_doc_id)
+    return os.path.join(UPLOAD_ROOT_PATH, local_doc_id)
 
 
 def get_vs_path(local_doc_id: str):
-    return os.path.join(API_UPLOAD_ROOT_PATH, local_doc_id, "vector_store")
+    return os.path.join(VS_ROOT_PATH, local_doc_id)
 
 
 def get_file_path(local_doc_id: str, doc_name: str):
-    return os.path.join(API_UPLOAD_ROOT_PATH, local_doc_id, doc_name)
+    return os.path.join(UPLOAD_ROOT_PATH, local_doc_id, doc_name)
 
 
 async def upload_file(
-    files: Annotated[
-        List[UploadFile], File(description="Multiple files as UploadFile")
-    ],
-    knowledge_base_id: str = Form(..., description="Knowledge Base Name", example="kb1"),
+        files: Annotated[
+            List[UploadFile], File(description="Multiple files as UploadFile")
+        ],
+        knowledge_base_id: str = Form(..., description="Knowledge Base Name", example="kb1"),
 ):
     saved_path = get_folder_path(knowledge_base_id)
     if not os.path.exists(saved_path):
         os.makedirs(saved_path)
+    filelist = []
     for file in files:
+        file_content = ''
         file_path = os.path.join(saved_path, file.filename)
-        with open(file_path, "wb") as f:
-            f.write(file.file.read())
-
-    local_doc_qa.init_knowledge_vector_store(saved_path, get_vs_path(knowledge_base_id))
-    return BaseResponse()
+        file_content = file.file.read()
+        if os.path.exists(file_path) and os.path.getsize(file_path) == len(file_content):
+            continue
+        with open(file_path, "ab+") as f:
+            f.write(file_content)
+        filelist.append(file_path)
+    if filelist:
+        vs_path, loaded_files = local_doc_qa.init_knowledge_vector_store(filelist, get_vs_path(knowledge_base_id))
+        if len(loaded_files):
+            file_status = f"已上传 {'、'.join([os.path.split(i)[-1] for i in loaded_files])} 至知识库,并已加载知识库,请开始提问"
+            return BaseResponse(code=200, msg=file_status)
+    file_status = "文件未成功加载,请重新上传文件"
+    return BaseResponse(code=500, msg=file_status)
 
 
 async def list_docs(
-    knowledge_base_id: Optional[str] = Query(description="Knowledge Base Name", example="kb1")
+        knowledge_base_id: Optional[str] = Query(description="Knowledge Base Name", example="kb1")
 ):
     if knowledge_base_id:
         local_doc_folder = get_folder_path(knowledge_base_id)
@@ -119,25 +128,27 @@ async def list_docs(
         ]
         return ListDocsResponse(data=all_doc_names)
     else:
-        if not os.path.exists(API_UPLOAD_ROOT_PATH):
+        if not os.path.exists(UPLOAD_ROOT_PATH):
             all_doc_ids = []
         else:
             all_doc_ids = [
                 folder
-                for folder in os.listdir(API_UPLOAD_ROOT_PATH)
-                if os.path.isdir(os.path.join(API_UPLOAD_ROOT_PATH, folder))
+                for folder in os.listdir(UPLOAD_ROOT_PATH)
+                if os.path.isdir(os.path.join(UPLOAD_ROOT_PATH, folder))
             ]
 
         return ListDocsResponse(data=all_doc_ids)
 
 
 async def delete_docs(
-    knowledge_base_id: str = Form(..., description="Knowledge Base Name", example="kb1"),
-    doc_name: Optional[str] = Form(
-        None, description="doc name", example="doc_name_1.pdf"
-    ),
+        knowledge_base_id: str = Form(...,
+                                      description="Knowledge Base Name(注意此方法仅删除上传的文件并不会删除知识库(FAISS)内数据)",
+                                      example="kb1"),
+        doc_name: Optional[str] = Form(
+            None, description="doc name", example="doc_name_1.pdf"
+        ),
 ):
-    if not os.path.exists(os.path.join(API_UPLOAD_ROOT_PATH, knowledge_base_id)):
+    if not os.path.exists(os.path.join(UPLOAD_ROOT_PATH, knowledge_base_id)):
         return {"code": 1, "msg": f"Knowledge base {knowledge_base_id} not found"}
     if doc_name:
         doc_path = get_file_path(knowledge_base_id, doc_name)
@@ -159,25 +170,25 @@ async def delete_docs(
 
 
 async def chat(
-    knowledge_base_id: str = Body(..., description="Knowledge Base Name", example="kb1"),
-    question: str = Body(..., description="Question", example="工伤保险是什么?"),
-    history: List[List[str]] = Body(
-        [],
-        description="History of previous questions and answers",
-        example=[
-            [
-                "工伤保险是什么?",
-                "工伤保险是指用人单位按照国家规定,为本单位的职工和用人单位的其他人员,缴纳工伤保险费,由保险机构按照国家规定的标准,给予工伤保险待遇的社会保险制度。",
-            ]
-        ],
-    ),
+        knowledge_base_id: str = Body(..., description="Knowledge Base Name", example="kb1"),
+        question: str = Body(..., description="Question", example="工伤保险是什么?"),
+        history: List[List[str]] = Body(
+            [],
+            description="History of previous questions and answers",
+            example=[
+                [
+                    "工伤保险是什么?",
+                    "工伤保险是指用人单位按照国家规定,为本单位的职工和用人单位的其他人员,缴纳工伤保险费,由保险机构按照国家规定的标准,给予工伤保险待遇的社会保险制度。",
+                ]
+            ],
+        ),
 ):
-    vs_path = os.path.join(API_UPLOAD_ROOT_PATH, knowledge_base_id, "vector_store")
+    vs_path = os.path.join(VS_ROOT_PATH, knowledge_base_id)
     if not os.path.exists(vs_path):
         raise ValueError(f"Knowledge base {knowledge_base_id} not found")
 
     for resp, history in local_doc_qa.get_knowledge_based_answer(
-        query=question, vs_path=vs_path, chat_history=history, streaming=True
+            query=question, vs_path=vs_path, chat_history=history, streaming=True
     ):
         pass
     source_documents = [
@@ -196,7 +207,7 @@ async def chat(
 
 async def stream_chat(websocket: WebSocket, knowledge_base_id: str):
     await websocket.accept()
-    vs_path = os.path.join(API_UPLOAD_ROOT_PATH, knowledge_base_id, "vector_store")
+    vs_path = os.path.join(VS_ROOT_PATH, knowledge_base_id)
 
     if not os.path.exists(vs_path):
         await websocket.send_json({"error": f"Knowledge base {knowledge_base_id} not found"})
@@ -211,7 +222,7 @@ async def stream_chat(websocket: WebSocket, knowledge_base_id: str):
 
         last_print_len = 0
         for resp, history in local_doc_qa.get_knowledge_based_answer(
-            query=question, vs_path=vs_path, chat_history=history, streaming=True
+                query=question, vs_path=vs_path, chat_history=history, streaming=True
         ):
             await websocket.send_text(resp["result"][last_print_len:])
             last_print_len = len(resp["result"])
@@ -236,40 +247,8 @@ async def stream_chat(websocket: WebSocket, knowledge_base_id: str):
         turn += 1
 
 
-def gen_docs():
-    global app
-    with tempfile.NamedTemporaryFile("w", encoding="utf-8", suffix=".json") as f:
-        json.dump(
-            get_openapi(
-                title=app.title,
-                version=app.version,
-                openapi_version=app.openapi_version,
-                description=app.description,
-                routes=app.routes,
-            ),
-            f,
-            ensure_ascii=False,
-        )
-        f.flush()
-        # test whether widdershins is available
-        try:
-            subprocess.run(
-                [
-                    "widdershins",
-                    f.name,
-                    "-o",
-                    os.path.join(
-                        os.path.dirname(os.path.abspath(__file__)),
-                        "docs",
-                        "API.md",
-                    ),
-                ],
-                check=True,
-            )
-        except Exception:
-            raise RuntimeError(
-                "Failed to generate docs. Please install widdershins first."
-            )
+async def document():
+    return RedirectResponse(url="/docs")
 
 
 def main():
@@ -278,7 +257,6 @@ def main():
     parser = argparse.ArgumentParser()
     parser.add_argument("--host", type=str, default="0.0.0.0")
     parser.add_argument("--port", type=int, default=7861)
-    parser.add_argument("--gen-docs", action="store_true")
     args = parser.parse_args()
 
     app = FastAPI()
@@ -287,10 +265,7 @@ def main():
     app.post("/chat-docs/upload", response_model=BaseResponse)(upload_file)
     app.get("/chat-docs/list", response_model=ListDocsResponse)(list_docs)
     app.delete("/chat-docs/delete", response_model=BaseResponse)(delete_docs)
-
-    if args.gen_docs:
-        gen_docs()
-        return
+    app.get("/", response_model=BaseResponse)(document)
 
     local_doc_qa = LocalDocQA()
     local_doc_qa.init_cfg(

+ 1 - 4
configs/model_config.py

@@ -28,7 +28,6 @@ llm_model_dict = {
 LLM_MODEL = "chatglm-6b"
 
 # LLM lora path,默认为空,如果有请直接指定文件夹路径
-# 推荐使用 chatglm-6b-belle-zh-lora
 LLM_LORA_PATH = ""
 USE_LORA = True if LLM_LORA_PATH else False
 
@@ -45,8 +44,6 @@ VS_ROOT_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "vector_
 
 UPLOAD_ROOT_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "content")
 
-API_UPLOAD_ROOT_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "api_content")
-
 # 基于上下文的prompt模版,请务必保留"{question}"和"{context}"
 PROMPT_TEMPLATE = """已知信息:
 {context} 
@@ -62,4 +59,4 @@ LLM_HISTORY_LEN = 3
 # return top-k text chunk from vector store
 VECTOR_SEARCH_TOP_K = 5
 
-NLTK_DATA_PATH = os.path.join(os.path.dirname(__file__), "nltk_data")
+NLTK_DATA_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "nltk_data")

+ 2 - 2
models/chatglm_llm.py

@@ -144,12 +144,12 @@ class ChatGLM(LLM):
                         config=model_config, **kwargs)
                 if LLM_LORA_PATH and use_lora:
                     from peft import PeftModel
-                    model_auto = PeftModel.from_pretrained(model, LLM_LORA_PATH)
+                    model = PeftModel.from_pretrained(model, LLM_LORA_PATH)
                 # 可传入device_map自定义每张卡的部署情况
                 if device_map is None:
                     device_map = auto_configure_device_map(num_gpus)
 
-                self.model = dispatch_model(model_auto.half(), device_map=device_map)
+                self.model = dispatch_model(model.half(), device_map=device_map)
         else:
             self.model = self.model.float().to(llm_device)
 

+ 4 - 8
webui.py

@@ -48,12 +48,6 @@ def get_answer(query, vs_path, history, mode,
             yield history, ""
 
 
-def update_status(history, status):
-    history = history + [[None, status]]
-    print(status)
-    return history
-
-
 def init_model():
     try:
         local_doc_qa.init_cfg()
@@ -92,10 +86,12 @@ def reinit_model(llm_model, embedding_model, llm_history_len, use_ptuning_v2, us
 def get_vector_store(vs_id, files, history):
     vs_path = os.path.join(VS_ROOT_PATH, vs_id)
     filelist = []
+    if not os.path.exists(os.path.join(UPLOAD_ROOT_PATH, vs_id)):
+        os.makedirs(os.path.join(UPLOAD_ROOT_PATH, vs_id))
     for file in files:
         filename = os.path.split(file.name)[-1]
-        shutil.move(file.name, os.path.join(UPLOAD_ROOT_PATH, filename))
-        filelist.append(os.path.join(UPLOAD_ROOT_PATH, filename))
+        shutil.move(file.name, os.path.join(UPLOAD_ROOT_PATH, vs_id, filename))
+        filelist.append(os.path.join(UPLOAD_ROOT_PATH, vs_id, filename))
     if local_doc_qa.llm and local_doc_qa.embeddings:
         vs_path, loaded_files = local_doc_qa.init_knowledge_vector_store(filelist, vs_path)
         if len(loaded_files):