123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313 |
- import argparse
- import json
- import os
- import shutil
- import subprocess
- import tempfile
- from typing import List, Optional
- import nltk
- import pydantic
- import uvicorn
- from fastapi import Body, FastAPI, File, Form, Query, UploadFile, WebSocket
- from fastapi.openapi.utils import get_openapi
- from pydantic import BaseModel
- from typing_extensions import Annotated
- from chains.local_doc_qa import LocalDocQA
- from configs.model_config import (API_UPLOAD_ROOT_PATH, EMBEDDING_DEVICE,
- EMBEDDING_MODEL, LLM_MODEL)
- nltk.data.path = [os.path.join(os.path.dirname(__file__), "nltk_data")] + nltk.data.path
- # return top-k text chunk from vector store
- VECTOR_SEARCH_TOP_K = 6
- # LLM input history length
- LLM_HISTORY_LEN = 3
- class BaseResponse(BaseModel):
- code: int = pydantic.Field(200, description="HTTP status code")
- msg: str = pydantic.Field("success", description="HTTP status message")
- class Config:
- schema_extra = {
- "example": {
- "code": 200,
- "msg": "success",
- }
- }
- class ListDocsResponse(BaseResponse):
- data: List[str] = pydantic.Field(..., description="List of document names")
- class Config:
- schema_extra = {
- "example": {
- "code": 200,
- "msg": "success",
- "data": ["doc1.docx", "doc2.pdf", "doc3.txt"],
- }
- }
- class ChatMessage(BaseModel):
- question: str = pydantic.Field(..., description="Question text")
- response: str = pydantic.Field(..., description="Response text")
- history: List[List[str]] = pydantic.Field(..., description="History text")
- source_documents: List[str] = pydantic.Field(
- ..., description="List of source documents and their scores"
- )
- class Config:
- schema_extra = {
- "example": {
- "question": "工伤保险如何办理?",
- "response": "根据已知信息,可以总结如下:\n\n1. 参保单位为员工缴纳工伤保险费,以保障员工在发生工伤时能够获得相应的待遇。\n2. 不同地区的工伤保险缴费规定可能有所不同,需要向当地社保部门咨询以了解具体的缴费标准和规定。\n3. 工伤从业人员及其近亲属需要申请工伤认定,确认享受的待遇资格,并按时缴纳工伤保险费。\n4. 工伤保险待遇包括工伤医疗、康复、辅助器具配置费用、伤残待遇、工亡待遇、一次性工亡补助金等。\n5. 工伤保险待遇领取资格认证包括长期待遇领取人员认证和一次性待遇领取人员认证。\n6. 工伤保险基金支付的待遇项目包括工伤医疗待遇、康复待遇、辅助器具配置费用、一次性工亡补助金、丧葬补助金等。",
- "history": [
- [
- "工伤保险是什么?",
- "工伤保险是指用人单位按照国家规定,为本单位的职工和用人单位的其他人员,缴纳工伤保险费,由保险机构按照国家规定的标准,给予工伤保险待遇的社会保险制度。",
- ]
- ],
- "source_documents": [
- "出处 [1] 广州市单位从业的特定人员参加工伤保险办事指引.docx:\n\n\t( 一) 从业单位 (组织) 按“自愿参保”原则, 为未建 立劳动关系的特定从业人员单项参加工伤保险 、缴纳工伤保 险费。",
- "出处 [2] ...",
- "出处 [3] ...",
- ],
- }
- }
- def get_folder_path(local_doc_id: str):
- return os.path.join(API_UPLOAD_ROOT_PATH, local_doc_id)
- def get_vs_path(local_doc_id: str):
- return os.path.join(API_UPLOAD_ROOT_PATH, local_doc_id, "vector_store")
- def get_file_path(local_doc_id: str, doc_name: str):
- return os.path.join(API_UPLOAD_ROOT_PATH, local_doc_id, doc_name)
- async def upload_file(
- files: Annotated[
- List[UploadFile], File(description="Multiple files as UploadFile")
- ],
- knowledge_base_id: str = Form(..., description="Knowledge Base Name", example="kb1"),
- ):
- saved_path = get_folder_path(knowledge_base_id)
- if not os.path.exists(saved_path):
- os.makedirs(saved_path)
- for file in files:
- file_path = os.path.join(saved_path, file.filename)
- with open(file_path, "wb") as f:
- f.write(file.file.read())
- local_doc_qa.init_knowledge_vector_store(saved_path, get_vs_path(knowledge_base_id))
- return BaseResponse()
- async def list_docs(
- knowledge_base_id: Optional[str] = Query(description="Knowledge Base Name", example="kb1")
- ):
- if knowledge_base_id:
- local_doc_folder = get_folder_path(knowledge_base_id)
- if not os.path.exists(local_doc_folder):
- return {"code": 1, "msg": f"Knowledge base {knowledge_base_id} not found"}
- all_doc_names = [
- doc
- for doc in os.listdir(local_doc_folder)
- if os.path.isfile(os.path.join(local_doc_folder, doc))
- ]
- return ListDocsResponse(data=all_doc_names)
- else:
- if not os.path.exists(API_UPLOAD_ROOT_PATH):
- all_doc_ids = []
- else:
- all_doc_ids = [
- folder
- for folder in os.listdir(API_UPLOAD_ROOT_PATH)
- if os.path.isdir(os.path.join(API_UPLOAD_ROOT_PATH, folder))
- ]
- return ListDocsResponse(data=all_doc_ids)
- async def delete_docs(
- knowledge_base_id: str = Form(..., description="Knowledge Base Name", example="kb1"),
- doc_name: Optional[str] = Form(
- None, description="doc name", example="doc_name_1.pdf"
- ),
- ):
- if not os.path.exists(os.path.join(API_UPLOAD_ROOT_PATH, knowledge_base_id)):
- return {"code": 1, "msg": f"Knowledge base {knowledge_base_id} not found"}
- if doc_name:
- doc_path = get_file_path(knowledge_base_id, doc_name)
- if os.path.exists(doc_path):
- os.remove(doc_path)
- else:
- return {"code": 1, "msg": f"document {doc_name} not found"}
- remain_docs = await list_docs(knowledge_base_id)
- if remain_docs["code"] != 0 or len(remain_docs["data"]) == 0:
- shutil.rmtree(get_folder_path(knowledge_base_id), ignore_errors=True)
- else:
- local_doc_qa.init_knowledge_vector_store(
- get_folder_path(knowledge_base_id), get_vs_path(knowledge_base_id)
- )
- else:
- shutil.rmtree(get_folder_path(knowledge_base_id))
- return BaseResponse()
- async def chat(
- knowledge_base_id: str = Body(..., description="Knowledge Base Name", example="kb1"),
- question: str = Body(..., description="Question", example="工伤保险是什么?"),
- history: List[List[str]] = Body(
- [],
- description="History of previous questions and answers",
- example=[
- [
- "工伤保险是什么?",
- "工伤保险是指用人单位按照国家规定,为本单位的职工和用人单位的其他人员,缴纳工伤保险费,由保险机构按照国家规定的标准,给予工伤保险待遇的社会保险制度。",
- ]
- ],
- ),
- ):
- vs_path = os.path.join(API_UPLOAD_ROOT_PATH, knowledge_base_id, "vector_store")
- if not os.path.exists(vs_path):
- raise ValueError(f"Knowledge base {knowledge_base_id} not found")
- for resp, history in local_doc_qa.get_knowledge_based_answer(
- query=question, vs_path=vs_path, chat_history=history, streaming=True
- ):
- pass
- source_documents = [
- f"""出处 [{inum + 1}] {os.path.split(doc.metadata['source'])[-1]}:\n\n{doc.page_content}\n\n"""
- f"""相关度:{doc.metadata['score']}\n\n"""
- for inum, doc in enumerate(resp["source_documents"])
- ]
- return ChatMessage(
- question=question,
- response=resp["result"],
- history=history,
- source_documents=source_documents,
- )
- async def stream_chat(websocket: WebSocket, knowledge_base_id: str):
- await websocket.accept()
- vs_path = os.path.join(API_UPLOAD_ROOT_PATH, knowledge_base_id, "vector_store")
- if not os.path.exists(vs_path):
- await websocket.send_json({"error": f"Knowledge base {knowledge_base_id} not found"})
- await websocket.close()
- return
- history = []
- turn = 1
- while True:
- question = await websocket.receive_text()
- await websocket.send_json({"question": question, "turn": turn, "flag": "start"})
- last_print_len = 0
- for resp, history in local_doc_qa.get_knowledge_based_answer(
- query=question, vs_path=vs_path, chat_history=history, streaming=True
- ):
- await websocket.send_text(resp["result"][last_print_len:])
- last_print_len = len(resp["result"])
- source_documents = [
- f"""出处 [{inum + 1}] {os.path.split(doc.metadata['source'])[-1]}:\n\n{doc.page_content}\n\n"""
- f"""相关度:{doc.metadata['score']}\n\n"""
- for inum, doc in enumerate(resp["source_documents"])
- ]
- await websocket.send_text(
- json.dumps(
- {
- "question": question,
- "turn": turn,
- "flag": "end",
- "sources_documents": source_documents,
- },
- ensure_ascii=False,
- )
- )
- turn += 1
- def gen_docs():
- global app
- with tempfile.NamedTemporaryFile("w", encoding="utf-8", suffix=".json") as f:
- json.dump(
- get_openapi(
- title=app.title,
- version=app.version,
- openapi_version=app.openapi_version,
- description=app.description,
- routes=app.routes,
- ),
- f,
- ensure_ascii=False,
- )
- f.flush()
- # test whether widdershins is available
- try:
- subprocess.run(
- [
- "widdershins",
- f.name,
- "-o",
- os.path.join(
- os.path.dirname(os.path.abspath(__file__)),
- "docs",
- "API.md",
- ),
- ],
- check=True,
- )
- except Exception:
- raise RuntimeError(
- "Failed to generate docs. Please install widdershins first."
- )
- def main():
- global app
- global local_doc_qa
- parser = argparse.ArgumentParser()
- parser.add_argument("--host", type=str, default="0.0.0.0")
- parser.add_argument("--port", type=int, default=7861)
- parser.add_argument("--gen-docs", action="store_true")
- args = parser.parse_args()
- app = FastAPI()
- app.websocket("/chat-docs/stream-chat/{knowledge_base_id}")(stream_chat)
- app.post("/chat-docs/chat", response_model=ChatMessage)(chat)
- app.post("/chat-docs/upload", response_model=BaseResponse)(upload_file)
- app.get("/chat-docs/list", response_model=ListDocsResponse)(list_docs)
- app.delete("/chat-docs/delete", response_model=BaseResponse)(delete_docs)
- if args.gen_docs:
- gen_docs()
- return
- local_doc_qa = LocalDocQA()
- local_doc_qa.init_cfg(
- llm_model=LLM_MODEL,
- embedding_model=EMBEDDING_MODEL,
- embedding_device=EMBEDDING_DEVICE,
- llm_history_len=LLM_HISTORY_LEN,
- top_k=VECTOR_SEARCH_TOP_K,
- )
- uvicorn.run(app, host=args.host, port=args.port)
- if __name__ == "__main__":
- main()
|