api.py 2.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103
  1. from configs.model_config import *
  2. from chains.local_doc_qa import LocalDocQA
  3. import os
  4. import nltk
  5. import uvicorn
  6. from fastapi import FastAPI, File, UploadFile
  7. from pydantic import BaseModel
  8. from starlette.responses import RedirectResponse
  9. app = FastAPI()
  10. global local_doc_qa, vs_path
  11. nltk.data.path = [os.path.join(os.path.dirname(__file__), "nltk_data")] + nltk.data.path
  12. # return top-k text chunk from vector store
  13. VECTOR_SEARCH_TOP_K = 10
  14. # LLM input history length
  15. LLM_HISTORY_LEN = 3
  16. # Show reply with source text from input document
  17. REPLY_WITH_SOURCE = False
  18. class Query(BaseModel):
  19. query: str
  20. @app.get('/')
  21. async def document():
  22. return RedirectResponse(url="/docs")
  23. @app.on_event("startup")
  24. async def get_local_doc_qa():
  25. global local_doc_qa
  26. local_doc_qa = LocalDocQA()
  27. local_doc_qa.init_cfg(llm_model=LLM_MODEL,
  28. embedding_model=EMBEDDING_MODEL,
  29. embedding_device=EMBEDDING_DEVICE,
  30. llm_history_len=LLM_HISTORY_LEN,
  31. top_k=VECTOR_SEARCH_TOP_K)
  32. @app.post("/file")
  33. async def upload_file(UserFile: UploadFile=File(...)):
  34. global vs_path
  35. response = {
  36. "msg": None,
  37. "status": 0
  38. }
  39. try:
  40. filepath = './content/' + UserFile.filename
  41. content = await UserFile.read()
  42. # print(UserFile.filename)
  43. with open(filepath, 'wb') as f:
  44. f.write(content)
  45. vs_path, files = local_doc_qa.init_knowledge_vector_store(filepath)
  46. response = {
  47. 'msg': 'seccess' if len(files)>0 else 'fail',
  48. 'status': 1 if len(files)>0 else 0,
  49. 'loaded_files': files
  50. }
  51. except Exception as err:
  52. response["message"] = err
  53. return response
  54. @app.post("/qa")
  55. async def get_answer(UserQuery: Query):
  56. response = {
  57. "status": 0,
  58. "message": "",
  59. "answer": None
  60. }
  61. global vs_path
  62. history = []
  63. try:
  64. resp, history = local_doc_qa.get_knowledge_based_answer(query=UserQuery.query,
  65. vs_path=vs_path,
  66. chat_history=history)
  67. if REPLY_WITH_SOURCE:
  68. response["answer"] = resp
  69. else:
  70. response['answer'] = resp["result"]
  71. response["message"] = 'successful'
  72. response["status"] = 1
  73. except Exception as err:
  74. response["message"] = err
  75. return response
  76. if __name__ == "__main__":
  77. uvicorn.run(
  78. app='api:app',
  79. host='0.0.0.0',
  80. port=8100,
  81. reload = True,
  82. )