|
@@ -13,11 +13,10 @@ from fastapi import Body, FastAPI, File, Form, Query, UploadFile, WebSocket
|
|
from fastapi.openapi.utils import get_openapi
|
|
from fastapi.openapi.utils import get_openapi
|
|
from pydantic import BaseModel
|
|
from pydantic import BaseModel
|
|
from typing_extensions import Annotated
|
|
from typing_extensions import Annotated
|
|
-
|
|
|
|
|
|
+from starlette.responses import RedirectResponse
|
|
from chains.local_doc_qa import LocalDocQA
|
|
from chains.local_doc_qa import LocalDocQA
|
|
-from configs.model_config import (API_UPLOAD_ROOT_PATH, EMBEDDING_DEVICE,
|
|
|
|
- EMBEDDING_MODEL, LLM_MODEL, NLTK_DATA_PATH,
|
|
|
|
- VECTOR_SEARCH_TOP_K, LLM_HISTORY_LEN)
|
|
|
|
|
|
+from configs.model_config import (VS_ROOT_PATH, EMBEDDING_DEVICE, EMBEDDING_MODEL, LLM_MODEL, UPLOAD_ROOT_PATH,
|
|
|
|
+ NLTK_DATA_PATH, VECTOR_SEARCH_TOP_K, LLM_HISTORY_LEN)
|
|
|
|
|
|
nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path
|
|
nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path
|
|
|
|
|
|
@@ -76,37 +75,47 @@ class ChatMessage(BaseModel):
|
|
|
|
|
|
|
|
|
|
def get_folder_path(local_doc_id: str):
|
|
def get_folder_path(local_doc_id: str):
|
|
- return os.path.join(API_UPLOAD_ROOT_PATH, local_doc_id)
|
|
|
|
|
|
+ return os.path.join(UPLOAD_ROOT_PATH, local_doc_id)
|
|
|
|
|
|
|
|
|
|
def get_vs_path(local_doc_id: str):
|
|
def get_vs_path(local_doc_id: str):
|
|
- return os.path.join(API_UPLOAD_ROOT_PATH, local_doc_id, "vector_store")
|
|
|
|
|
|
+ return os.path.join(VS_ROOT_PATH, local_doc_id)
|
|
|
|
|
|
|
|
|
|
def get_file_path(local_doc_id: str, doc_name: str):
|
|
def get_file_path(local_doc_id: str, doc_name: str):
|
|
- return os.path.join(API_UPLOAD_ROOT_PATH, local_doc_id, doc_name)
|
|
|
|
|
|
+ return os.path.join(UPLOAD_ROOT_PATH, local_doc_id, doc_name)
|
|
|
|
|
|
|
|
|
|
async def upload_file(
|
|
async def upload_file(
|
|
- files: Annotated[
|
|
|
|
- List[UploadFile], File(description="Multiple files as UploadFile")
|
|
|
|
- ],
|
|
|
|
- knowledge_base_id: str = Form(..., description="Knowledge Base Name", example="kb1"),
|
|
|
|
|
|
+ files: Annotated[
|
|
|
|
+ List[UploadFile], File(description="Multiple files as UploadFile")
|
|
|
|
+ ],
|
|
|
|
+ knowledge_base_id: str = Form(..., description="Knowledge Base Name", example="kb1"),
|
|
):
|
|
):
|
|
saved_path = get_folder_path(knowledge_base_id)
|
|
saved_path = get_folder_path(knowledge_base_id)
|
|
if not os.path.exists(saved_path):
|
|
if not os.path.exists(saved_path):
|
|
os.makedirs(saved_path)
|
|
os.makedirs(saved_path)
|
|
|
|
+ filelist = []
|
|
for file in files:
|
|
for file in files:
|
|
|
|
+ file_content = ''
|
|
file_path = os.path.join(saved_path, file.filename)
|
|
file_path = os.path.join(saved_path, file.filename)
|
|
- with open(file_path, "wb") as f:
|
|
|
|
- f.write(file.file.read())
|
|
|
|
-
|
|
|
|
- local_doc_qa.init_knowledge_vector_store(saved_path, get_vs_path(knowledge_base_id))
|
|
|
|
- return BaseResponse()
|
|
|
|
|
|
+ file_content = file.file.read()
|
|
|
|
+ if os.path.exists(file_path) and os.path.getsize(file_path) == len(file_content):
|
|
|
|
+ continue
|
|
|
|
+ with open(file_path, "ab+") as f:
|
|
|
|
+ f.write(file_content)
|
|
|
|
+ filelist.append(file_path)
|
|
|
|
+ if filelist:
|
|
|
|
+ vs_path, loaded_files = local_doc_qa.init_knowledge_vector_store(filelist, get_vs_path(knowledge_base_id))
|
|
|
|
+ if len(loaded_files):
|
|
|
|
+ file_status = f"已上传 {'、'.join([os.path.split(i)[-1] for i in loaded_files])} 至知识库,并已加载知识库,请开始提问"
|
|
|
|
+ return BaseResponse(code=200, msg=file_status)
|
|
|
|
+ file_status = "文件未成功加载,请重新上传文件"
|
|
|
|
+ return BaseResponse(code=500, msg=file_status)
|
|
|
|
|
|
|
|
|
|
async def list_docs(
|
|
async def list_docs(
|
|
- knowledge_base_id: Optional[str] = Query(description="Knowledge Base Name", example="kb1")
|
|
|
|
|
|
+ knowledge_base_id: Optional[str] = Query(description="Knowledge Base Name", example="kb1")
|
|
):
|
|
):
|
|
if knowledge_base_id:
|
|
if knowledge_base_id:
|
|
local_doc_folder = get_folder_path(knowledge_base_id)
|
|
local_doc_folder = get_folder_path(knowledge_base_id)
|
|
@@ -119,25 +128,27 @@ async def list_docs(
|
|
]
|
|
]
|
|
return ListDocsResponse(data=all_doc_names)
|
|
return ListDocsResponse(data=all_doc_names)
|
|
else:
|
|
else:
|
|
- if not os.path.exists(API_UPLOAD_ROOT_PATH):
|
|
|
|
|
|
+ if not os.path.exists(UPLOAD_ROOT_PATH):
|
|
all_doc_ids = []
|
|
all_doc_ids = []
|
|
else:
|
|
else:
|
|
all_doc_ids = [
|
|
all_doc_ids = [
|
|
folder
|
|
folder
|
|
- for folder in os.listdir(API_UPLOAD_ROOT_PATH)
|
|
|
|
- if os.path.isdir(os.path.join(API_UPLOAD_ROOT_PATH, folder))
|
|
|
|
|
|
+ for folder in os.listdir(UPLOAD_ROOT_PATH)
|
|
|
|
+ if os.path.isdir(os.path.join(UPLOAD_ROOT_PATH, folder))
|
|
]
|
|
]
|
|
|
|
|
|
return ListDocsResponse(data=all_doc_ids)
|
|
return ListDocsResponse(data=all_doc_ids)
|
|
|
|
|
|
|
|
|
|
async def delete_docs(
|
|
async def delete_docs(
|
|
- knowledge_base_id: str = Form(..., description="Knowledge Base Name", example="kb1"),
|
|
|
|
- doc_name: Optional[str] = Form(
|
|
|
|
- None, description="doc name", example="doc_name_1.pdf"
|
|
|
|
- ),
|
|
|
|
|
|
+ knowledge_base_id: str = Form(...,
|
|
|
|
+ description="Knowledge Base Name(注意此方法仅删除上传的文件并不会删除知识库(FAISS)内数据)",
|
|
|
|
+ example="kb1"),
|
|
|
|
+ doc_name: Optional[str] = Form(
|
|
|
|
+ None, description="doc name", example="doc_name_1.pdf"
|
|
|
|
+ ),
|
|
):
|
|
):
|
|
- if not os.path.exists(os.path.join(API_UPLOAD_ROOT_PATH, knowledge_base_id)):
|
|
|
|
|
|
+ if not os.path.exists(os.path.join(UPLOAD_ROOT_PATH, knowledge_base_id)):
|
|
return {"code": 1, "msg": f"Knowledge base {knowledge_base_id} not found"}
|
|
return {"code": 1, "msg": f"Knowledge base {knowledge_base_id} not found"}
|
|
if doc_name:
|
|
if doc_name:
|
|
doc_path = get_file_path(knowledge_base_id, doc_name)
|
|
doc_path = get_file_path(knowledge_base_id, doc_name)
|
|
@@ -159,25 +170,25 @@ async def delete_docs(
|
|
|
|
|
|
|
|
|
|
async def chat(
|
|
async def chat(
|
|
- knowledge_base_id: str = Body(..., description="Knowledge Base Name", example="kb1"),
|
|
|
|
- question: str = Body(..., description="Question", example="工伤保险是什么?"),
|
|
|
|
- history: List[List[str]] = Body(
|
|
|
|
- [],
|
|
|
|
- description="History of previous questions and answers",
|
|
|
|
- example=[
|
|
|
|
- [
|
|
|
|
- "工伤保险是什么?",
|
|
|
|
- "工伤保险是指用人单位按照国家规定,为本单位的职工和用人单位的其他人员,缴纳工伤保险费,由保险机构按照国家规定的标准,给予工伤保险待遇的社会保险制度。",
|
|
|
|
- ]
|
|
|
|
- ],
|
|
|
|
- ),
|
|
|
|
|
|
+ knowledge_base_id: str = Body(..., description="Knowledge Base Name", example="kb1"),
|
|
|
|
+ question: str = Body(..., description="Question", example="工伤保险是什么?"),
|
|
|
|
+ history: List[List[str]] = Body(
|
|
|
|
+ [],
|
|
|
|
+ description="History of previous questions and answers",
|
|
|
|
+ example=[
|
|
|
|
+ [
|
|
|
|
+ "工伤保险是什么?",
|
|
|
|
+ "工伤保险是指用人单位按照国家规定,为本单位的职工和用人单位的其他人员,缴纳工伤保险费,由保险机构按照国家规定的标准,给予工伤保险待遇的社会保险制度。",
|
|
|
|
+ ]
|
|
|
|
+ ],
|
|
|
|
+ ),
|
|
):
|
|
):
|
|
- vs_path = os.path.join(API_UPLOAD_ROOT_PATH, knowledge_base_id, "vector_store")
|
|
|
|
|
|
+ vs_path = os.path.join(VS_ROOT_PATH, knowledge_base_id)
|
|
if not os.path.exists(vs_path):
|
|
if not os.path.exists(vs_path):
|
|
raise ValueError(f"Knowledge base {knowledge_base_id} not found")
|
|
raise ValueError(f"Knowledge base {knowledge_base_id} not found")
|
|
|
|
|
|
for resp, history in local_doc_qa.get_knowledge_based_answer(
|
|
for resp, history in local_doc_qa.get_knowledge_based_answer(
|
|
- query=question, vs_path=vs_path, chat_history=history, streaming=True
|
|
|
|
|
|
+ query=question, vs_path=vs_path, chat_history=history, streaming=True
|
|
):
|
|
):
|
|
pass
|
|
pass
|
|
source_documents = [
|
|
source_documents = [
|
|
@@ -196,7 +207,7 @@ async def chat(
|
|
|
|
|
|
async def stream_chat(websocket: WebSocket, knowledge_base_id: str):
|
|
async def stream_chat(websocket: WebSocket, knowledge_base_id: str):
|
|
await websocket.accept()
|
|
await websocket.accept()
|
|
- vs_path = os.path.join(API_UPLOAD_ROOT_PATH, knowledge_base_id, "vector_store")
|
|
|
|
|
|
+ vs_path = os.path.join(VS_ROOT_PATH, knowledge_base_id)
|
|
|
|
|
|
if not os.path.exists(vs_path):
|
|
if not os.path.exists(vs_path):
|
|
await websocket.send_json({"error": f"Knowledge base {knowledge_base_id} not found"})
|
|
await websocket.send_json({"error": f"Knowledge base {knowledge_base_id} not found"})
|
|
@@ -211,7 +222,7 @@ async def stream_chat(websocket: WebSocket, knowledge_base_id: str):
|
|
|
|
|
|
last_print_len = 0
|
|
last_print_len = 0
|
|
for resp, history in local_doc_qa.get_knowledge_based_answer(
|
|
for resp, history in local_doc_qa.get_knowledge_based_answer(
|
|
- query=question, vs_path=vs_path, chat_history=history, streaming=True
|
|
|
|
|
|
+ query=question, vs_path=vs_path, chat_history=history, streaming=True
|
|
):
|
|
):
|
|
await websocket.send_text(resp["result"][last_print_len:])
|
|
await websocket.send_text(resp["result"][last_print_len:])
|
|
last_print_len = len(resp["result"])
|
|
last_print_len = len(resp["result"])
|
|
@@ -236,40 +247,8 @@ async def stream_chat(websocket: WebSocket, knowledge_base_id: str):
|
|
turn += 1
|
|
turn += 1
|
|
|
|
|
|
|
|
|
|
-def gen_docs():
|
|
|
|
- global app
|
|
|
|
- with tempfile.NamedTemporaryFile("w", encoding="utf-8", suffix=".json") as f:
|
|
|
|
- json.dump(
|
|
|
|
- get_openapi(
|
|
|
|
- title=app.title,
|
|
|
|
- version=app.version,
|
|
|
|
- openapi_version=app.openapi_version,
|
|
|
|
- description=app.description,
|
|
|
|
- routes=app.routes,
|
|
|
|
- ),
|
|
|
|
- f,
|
|
|
|
- ensure_ascii=False,
|
|
|
|
- )
|
|
|
|
- f.flush()
|
|
|
|
- # test whether widdershins is available
|
|
|
|
- try:
|
|
|
|
- subprocess.run(
|
|
|
|
- [
|
|
|
|
- "widdershins",
|
|
|
|
- f.name,
|
|
|
|
- "-o",
|
|
|
|
- os.path.join(
|
|
|
|
- os.path.dirname(os.path.abspath(__file__)),
|
|
|
|
- "docs",
|
|
|
|
- "API.md",
|
|
|
|
- ),
|
|
|
|
- ],
|
|
|
|
- check=True,
|
|
|
|
- )
|
|
|
|
- except Exception:
|
|
|
|
- raise RuntimeError(
|
|
|
|
- "Failed to generate docs. Please install widdershins first."
|
|
|
|
- )
|
|
|
|
|
|
+async def document():
|
|
|
|
+ return RedirectResponse(url="/docs")
|
|
|
|
|
|
|
|
|
|
def main():
|
|
def main():
|
|
@@ -278,7 +257,6 @@ def main():
|
|
parser = argparse.ArgumentParser()
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("--host", type=str, default="0.0.0.0")
|
|
parser.add_argument("--host", type=str, default="0.0.0.0")
|
|
parser.add_argument("--port", type=int, default=7861)
|
|
parser.add_argument("--port", type=int, default=7861)
|
|
- parser.add_argument("--gen-docs", action="store_true")
|
|
|
|
args = parser.parse_args()
|
|
args = parser.parse_args()
|
|
|
|
|
|
app = FastAPI()
|
|
app = FastAPI()
|
|
@@ -287,10 +265,7 @@ def main():
|
|
app.post("/chat-docs/upload", response_model=BaseResponse)(upload_file)
|
|
app.post("/chat-docs/upload", response_model=BaseResponse)(upload_file)
|
|
app.get("/chat-docs/list", response_model=ListDocsResponse)(list_docs)
|
|
app.get("/chat-docs/list", response_model=ListDocsResponse)(list_docs)
|
|
app.delete("/chat-docs/delete", response_model=BaseResponse)(delete_docs)
|
|
app.delete("/chat-docs/delete", response_model=BaseResponse)(delete_docs)
|
|
-
|
|
|
|
- if args.gen_docs:
|
|
|
|
- gen_docs()
|
|
|
|
- return
|
|
|
|
|
|
+ app.get("/", response_model=BaseResponse)(document)
|
|
|
|
|
|
local_doc_qa = LocalDocQA()
|
|
local_doc_qa = LocalDocQA()
|
|
local_doc_qa.init_cfg(
|
|
local_doc_qa.init_cfg(
|