Ver Fonte

update webui.py

imClumsyPanda há 2 anos atrás
pai
commit
e9a5db1a9d
1 ficheiros alterados com 63 adições e 69 exclusões
  1. 63 69
      webui.py

+ 63 - 69
webui.py

@@ -14,19 +14,12 @@ VECTOR_SEARCH_TOP_K = 6
 LLM_HISTORY_LEN = 3
 
 
-def get_file_list():
-    if not os.path.exists("content"):
-        return []
-    return [f for f in os.listdir("content")]
-
-
 def get_vs_list():
-    if not os.path.exists("vector_store"):
+    if not os.path.exists(VS_ROOT_PATH):
         return []
-    return ["新建知识库"] + os.listdir("vector_store")
+    return ["新建知识库"] + os.listdir(VS_ROOT_PATH)
 
 
-file_list = get_file_list()
 vs_list = get_vs_list()
 
 embedding_model_dict_list = list(embedding_model_dict.keys())
@@ -36,19 +29,8 @@ llm_model_dict_list = list(llm_model_dict.keys())
 local_doc_qa = LocalDocQA()
 
 
-def upload_file(file, chatbot):
-    if not os.path.exists("content"):
-        os.mkdir("content")
-    filename = os.path.basename(file.name)
-    shutil.move(file.name, "content/" + filename)
-    # file_list首位插入新上传的文件
-    file_list.insert(0, filename)
-    status = "已将xx上传至xxx"
-    return chatbot + [None, status]
-
-
-def get_answer(query, vs_path, history):
-    if vs_path:
+def get_answer(query, vs_path, history, mode):
+    if vs_path and mode == "知识库问答":
         resp, history = local_doc_qa.get_knowledge_based_answer(
             query=query, vs_path=vs_path, chat_history=history)
         source = "".join([f"""<details> <summary>出处 {i + 1}</summary>
@@ -93,24 +75,30 @@ def reinit_model(llm_model, embedding_model, llm_history_len, use_ptuning_v2, to
     return history + [[None, model_status]]
 
 
-def get_vector_store(filepath, history):
+def get_vector_store(vs_id, files, history):
+    vs_path = VS_ROOT_PATH + vs_id
+    filelist = []
+    for file in files:
+        filename = os.path.split(file.name)[-1]
+        shutil.move(file.name, UPLOAD_ROOT_PATH + filename)
+        filelist.append(UPLOAD_ROOT_PATH + filename)
     if local_doc_qa.llm and local_doc_qa.embeddings:
-        vs_path = local_doc_qa.init_knowledge_vector_store(["content/" + filepath])
-        if vs_path:
-            file_status = "文件已成功加载,请开始提问"
+        vs_path, loaded_files = local_doc_qa.init_knowledge_vector_store(filelist, vs_path)
+        if len(loaded_files):
+            file_status = f"已上传 {'、'.join([os.path.split(i)[-1] for i in loaded_files])} 至知识库,并已加载知识库,请开始提问"
         else:
             file_status = "文件未成功加载,请重新上传文件"
     else:
         file_status = "模型未完成加载,请先在加载模型后再导入文件"
         vs_path = None
-    return vs_path, history + [[None, file_status]]
+    return vs_path, None, history + [[None, file_status]]
 
 
 def change_vs_name_input(vs):
     if vs == "新建知识库":
-        return gr.update(visible=True), gr.update(visible=True)
+        return gr.update(visible=True), gr.update(visible=True), gr.update(visible=False)
     else:
-        return gr.update(visible=False), gr.update(visible=False)
+        return gr.update(visible=False), gr.update(visible=False), gr.update(visible=True)
 
 
 def change_mode(mode):
@@ -119,15 +107,15 @@ def change_mode(mode):
     else:
         return gr.update(visible=False)
 
+
 def add_vs_name(vs_name, vs_list, chatbot):
     if vs_name in vs_list:
-        chatbot = chatbot+[None, "与已有知识库名称冲突,请重新选择其他名称后提交"]
-        return gr.update(visible=True),vs_list, chatbot
+        chatbot = chatbot + [[None, "与已有知识库名称冲突,请重新选择其他名称后提交"]]
+        return gr.update(visible=True), vs_list, chatbot
     else:
-        chatbot = chatbot + [None, f"""已新增知识库"{vs_name}" """]
-        vs_list = vs_list+[vs_name]
-        return gr.update(visible=True),vs_list, chatbot
-
+        chatbot = chatbot + [
+            [None, f"""已新增知识库"{vs_name}",将在上传文件并载入成功后进行存储。请在开始对话前,先完成文件上传。 """]]
+        return gr.update(visible=True, choices=vs_list + [vs_name], value=vs_name), vs_list + [vs_name], chatbot
 
 
 block_css = """.importantButton {
@@ -155,7 +143,7 @@ init_message = """欢迎使用 langchain-ChatGLM Web UI,开始提问前,请
 model_status = init_model()
 
 with gr.Blocks(css=block_css) as demo:
-    vs_path, file_status, model_status = gr.State(""), gr.State(""), gr.State(model_status)
+    vs_path, file_status, model_status, vs_list = gr.State(""), gr.State(""), gr.State(model_status), gr.State(vs_list)
     gr.Markdown(webui_title)
     with gr.Tab("对话"):
         with gr.Row():
@@ -169,16 +157,17 @@ with gr.Blocks(css=block_css) as demo:
             with gr.Column(scale=5):
                 mode = gr.Radio(["LLM 对话", "知识库问答"],
                                 label="请选择使用模式",
-                                value="知识库问答",)
+                                value="知识库问答", )
                 vs_setting = gr.Accordion("配置知识库")
                 mode.change(fn=change_mode,
                             inputs=mode,
                             outputs=vs_setting)
                 with vs_setting:
-                    select_vs = gr.Dropdown(vs_list,
+                    select_vs = gr.Dropdown(vs_list.value,
                                             label="请选择要加载的知识库",
                                             interactive=True,
-                                            value=vs_list[0] if len(vs_list) > 0 else None)
+                                            value=vs_list.value[0] if len(vs_list.value) > 0 else None
+                                            )
                     vs_name = gr.Textbox(label="请输入新建知识库名称",
                                          lines=1,
                                          interactive=True)
@@ -186,24 +175,42 @@ with gr.Blocks(css=block_css) as demo:
                     vs_add.click(fn=add_vs_name,
                                  inputs=[vs_name, vs_list, chatbot],
                                  outputs=[select_vs, vs_list, chatbot])
+
+                    file2vs = gr.Box(visible=False)
+                    with file2vs:
+                        gr.Markdown("向知识库中添加文件")
+                        with gr.Tab("上传文件"):
+                            files = gr.File(label="添加文件",
+                                            file_types=['.txt', '.md', '.docx', '.pdf'],
+                                            file_count="multiple",
+                                            show_label=False
+                                            )
+                            load_file_button = gr.Button("上传文件")
+                        with gr.Tab("上传文件夹"):
+                            folder_files = gr.File(label="添加文件",
+                                                   # file_types=['.txt', '.md', '.docx', '.pdf'],
+                                                   file_count="directory",
+                                                   show_label=False
+                                                   )
+                            load_folder_button = gr.Button("上传文件夹")
                     select_vs.change(fn=change_vs_name_input,
                                      inputs=select_vs,
-                                     outputs=[vs_name, vs_add])
-                    gr.Markdown("向知识库中添加文件")
-                    with gr.Tab("上传文件"):
-                        files = gr.File(label="添加文件",
-                                        file_types=['.txt', '.md', '.docx', '.pdf'],
-                                        file_count="multiple",
-                                        show_label=False
-                                        )
-                        load_file_button = gr.Button("上传文件")
-                    with gr.Tab("上传文件夹"):
-                        folder_files = gr.File(label="添加文件",
-                                               file_types=['.txt', '.md', '.docx', '.pdf'],
-                                               file_count="directory",
-                                               show_label=False
-                                               )
-                        load_folder_button = gr.Button("上传文件夹")
+                                     outputs=[vs_name, vs_add, file2vs])
+                    # 将上传的文件保存到content文件夹下,并更新下拉框
+                    load_file_button.click(get_vector_store,
+                                           show_progress=True,
+                                           inputs=[select_vs, files, chatbot],
+                                           outputs=[vs_path, files, chatbot],
+                                           )
+                    load_folder_button.click(get_vector_store,
+                                             show_progress=True,
+                                             inputs=[select_vs, folder_files, chatbot],
+                                             outputs=[vs_path, folder_files, chatbot],
+                                             )
+                    query.submit(get_answer,
+                                 [query, vs_path, chatbot, mode],
+                                 [chatbot, query],
+                                 )
     with gr.Tab("模型配置"):
         llm_model = gr.Radio(llm_model_dict_list,
                              label="LLM 模型",
@@ -213,7 +220,7 @@ with gr.Blocks(css=block_css) as demo:
                                     10,
                                     value=LLM_HISTORY_LEN,
                                     step=1,
-                                    label="LLM history len",
+                                    label="LLM 对话轮数",
                                     interactive=True)
         use_ptuning_v2 = gr.Checkbox(USE_PTUNING_V2,
                                      label="使用p-tuning-v2微调过的模型",
@@ -234,19 +241,6 @@ with gr.Blocks(css=block_css) as demo:
                             inputs=[llm_model, embedding_model, llm_history_len, use_ptuning_v2, top_k, chatbot],
                             outputs=chatbot
                             )
-    # 将上传的文件保存到content文件夹下,并更新下拉框
-    files.upload(upload_file,
-                 inputs=[files, chatbot],
-                 outputs=chatbot)
-    load_file_button.click(get_vector_store,
-                           show_progress=True,
-                           inputs=[select_vs, chatbot],
-                           outputs=[vs_path, chatbot],
-                           )
-    query.submit(get_answer,
-                 [query, vs_path, chatbot],
-                 [chatbot, query],
-                 )
 
 demo.queue(concurrency_count=3
            ).launch(server_name='0.0.0.0',