Browse Source

add streaming option in configs/model_config.py

imClumsyPanda 2 năm trước cách đây
mục cha
commit
0e8cc0d16c
3 tập tin đã thay đổi với 30 bổ sung12 xóa
  1. 4 4
      chains/local_doc_qa.py
  2. 3 0
      configs/model_config.py
  3. 23 8
      webui.py

+ 4 - 4
chains/local_doc_qa.py

@@ -116,10 +116,12 @@ class LocalDocQA:
                  llm_history_len: int = LLM_HISTORY_LEN,
                  llm_model: str = LLM_MODEL,
                  llm_device=LLM_DEVICE,
+                 streaming=STREAMING,
                  top_k=VECTOR_SEARCH_TOP_K,
                  use_ptuning_v2: bool = USE_PTUNING_V2
                  ):
         self.llm = ChatGLM()
+        self.llm.streaming = streaming
         self.llm.load_model(model_name_or_path=llm_model_dict[llm_model],
                             llm_device=llm_device,
                             use_ptuning_v2=use_ptuning_v2)
@@ -186,9 +188,7 @@ class LocalDocQA:
     def get_knowledge_based_answer(self,
                                    query,
                                    vs_path,
-                                   chat_history=[],
-                                   streaming=True):
-        self.llm.streaming = streaming
+                                   chat_history=[]):
         vector_store = FAISS.load_local(vs_path, self.embeddings)
         FAISS.similarity_search_with_score_by_vector = similarity_search_with_score_by_vector
         vector_store.chunk_size=self.chunk_size
@@ -197,7 +197,7 @@ class LocalDocQA:
         related_docs = get_docs_with_score(related_docs_with_score)
         prompt = generate_prompt(related_docs, query)
 
-        if streaming:
+        if self.llm.streaming:
             for result, history in self.llm._call(prompt=prompt,
                                                   history=chat_history):
                 history[-1][0] = query

+ 3 - 0
configs/model_config.py

@@ -27,6 +27,9 @@ llm_model_dict = {
 # LLM model name
 LLM_MODEL = "chatglm-6b"
 
+# LLM streaming reponse
+STREAMING = True
+
 # Use p-tuning-v2 PrefixEncoder
 USE_PTUNING_V2 = False
 

+ 23 - 8
webui.py

@@ -30,8 +30,8 @@ local_doc_qa = LocalDocQA()
 
 
 def get_answer(query, vs_path, history, mode):
-    if mode == "知识库问答":
-        if vs_path:
+    if mode == "知识库问答" and vs_path:
+        if local_doc_qa.llm.streaming:
             for resp, history in local_doc_qa.get_knowledge_based_answer(
                     query=query, vs_path=vs_path, chat_history=history):
                 source = "\n\n"
@@ -44,14 +44,28 @@ def get_answer(query, vs_path, history, mode):
                 history[-1][-1] += source
                 yield history, ""
         else:
+            resp, history = local_doc_qa.get_knowledge_based_answer(
+                query=query, vs_path=vs_path, chat_history=history)
+            source = "\n\n"
+            source += "".join(
+                [f"""<details> <summary>出处 [{i + 1}] {os.path.split(doc.metadata["source"])[-1]}</summary>\n"""
+                 f"""{doc.page_content}\n"""
+                 f"""</details>"""
+                 for i, doc in
+                 enumerate(resp["source_documents"])])
+            history[-1][-1] += source
+            return history, ""
+    else:
+        if local_doc_qa.llm.streaming:
             for resp, history in local_doc_qa.llm._call(query, history):
                 history[-1][-1] = resp + (
                     "\n\n当前知识库为空,如需基于知识库进行问答,请先加载知识库后,再进行提问。" if mode == "知识库问答" else "")
                 yield history, ""
-    else:
-        for resp, history in local_doc_qa.llm._call(query, history):
-            history[-1][-1] = resp
-            yield history, ""
+        else:
+            resp, history = local_doc_qa.llm._call(query, history)
+            history[-1][-1] = resp + (
+                "\n\n当前知识库为空,如需基于知识库进行问答,请先加载知识库后,再进行提问。" if mode == "知识库问答" else "")
+            return history, ""
 
 
 def update_status(history, status):
@@ -62,7 +76,7 @@ def update_status(history, status):
 
 def init_model():
     try:
-        local_doc_qa.init_cfg()
+        local_doc_qa.init_cfg(streaming=STREAMING)
         local_doc_qa.llm._call("你好")
         reply = """模型已成功加载,可以开始对话,或从右侧选择模式后开始对话"""
         print(reply)
@@ -84,7 +98,8 @@ def reinit_model(llm_model, embedding_model, llm_history_len, use_ptuning_v2, to
                               embedding_model=embedding_model,
                               llm_history_len=llm_history_len,
                               use_ptuning_v2=use_ptuning_v2,
-                              top_k=top_k)
+                              top_k=top_k,
+                              streaming=STREAMING)
         model_status = """模型已成功重新加载,可以开始对话,或从右侧选择模式后开始对话"""
         print(model_status)
     except Exception as e: