Explorar el Código

update webui.py

imClumsyPanda hace 2 años
padre
commit
0a4dd1987d
Se han modificado 1 ficheros con 32 adiciones y 23 borrados
  1. 32 23
      webui.py

+ 32 - 23
webui.py

@@ -8,17 +8,19 @@ import uuid
 
 
 nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path
 nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path
 
 
+
 def get_vs_list():
 def get_vs_list():
-    lst_default  = ["新建知识库"]
+    lst_default = ["新建知识库"]
     if not os.path.exists(VS_ROOT_PATH):
     if not os.path.exists(VS_ROOT_PATH):
         return lst_default
         return lst_default
-    lst= os.listdir(VS_ROOT_PATH)
-    if not lst:           
+    lst = os.listdir(VS_ROOT_PATH)
+    if not lst:
         return lst_default
         return lst_default
     lst.sort(reverse=True)
     lst.sort(reverse=True)
-    return lst+ lst_default
+    return lst + lst_default
+
 
 
-vs_list =get_vs_list()
+vs_list = get_vs_list()
 
 
 embedding_model_dict_list = list(embedding_model_dict.keys())
 embedding_model_dict_list = list(embedding_model_dict.keys())
 
 
@@ -29,6 +31,7 @@ local_doc_qa = LocalDocQA()
 logger = gr.CSVLogger()
 logger = gr.CSVLogger()
 username = uuid.uuid4().hex
 username = uuid.uuid4().hex
 
 
+
 def get_answer(query, vs_path, history, mode,
 def get_answer(query, vs_path, history, mode,
                streaming: bool = STREAMING):
                streaming: bool = STREAMING):
     if mode == "知识库问答" and vs_path:
     if mode == "知识库问答" and vs_path:
@@ -51,8 +54,9 @@ def get_answer(query, vs_path, history, mode,
                                                     streaming=streaming):
                                                     streaming=streaming):
             history[-1][-1] = resp + (
             history[-1][-1] = resp + (
                 "\n\n当前知识库为空,如需基于知识库进行问答,请先加载知识库后,再进行提问。" if mode == "知识库问答" else "")
                 "\n\n当前知识库为空,如需基于知识库进行问答,请先加载知识库后,再进行提问。" if mode == "知识库问答" else "")
-            yield history, "" 
-    logger.flag([query, vs_path, history, mode],username=username)
+            yield history, ""
+    logger.flag([query, vs_path, history, mode], username=username)
+
 
 
 def init_model():
 def init_model():
     try:
     try:
@@ -78,8 +82,8 @@ def reinit_model(llm_model, embedding_model, llm_history_len, use_ptuning_v2, us
                               embedding_model=embedding_model,
                               embedding_model=embedding_model,
                               llm_history_len=llm_history_len,
                               llm_history_len=llm_history_len,
                               use_ptuning_v2=use_ptuning_v2,
                               use_ptuning_v2=use_ptuning_v2,
-                              use_lora = use_lora,
-                              top_k=top_k,)
+                              use_lora=use_lora,
+                              top_k=top_k, )
         model_status = """模型已成功重新加载,可以开始对话,或从右侧选择模式后开始对话"""
         model_status = """模型已成功重新加载,可以开始对话,或从右侧选择模式后开始对话"""
         print(model_status)
         print(model_status)
     except Exception as e:
     except Exception as e:
@@ -111,12 +115,14 @@ def get_vector_store(vs_id, files, history):
     return vs_path, None, history + [[None, file_status]]
     return vs_path, None, history + [[None, file_status]]
 
 
 
 
-def change_vs_name_input(vs_id,history):
+def change_vs_name_input(vs_id, history):
     if vs_id == "新建知识库":
     if vs_id == "新建知识库":
-        return gr.update(visible=True), gr.update(visible=True), gr.update(visible=False), None,history
+        return gr.update(visible=True), gr.update(visible=True), gr.update(visible=False), None, history
     else:
     else:
         file_status = f"已加载知识库{vs_id},请开始提问"
         file_status = f"已加载知识库{vs_id},请开始提问"
-        return gr.update(visible=False), gr.update(visible=False), gr.update(visible=True), os.path.join(VS_ROOT_PATH, vs_id),history + [[None, file_status]]
+        return gr.update(visible=False), gr.update(visible=False), gr.update(visible=True), os.path.join(VS_ROOT_PATH,
+                                                                                                         vs_id), history + [
+                   [None, file_status]]
 
 
 
 
 def change_mode(mode):
 def change_mode(mode):
@@ -136,6 +142,7 @@ def add_vs_name(vs_name, vs_list, chatbot):
         chatbot = chatbot + [[None, vs_status]]
         chatbot = chatbot + [[None, vs_status]]
         return gr.update(visible=True, choices=vs_list + [vs_name], value=vs_name), vs_list + [vs_name], chatbot
         return gr.update(visible=True, choices=vs_list + [vs_name], value=vs_name), vs_list + [vs_name], chatbot
 
 
+
 block_css = """.importantButton {
 block_css = """.importantButton {
     background: linear-gradient(45deg, #7e0570,#5d1c99, #6e00ff) !important;
     background: linear-gradient(45deg, #7e0570,#5d1c99, #6e00ff) !important;
     border: none !important;
     border: none !important;
@@ -163,10 +170,11 @@ init_message = f"""欢迎使用 langchain-ChatGLM Web UI!
 """
 """
 
 
 model_status = init_model()
 model_status = init_model()
-default_path =  os.path.join(VS_ROOT_PATH, vs_list[0]) if len(vs_list) > 1 else ""
+default_path = os.path.join(VS_ROOT_PATH, vs_list[0]) if len(vs_list) > 1 else ""
 
 
 with gr.Blocks(css=block_css) as demo:
 with gr.Blocks(css=block_css) as demo:
-    vs_path, file_status, model_status, vs_list = gr.State(default_path), gr.State(""), gr.State(model_status), gr.State(vs_list)
+    vs_path, file_status, model_status, vs_list = gr.State(default_path), gr.State(""), gr.State(
+        model_status), gr.State(vs_list)
     gr.Markdown(webui_title)
     gr.Markdown(webui_title)
     with gr.Tab("对话"):
     with gr.Tab("对话"):
         with gr.Row():
         with gr.Row():
@@ -175,7 +183,7 @@ with gr.Blocks(css=block_css) as demo:
                                      elem_id="chat-box",
                                      elem_id="chat-box",
                                      show_label=False).style(height=750)
                                      show_label=False).style(height=750)
                 query = gr.Textbox(show_label=False,
                 query = gr.Textbox(show_label=False,
-                                   placeholder="请输入提问内容,按回车进行提交").style(container=False) 
+                                   placeholder="请输入提问内容,按回车进行提交").style(container=False)
             with gr.Column(scale=5):
             with gr.Column(scale=5):
                 mode = gr.Radio(["LLM 对话", "知识库问答"],
                 mode = gr.Radio(["LLM 对话", "知识库问答"],
                                 label="请选择使用模式",
                                 label="请选择使用模式",
@@ -218,7 +226,7 @@ with gr.Blocks(css=block_css) as demo:
                             load_folder_button = gr.Button("上传文件夹并加载知识库")
                             load_folder_button = gr.Button("上传文件夹并加载知识库")
                     # load_vs.click(fn=)
                     # load_vs.click(fn=)
                     select_vs.change(fn=change_vs_name_input,
                     select_vs.change(fn=change_vs_name_input,
-                                     inputs=[select_vs,chatbot],
+                                     inputs=[select_vs, chatbot],
                                      outputs=[vs_name, vs_add, file2vs, vs_path, chatbot])
                                      outputs=[vs_name, vs_add, file2vs, vs_path, chatbot])
                     # 将上传的文件保存到content文件夹下,并更新下拉框
                     # 将上传的文件保存到content文件夹下,并更新下拉框
                     load_file_button.click(get_vector_store,
                     load_file_button.click(get_vector_store,
@@ -230,11 +238,11 @@ with gr.Blocks(css=block_css) as demo:
                                              show_progress=True,
                                              show_progress=True,
                                              inputs=[select_vs, folder_files, chatbot],
                                              inputs=[select_vs, folder_files, chatbot],
                                              outputs=[vs_path, folder_files, chatbot],
                                              outputs=[vs_path, folder_files, chatbot],
-                                             )             
-                    logger.setup([query, vs_path, chatbot, mode], "flagged")       
+                                             )
+                    logger.setup([query, vs_path, chatbot, mode], "flagged")
                     query.submit(get_answer,
                     query.submit(get_answer,
-                               [query, vs_path, chatbot, mode],
-                               [chatbot, query])                  
+                                 [query, vs_path, chatbot, mode],
+                                 [chatbot, query])
     with gr.Tab("模型配置"):
     with gr.Tab("模型配置"):
         llm_model = gr.Radio(llm_model_dict_list,
         llm_model = gr.Radio(llm_model_dict_list,
                              label="LLM 模型",
                              label="LLM 模型",
@@ -250,8 +258,8 @@ with gr.Blocks(css=block_css) as demo:
                                      label="使用p-tuning-v2微调过的模型",
                                      label="使用p-tuning-v2微调过的模型",
                                      interactive=True)
                                      interactive=True)
         use_lora = gr.Checkbox(USE_LORA,
         use_lora = gr.Checkbox(USE_LORA,
-                                     label="使用lora微调的权重",
-                                     interactive=True)
+                               label="使用lora微调的权重",
+                               interactive=True)
         embedding_model = gr.Radio(embedding_model_dict_list,
         embedding_model = gr.Radio(embedding_model_dict_list,
                                    label="Embedding 模型",
                                    label="Embedding 模型",
                                    value=EMBEDDING_MODEL,
                                    value=EMBEDDING_MODEL,
@@ -265,7 +273,8 @@ with gr.Blocks(css=block_css) as demo:
         load_model_button = gr.Button("重新加载模型")
         load_model_button = gr.Button("重新加载模型")
     load_model_button.click(reinit_model,
     load_model_button.click(reinit_model,
                             show_progress=True,
                             show_progress=True,
-                            inputs=[llm_model, embedding_model, llm_history_len, use_ptuning_v2, use_lora, top_k, chatbot],
+                            inputs=[llm_model, embedding_model, llm_history_len, use_ptuning_v2, use_lora, top_k,
+                                    chatbot],
                             outputs=chatbot
                             outputs=chatbot
                             )
                             )