Ver Fonte

处理引用错误

tenggangren há 2 anos atrás
pai
commit
5ba558ed9c
1 ficheiros alterados com 21 adições e 21 exclusões
  1. 21 21
      api/api.py

+ 21 - 21
api/api.py

@@ -1,5 +1,6 @@
 import sys
-sys.path.append("..") # 将父目录放入系统路径中
+
+sys.path.append("..")  # 将父目录放入系统路径中
 
 from fastapi import FastAPI, Request, UploadFile, File
 from fastapi.responses import StreamingResponse
@@ -118,6 +119,24 @@ async def create_item(request: Request):
     # temperature = json_post_list.get('temperature')
     chatglm.history = history
     chatglm.is_stream_chat = 0
+    vector_store = init_vector_store(vs_path)
+    system_template = """基于以下内容,简洁和专业的来回答用户的问题。
+        如果无法从中得到答案,请说 "不知道" 或 "没有足够的相关信息",不要试图编造答案,答案只要中文。
+        ----------------
+        {context}
+        ----------------
+        """
+    messages = [
+        SystemMessagePromptTemplate.from_template(system_template),
+        HumanMessagePromptTemplate.from_template("{question}"),
+    ]
+    prompt = ChatPromptTemplate.from_messages(messages)
+    knowledge_chain = RetrievalQA.from_llm(
+        llm=chatglm,
+        retriever=vector_store.as_retriever(search_kwargs={"k": model_config.VECTOR_SEARCH_TOP_K}),
+        prompt=prompt
+    )
+    knowledge_chain.return_source_documents = False
     response = knowledge_chain({"query": query})
     # chatglm.history[-1][0] = query
     end_time = time.perf_counter()
@@ -166,9 +185,8 @@ def init_vector_store(vs_path):
     print("init_vector_store===操作耗时: {:.6f} 秒".format(elapsed_time))
     return vector_store
 
-
 def init_cfg():
-    global chatglm, embeddings, model_init, knowledge_chain, vector_store
+    global chatglm, embeddings, vector_store
     print("预加载模型......")
     start_time = time.perf_counter()
     print("加载GLM模型......")
@@ -182,24 +200,6 @@ def init_cfg():
     # 输出耗时时间
     print("模型预加载耗时: {:.6f} 秒".format(elapsed_time))
 
-    system_template = """基于以下内容,简洁和专业的来回答用户的问题。
-   如果无法从中得到答案,请说 "不知道" 或 "没有足够的相关信息",不要试图编造答案,答案只要中文。
-   ----------------
-   {context}
-   ----------------
-   """
-    messages = [
-        SystemMessagePromptTemplate.from_template(system_template),
-        HumanMessagePromptTemplate.from_template("{question}"),
-    ]
-    prompt = ChatPromptTemplate.from_messages(messages)
-    knowledge_chain = RetrievalQA.from_llm(
-        llm=chatglm,
-        retriever=vector_store.as_retriever(search_kwargs={"k": model_config.VECTOR_SEARCH_TOP_K}),
-        prompt=prompt
-    )
-    knowledge_chain.return_source_documents = False
-
 
 if __name__ == '__main__':
     init_embedding()