api.py 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210
  1. import sys
  2. sys.path.append("..") # 将父目录放入系统路径中
  3. from fastapi import FastAPI, Request, UploadFile, File
  4. from fastapi.responses import StreamingResponse
  5. import uvicorn, json, datetime, time
  6. from langchain.vectorstores import FAISS
  7. from starlette.middleware.cors import CORSMiddleware
  8. from langchain.embeddings.huggingface import HuggingFaceEmbeddings
  9. from langchain.vectorstores import Pinecone
  10. from langchain.chains import RetrievalQA
  11. from langchain.prompts.chat import (
  12. ChatPromptTemplate,
  13. SystemMessagePromptTemplate,
  14. HumanMessagePromptTemplate,
  15. )
  16. from langchain.document_loaders import UnstructuredFileLoader
  17. # pip install pinecone-client,记得换源
  18. import pinecone
  19. import sentence_transformers
  20. from configs import *
  21. from models import *
  22. # 写到 import torch前面,否则多显卡情况有异常
  23. import os
  24. os.environ['CUDA_VISIBLE_DEVICES'] = "0"
  25. app = FastAPI()
  26. # 解决跨域问题
  27. app.add_middleware(
  28. CORSMiddleware,
  29. allow_origins=["*"],
  30. allow_credentials=True,
  31. allow_methods=["*"],
  32. allow_headers=["*"]
  33. )
  34. # post请求,流式输出
  35. @app.post("/stream")
  36. async def create_stream_item(request: Request):
  37. json_post_raw = await request.json()
  38. json_post = json.dumps(json_post_raw)
  39. json_post_list = json.loads(json_post)
  40. query = json_post_list.get('prompt')
  41. history = json_post_list.get('history')
  42. vs_path = json_post_list.get('vs_path')
  43. print("chat_history========================================", history)
  44. print("开始查询========================================")
  45. # max_length = json_post_list.get('max_length')
  46. # top_p = json_post_list.get('top_p')
  47. # temperature = json_post_list.get('temperature')
  48. chatglm.history = history
  49. chatglm.chat_mode = ModelType.stream_chat
  50. knowledge_chain({"query": query})
  51. now = datetime.datetime.now()
  52. time_stamp = now.strftime("%Y-%m-%d %H:%M:%S")
  53. if vs_path is None or vs_path == "":
  54. answer = {
  55. "response": {"vs_path参数不能为空"},
  56. "status": 200,
  57. "time": time_stamp
  58. }
  59. return StreamingResponse(json.dumps(answer, ensure_ascii=False) + "\n",
  60. status_code=200, media_type="application/json")
  61. else:
  62. return StreamingResponse(chatglm.start_stream_chat(query, vs_path),
  63. status_code=200, media_type="application/json")
  64. @app.post("/uploadfile")
  65. async def create_upload_file(file: UploadFile = File(...)):
  66. # 指定文件保存路径
  67. file_path = model_config.UPLOAD_LOCAL_PATH + file.filename
  68. with open(file_path, "wb") as f:
  69. # 读取上传的文件内容并保存到指定路径
  70. f.write(file.file.read())
  71. f.close()
  72. try:
  73. loader = UnstructuredFileLoader(file_path, mode="elements")
  74. docs = loader.load()
  75. print(f"{file} 已成功加载")
  76. except:
  77. print(f"{file} 未能成功加载")
  78. vector_store = FAISS.from_documents(docs, embeddings)
  79. vs_path = f"""./vector_store/{os.path.splitext(file)[0]}_FAISS_{datetime.datetime.now().strftime("%Y%m%d_%H%M%S")}"""
  80. vector_store.save_local(vs_path)
  81. now = datetime.datetime.now()
  82. time_stamp = now.strftime("%Y-%m-%d %H:%M:%S")
  83. response = {"filename": file.filename, "local_vs_path": vs_path}
  84. answer = {
  85. "response": response,
  86. "status": 200,
  87. "time": time_stamp
  88. }
  89. return answer
  90. # post请求
  91. @app.post("/")
  92. async def create_item(request: Request):
  93. start_time = time.perf_counter()
  94. json_post_raw = await request.json()
  95. json_post = json.dumps(json_post_raw)
  96. json_post_list = json.loads(json_post)
  97. query = json_post_list.get('prompt')
  98. history = json_post_list.get('history')
  99. vs_path = json_post_list.get('vs_path')
  100. print("chat_history========================================", history)
  101. print("开始查询========================================")
  102. # max_length = json_post_list.get('max_length')
  103. # top_p = json_post_list.get('top_p')
  104. # temperature = json_post_list.get('temperature')
  105. chatglm.history = history
  106. chatglm.is_stream_chat = 0
  107. response = knowledge_chain({"query": query})
  108. # chatglm.history[-1][0] = query
  109. end_time = time.perf_counter()
  110. # 计算操作耗时
  111. elapsed_time = end_time - start_time
  112. # 输出耗时时间
  113. print("问答操作耗时: {:.6f} 秒".format(elapsed_time))
  114. now = datetime.datetime.now()
  115. time_stamp = now.strftime("%Y-%m-%d %H:%M:%S")
  116. answer = {
  117. "response": response,
  118. "history": [],
  119. "status": 200,
  120. "time": time_stamp
  121. }
  122. log = "[" + time_stamp + "] " + '", response:"' + repr(response) + '"'
  123. print(log)
  124. chatglm.torch_gc()
  125. print("answer=====>", answer)
  126. return answer
  127. def init_embedding():
  128. print("加载embeding模型......")
  129. embeddings = HuggingFaceEmbeddings(model_name=model_config.embedding_model_dict[model_config.EMBEDDING_MODEL])
  130. embeddings.client = sentence_transformers.SentenceTransformer(embeddings.model_name,
  131. device=model_config.EMBEDDING_DEVICE)
  132. print("加载embeding模型完成......")
  133. def init_vector_store(vs_path):
  134. start_time = time.perf_counter()
  135. if model_config.IS_LOCAL_STORAGE:
  136. vector_store = FAISS.load_local(vs_path, embeddings)
  137. else:
  138. # 去Pinecone官网免费注册获得:api_key、environment、index_name
  139. pinecone.init(api_key="", environment="")
  140. index_name = ""
  141. vector_store = Pinecone.from_existing_index(index_name=index_name, embedding=embeddings)
  142. end_time = time.perf_counter()
  143. # 计算操作耗时
  144. elapsed_time = end_time - start_time
  145. # 输出耗时时间
  146. print("init_vector_store===操作耗时: {:.6f} 秒".format(elapsed_time))
  147. return vector_store
  148. def init_cfg():
  149. global chatglm, embeddings, model_init, knowledge_chain, vector_store
  150. print("预加载模型......")
  151. start_time = time.perf_counter()
  152. print("加载GLM模型......")
  153. chatglm = ChatGLM()
  154. chatglm.load_model(model_name_or_path=MODEL_CONFIG.llm_model_dict[MODEL_CONFIG.LLM_MODEL])
  155. chatglm.history_len = MODEL_CONFIG.LLM_HISTORY_LEN
  156. print("模型加载完成!!!")
  157. end_time = time.perf_counter()
  158. # 计算操作耗时
  159. elapsed_time = end_time - start_time
  160. # 输出耗时时间
  161. print("模型预加载耗时: {:.6f} 秒".format(elapsed_time))
  162. system_template = """基于以下内容,简洁和专业的来回答用户的问题。
  163. 如果无法从中得到答案,请说 "不知道" 或 "没有足够的相关信息",不要试图编造答案,答案只要中文。
  164. ----------------
  165. {context}
  166. ----------------
  167. """
  168. messages = [
  169. SystemMessagePromptTemplate.from_template(system_template),
  170. HumanMessagePromptTemplate.from_template("{question}"),
  171. ]
  172. prompt = ChatPromptTemplate.from_messages(messages)
  173. knowledge_chain = RetrievalQA.from_llm(
  174. llm=chatglm,
  175. retriever=vector_store.as_retriever(search_kwargs={"k": MODEL_CONFIG.VECTOR_SEARCH_TOP_K}),
  176. prompt=prompt
  177. )
  178. knowledge_chain.return_source_documents = False
  179. if __name__ == '__main__':
  180. init_embedding()
  181. init_cfg()
  182. # 外网访问地址,记得端口在安全组、防火墙开放
  183. # uvicorn.run(app, host='0.0.0.0', port=8899, log_level="info")
  184. uvicorn.run(app, host='127.0.0.1', port=8899, log_level="info")