Browse Source

update webui.py

imClumsyPanda 2 years ago
parent
commit
635aa2d2ac
2 changed files with 77 additions and 74 deletions
  1. 0 9
      chains/local_doc_qa.py
  2. 77 65
      webui.py

+ 0 - 9
chains/local_doc_qa.py

@@ -17,10 +17,6 @@ VECTOR_SEARCH_TOP_K = 6
 # LLM input history length
 LLM_HISTORY_LEN = 3
 
-<<<<<<< HEAD
-<<<<<<< HEAD
-=======
->>>>>>> 7cc03c3 (feat: add api for knowledge_based QA)
 
 def load_file(filepath):
     if filepath.lower().endswith(".pdf"):
@@ -33,11 +29,6 @@ def load_file(filepath):
         docs = loader.load_and_split(text_splitter=textsplitter)
     return docs
 
-<<<<<<< HEAD
-=======
->>>>>>> cba44ca (修复 webui.py 中 llm_history_len 和 vector_search_top_k 显示值与启动设置默认值不一致的问题)
-=======
->>>>>>> 7cc03c3 (feat: add api for knowledge_based QA)
 
 class LocalDocQA:
     llm: object = None

+ 77 - 65
webui.py

@@ -12,18 +12,7 @@ VECTOR_SEARCH_TOP_K = 6
 
 # LLM input history length
 LLM_HISTORY_LEN = 3
-<<<<<<< HEAD
-=======
 
-<<<<<<< HEAD
->>>>>>> f87a5f5 (fix bug in webui.py)
-=======
-# return top-k text chunk from vector store
-VECTOR_SEARCH_TOP_K = 6
-
-# LLM input history length
-LLM_HISTORY_LEN = 3
->>>>>>> cba44ca (修复 webui.py 中 llm_history_len 和 vector_search_top_k 显示值与启动设置默认值不一致的问题)
 
 def get_file_list():
     if not os.path.exists("content"):
@@ -31,7 +20,14 @@ def get_file_list():
     return [f for f in os.listdir("content")]
 
 
+def get_vs_list():
+    if not os.path.exists("vector_store"):
+        return []
+    return [f for f in os.listdir("vector_store")]
+
+
 file_list = get_file_list()
+vs_list = get_vs_list()
 
 embedding_model_dict_list = list(embedding_model_dict.keys())
 
@@ -40,22 +36,30 @@ llm_model_dict_list = list(llm_model_dict.keys())
 local_doc_qa = LocalDocQA()
 
 
-def upload_file(file):
+def upload_file(file, chatbot):
     if not os.path.exists("content"):
         os.mkdir("content")
     filename = os.path.basename(file.name)
     shutil.move(file.name, "content/" + filename)
     # file_list首位插入新上传的文件
     file_list.insert(0, filename)
-    return gr.Dropdown.update(choices=file_list, value=filename)
+    status = "已将xx上传至xxx"
+    return chatbot + [None, status]
 
 
 def get_answer(query, vs_path, history):
     if vs_path:
         resp, history = local_doc_qa.get_knowledge_based_answer(
             query=query, vs_path=vs_path, chat_history=history)
+        source = "".join([f"""<details> <summary>出处 {i + 1}</summary>
+{doc.page_content}
+
+<b>所属文件:</b>{doc.metadata["source"]}
+</details>""" for i, doc in enumerate(resp["source_documents"])])
+        history[-1][-1] += source
     else:
-        history = history + [[None, "请先加载文件后,再进行提问。"]]
+        resp = local_doc_qa.llm._call(query)
+        history = history + [[None, resp + "\n如需基于知识库进行问答,请先加载知识库后,再进行提问。"]]
     return history, ""
 
 
@@ -68,6 +72,7 @@ def update_status(history, status):
 def init_model():
     try:
         local_doc_qa.init_cfg()
+        local_doc_qa.llm._call("你好")
         return """模型已成功加载,请选择文件后点击"加载文件"按钮"""
     except Exception as e:
         print(e)
@@ -88,7 +93,6 @@ def reinit_model(llm_model, embedding_model, llm_history_len, use_ptuning_v2, to
     return history + [[None, model_status]]
 
 
-
 def get_vector_store(filepath, history):
     if local_doc_qa.llm and local_doc_qa.embeddings:
         vs_path = local_doc_qa.init_knowledge_vector_store(["content/" + filepath])
@@ -120,71 +124,79 @@ webui_title = """
 """
 
 init_message = """欢迎使用 langchain-ChatGLM Web UI,开始提问前,请依次如下 3 个步骤:
-1. 选择语言模型、Embedding 模型及相关参数,如果使用ptuning-v2方式微调过模型,将PrefixEncoder模型放在ptuning-v2文件夹里并勾选相关选项,然后点击"重新加载模型",并等待加载完成提示
+1. 选择语言模型、Embedding 模型及相关参数,如果使用 ptuning-v2 方式微调过模型,将 PrefixEncoder 模型放在 ptuning-v2 文件夹里并勾选相关选项,然后点击"重新加载模型",并等待加载完成提示
 2. 上传或选择已有文件作为本地知识文档输入后点击"重新加载文档",并等待加载完成提示
 3. 输入要提交的问题后,点击回车提交 """
 
-
 model_status = init_model()
 
 with gr.Blocks(css=block_css) as demo:
     vs_path, file_status, model_status = gr.State(""), gr.State(""), gr.State(model_status)
     gr.Markdown(webui_title)
-    with gr.Row():
-        with gr.Column(scale=2):
-            chatbot = gr.Chatbot([[None, init_message], [None, model_status.value]],
-                                 elem_id="chat-box",
-                                 show_label=False).style(height=750)
-            query = gr.Textbox(show_label=False,
-                               placeholder="请输入提问内容,按回车进行提交",
-                               ).style(container=False)
-
-        with gr.Column(scale=1):
-            llm_model = gr.Radio(llm_model_dict_list,
-                                 label="LLM 模型",
-                                 value=LLM_MODEL,
-                                 interactive=True)
-            llm_history_len = gr.Slider(0,
-                                        10,
-                                        value=LLM_HISTORY_LEN,
-                                        step=1,
-                                        label="LLM history len",
-                                        interactive=True)
-            use_ptuning_v2 = gr.Checkbox(USE_PTUNING_V2,
-                                         label="使用p-tuning-v2微调过的模型",
-                                         interactive=True)
-            embedding_model = gr.Radio(embedding_model_dict_list,
-                                       label="Embedding 模型",
-                                       value=EMBEDDING_MODEL,
-                                       interactive=True)
-            top_k = gr.Slider(1,
-                              20,
-                              value=VECTOR_SEARCH_TOP_K,
-                              step=1,
-                              label="向量匹配 top k",
-                              interactive=True)
-            load_model_button = gr.Button("重新加载模型")
-
-            # with gr.Column():
-            with gr.Tab("select"):
-                selectFile = gr.Dropdown(file_list,
-                                         label="content file",
+    with gr.Tab("聊天"):
+        with gr.Row():
+            with gr.Column(scale=2):
+                chatbot = gr.Chatbot([[None, init_message], [None, model_status.value]],
+                                     elem_id="chat-box",
+                                     show_label=False).style(height=750)
+                query = gr.Textbox(show_label=False,
+                                   placeholder="请输入提问内容,按回车进行提交",
+                                   ).style(container=False)
+
+            with gr.Column(scale=1):
+                # with gr.Column():
+                # with gr.Tab("select"):
+                selectFile = gr.Dropdown(vs_list,
+                                         label="请选择要加载的知识库",
                                          interactive=True,
-                                         value=file_list[0] if len(file_list) > 0 else None)
-            with gr.Tab("upload"):
-                file = gr.File(label="content file",
-                               file_types=['.txt', '.md', '.docx', '.pdf']
-                               )  # .style(height=100)
-            load_file_button = gr.Button("加载文件")
+                                         value=vs_list[0] if len(vs_list) > 0 else None)
+                #
+                gr.Markdown("向知识库中添加文件")
+                with gr.Tab("上传文件"):
+                    files = gr.File(label="向知识库中添加文件",
+                                    file_types=['.txt', '.md', '.docx', '.pdf'],
+                                    file_count="multiple"
+                                    )  # .style(height=100)
+                with gr.Tab("上传文件夹"):
+                    files = gr.File(label="向知识库中添加文件",
+                                    file_types=['.txt', '.md', '.docx', '.pdf'],
+                                    file_count="directory"
+                                    )  # .style(height=100)
+                load_file_button = gr.Button("加载知识库")
+    with gr.Tab("模型配置"):
+        llm_model = gr.Radio(llm_model_dict_list,
+                             label="LLM 模型",
+                             value=LLM_MODEL,
+                             interactive=True)
+        llm_history_len = gr.Slider(0,
+                                    10,
+                                    value=LLM_HISTORY_LEN,
+                                    step=1,
+                                    label="LLM history len",
+                                    interactive=True)
+        use_ptuning_v2 = gr.Checkbox(USE_PTUNING_V2,
+                                     label="使用p-tuning-v2微调过的模型",
+                                     interactive=True)
+        embedding_model = gr.Radio(embedding_model_dict_list,
+                                   label="Embedding 模型",
+                                   value=EMBEDDING_MODEL,
+                                   interactive=True)
+        top_k = gr.Slider(1,
+                          20,
+                          value=VECTOR_SEARCH_TOP_K,
+                          step=1,
+                          label="向量匹配 top k",
+                          interactive=True)
+        load_model_button = gr.Button("重新加载模型")
     load_model_button.click(reinit_model,
                             show_progress=True,
                             inputs=[llm_model, embedding_model, llm_history_len, use_ptuning_v2, top_k, chatbot],
                             outputs=chatbot
                             )
     # 将上传的文件保存到content文件夹下,并更新下拉框
-    file.upload(upload_file,
-                inputs=file,
-                outputs=selectFile)
+    files.upload(upload_file,
+                 inputs=[files, chatbot],
+                 outputs=chatbot)
     load_file_button.click(get_vector_store,
                            show_progress=True,
                            inputs=[selectFile, chatbot],