123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103 |
- from configs.model_config import *
- from chains.local_doc_qa import LocalDocQA
- import os
- import nltk
- import uvicorn
- from fastapi import FastAPI, File, UploadFile
- from pydantic import BaseModel
- from starlette.responses import RedirectResponse
- app = FastAPI()
- global local_doc_qa, vs_path
- nltk.data.path = [os.path.join(os.path.dirname(__file__), "nltk_data")] + nltk.data.path
- # return top-k text chunk from vector store
- VECTOR_SEARCH_TOP_K = 10
- # LLM input history length
- LLM_HISTORY_LEN = 3
- # Show reply with source text from input document
- REPLY_WITH_SOURCE = False
- class Query(BaseModel):
- query: str
- @app.get('/')
- async def document():
- return RedirectResponse(url="/docs")
- @app.on_event("startup")
- async def get_local_doc_qa():
- global local_doc_qa
- local_doc_qa = LocalDocQA()
- local_doc_qa.init_cfg(llm_model=LLM_MODEL,
- embedding_model=EMBEDDING_MODEL,
- embedding_device=EMBEDDING_DEVICE,
- llm_history_len=LLM_HISTORY_LEN,
- top_k=VECTOR_SEARCH_TOP_K)
-
- @app.post("/file")
- async def upload_file(UserFile: UploadFile=File(...)):
- global vs_path
- response = {
- "msg": None,
- "status": 0
- }
- try:
- filepath = './content/' + UserFile.filename
- content = await UserFile.read()
- # print(UserFile.filename)
- with open(filepath, 'wb') as f:
- f.write(content)
- vs_path, files = local_doc_qa.init_knowledge_vector_store(filepath)
- response = {
- 'msg': 'seccess' if len(files)>0 else 'fail',
- 'status': 1 if len(files)>0 else 0,
- 'loaded_files': files
- }
-
- except Exception as err:
- response["message"] = err
-
- return response
- @app.post("/qa")
- async def get_answer(UserQuery: Query):
- response = {
- "status": 0,
- "message": "",
- "answer": None
- }
- global vs_path
- history = []
- try:
- resp, history = local_doc_qa.get_knowledge_based_answer(query=UserQuery.query,
- vs_path=vs_path,
- chat_history=history)
- if REPLY_WITH_SOURCE:
- response["answer"] = resp
- else:
- response['answer'] = resp["result"]
-
- response["message"] = 'successful'
- response["status"] = 1
- except Exception as err:
- response["message"] = err
-
- return response
- if __name__ == "__main__":
- uvicorn.run(
- app='api:app',
- host='0.0.0.0',
- port=8100,
- reload = True,
- )
|