|
@@ -1,103 +0,0 @@
|
|
-from configs.model_config import *
|
|
|
|
-from chains.local_doc_qa import LocalDocQA
|
|
|
|
-import os
|
|
|
|
-import nltk
|
|
|
|
-
|
|
|
|
-import uvicorn
|
|
|
|
-from fastapi import FastAPI, File, UploadFile
|
|
|
|
-from pydantic import BaseModel
|
|
|
|
-from starlette.responses import RedirectResponse
|
|
|
|
-
|
|
|
|
-app = FastAPI()
|
|
|
|
-
|
|
|
|
-global local_doc_qa, vs_path
|
|
|
|
-
|
|
|
|
-nltk.data.path = [os.path.join(os.path.dirname(__file__), "nltk_data")] + nltk.data.path
|
|
|
|
-
|
|
|
|
-# return top-k text chunk from vector store
|
|
|
|
-VECTOR_SEARCH_TOP_K = 10
|
|
|
|
-
|
|
|
|
-# LLM input history length
|
|
|
|
-LLM_HISTORY_LEN = 3
|
|
|
|
-
|
|
|
|
-# Show reply with source text from input document
|
|
|
|
-REPLY_WITH_SOURCE = False
|
|
|
|
-
|
|
|
|
-class Query(BaseModel):
|
|
|
|
- query: str
|
|
|
|
-
|
|
|
|
-@app.get('/')
|
|
|
|
-async def document():
|
|
|
|
- return RedirectResponse(url="/docs")
|
|
|
|
-
|
|
|
|
-@app.on_event("startup")
|
|
|
|
-async def get_local_doc_qa():
|
|
|
|
- global local_doc_qa
|
|
|
|
- local_doc_qa = LocalDocQA()
|
|
|
|
- local_doc_qa.init_cfg(llm_model=LLM_MODEL,
|
|
|
|
- embedding_model=EMBEDDING_MODEL,
|
|
|
|
- embedding_device=EMBEDDING_DEVICE,
|
|
|
|
- llm_history_len=LLM_HISTORY_LEN,
|
|
|
|
- top_k=VECTOR_SEARCH_TOP_K)
|
|
|
|
-
|
|
|
|
-
|
|
|
|
-@app.post("/file")
|
|
|
|
-async def upload_file(UserFile: UploadFile=File(...),):
|
|
|
|
- global vs_path
|
|
|
|
- response = {
|
|
|
|
- "msg": None,
|
|
|
|
- "status": 0
|
|
|
|
- }
|
|
|
|
- try:
|
|
|
|
- filepath = './content/' + UserFile.filename
|
|
|
|
- content = await UserFile.read()
|
|
|
|
- # print(UserFile.filename)
|
|
|
|
- with open(filepath, 'wb') as f:
|
|
|
|
- f.write(content)
|
|
|
|
- vs_path, files = local_doc_qa.init_knowledge_vector_store(filepath)
|
|
|
|
- response = {
|
|
|
|
- 'msg': 'seccess' if len(files)>0 else 'fail',
|
|
|
|
- 'status': 1 if len(files)>0 else 0,
|
|
|
|
- 'loaded_files': files
|
|
|
|
- }
|
|
|
|
-
|
|
|
|
- except Exception as err:
|
|
|
|
- response["message"] = err
|
|
|
|
-
|
|
|
|
- return response
|
|
|
|
-
|
|
|
|
-@app.post("/qa")
|
|
|
|
-async def get_answer(query: str = ""):
|
|
|
|
- response = {
|
|
|
|
- "status": 0,
|
|
|
|
- "message": "",
|
|
|
|
- "answer": None
|
|
|
|
- }
|
|
|
|
- global vs_path
|
|
|
|
- history = []
|
|
|
|
- try:
|
|
|
|
- resp, history = local_doc_qa.get_knowledge_based_answer(query=query,
|
|
|
|
- vs_path=vs_path,
|
|
|
|
- chat_history=history)
|
|
|
|
- if REPLY_WITH_SOURCE:
|
|
|
|
- response["answer"] = resp
|
|
|
|
- else:
|
|
|
|
- response['answer'] = resp["result"]
|
|
|
|
-
|
|
|
|
- response["message"] = 'successful'
|
|
|
|
- response["status"] = 1
|
|
|
|
-
|
|
|
|
- except Exception as err:
|
|
|
|
- response["message"] = err
|
|
|
|
-
|
|
|
|
- return response
|
|
|
|
-
|
|
|
|
-
|
|
|
|
-if __name__ == "__main__":
|
|
|
|
- uvicorn.run(
|
|
|
|
- app=app,
|
|
|
|
- host='0.0.0.0',
|
|
|
|
- port=8100,
|
|
|
|
- reload=True,
|
|
|
|
- )
|
|
|
|
-
|
|
|