Bläddra i källkod

update webui.py

imClumsyPanda 2 år sedan
förälder
incheckning
5c3497912c
1 ändrade filer med 55 tillägg och 51 borttagningar
  1. 55 51
      webui.py

+ 55 - 51
webui.py

@@ -36,12 +36,10 @@ def get_answer(query, vs_path, history):
     return history, history
     return history, history
 
 
 
 
-def get_model_status(history):
-    return history + [[None, "模型已完成加载,请选择要加载的文档"]]
-
-
-def get_file_status(history):
-    return history + [[None, "文档已完成加载,请开始提问"]]
+def update_status(history, status):
+    history = history + [[None, status]]
+    print(status)
+    return history
 
 
 
 
 def init_model():
 def init_model():
@@ -53,22 +51,28 @@ def init_model():
 
 
 
 
 def reinit_model(llm_model, embedding_model, llm_history_len, top_k):
 def reinit_model(llm_model, embedding_model, llm_history_len, top_k):
-    local_doc_qa.init_cfg(llm_model=llm_model,
-                          embedding_model=embedding_model,
-                          llm_history_len=llm_history_len,
-                          top_k=top_k),
+    try:
+        local_doc_qa.init_cfg(llm_model=llm_model,
+                              embedding_model=embedding_model,
+                              llm_history_len=llm_history_len,
+                              top_k=top_k)
+        return """模型已成功重新加载,请选择文件后点击"加载文件"按钮"""
+    except:
+        return """模型未成功重新加载,请重新选择后点击"加载模型"按钮"""
+
 
 
 
 
 def get_vector_store(filepath):
 def get_vector_store(filepath):
-    local_doc_qa.init_knowledge_vector_store("content/"+filepath)
+    vs_path = local_doc_qa.init_knowledge_vector_store(["content/" + filepath])
+    if vs_path:
+        file_status = "文件已成功加载,请开始提问"
+    else:
+        file_status = "文件未成功加载,请重新上传文件"
+    print(file_status)
+    return vs_path, file_status
 
 
 
 
-model_status = gr.State()
-history = gr.State([])
-vs_path = gr.State()
-model_status = init_model()
-with gr.Blocks(css="""
-.importantButton {
+block_css = """.importantButton {
     background: linear-gradient(45deg, #7e0570,#5d1c99, #6e00ff) !important;
     background: linear-gradient(45deg, #7e0570,#5d1c99, #6e00ff) !important;
     border: none !important;
     border: none !important;
 }
 }
@@ -76,24 +80,31 @@ with gr.Blocks(css="""
 .importantButton:hover {
 .importantButton:hover {
     background: linear-gradient(45deg, #ff00e0,#8500ff, #6e00ff) !important;
     background: linear-gradient(45deg, #ff00e0,#8500ff, #6e00ff) !important;
     border: none !important;
     border: none !important;
-}
+}"""
 
 
-""") as demo:
-    gr.Markdown(
-        f"""
+webui_title = """
 # 🎉langchain-ChatGLM WebUI🎉
 # 🎉langchain-ChatGLM WebUI🎉
 
 
 👍 [https://github.com/imClumsyPanda/langchain-ChatGLM](https://github.com/imClumsyPanda/langchain-ChatGLM)
 👍 [https://github.com/imClumsyPanda/langchain-ChatGLM](https://github.com/imClumsyPanda/langchain-ChatGLM)
 
 
-""")
-    with gr.Row():
-        with gr.Column(scale=2):
-            chatbot = gr.Chatbot([[None, """欢迎使用 langchain-ChatGLM Web UI,开始提问前,请依次如下 3 个步骤:
+"""
+
+init_message = """欢迎使用 langchain-ChatGLM Web UI,开始提问前,请依次如下 3 个步骤:
 1. 选择语言模型、Embedding 模型及相关参数后点击"重新加载模型",并等待加载完成提示
 1. 选择语言模型、Embedding 模型及相关参数后点击"重新加载模型",并等待加载完成提示
 2. 上传或选择已有文件作为本地知识文档输入后点击"重新加载文档",并等待加载完成提示
 2. 上传或选择已有文件作为本地知识文档输入后点击"重新加载文档",并等待加载完成提示
-3. 输入要提交的问题后,点击回车提交 """], [None, str(model_status)]],
+3. 输入要提交的问题后,点击回车提交 """
+
+
+model_status = init_model()
+
+with gr.Blocks(css=block_css) as demo:
+    vs_path, history, file_status, model_status = gr.State(""), gr.State([]), gr.State(""), gr.State(model_status)
+    gr.Markdown(webui_title)
+    with gr.Row():
+        with gr.Column(scale=2):
+            chatbot = gr.Chatbot([[None, init_message], [None, model_status.value]],
                                  elem_id="chat-box",
                                  elem_id="chat-box",
-                                 show_label=False).style(height=600)
+                                 show_label=False).style(height=750)
             query = gr.Textbox(show_label=False,
             query = gr.Textbox(show_label=False,
                                placeholder="请提问",
                                placeholder="请提问",
                                lines=1,
                                lines=1,
@@ -103,7 +114,7 @@ with gr.Blocks(css="""
         with gr.Column(scale=1):
         with gr.Column(scale=1):
             llm_model = gr.Radio(llm_model_dict_list,
             llm_model = gr.Radio(llm_model_dict_list,
                                  label="LLM 模型",
                                  label="LLM 模型",
-                                 value="chatglm-6b",
+                                 value=LLM_MODEL,
                                  interactive=True)
                                  interactive=True)
             llm_history_len = gr.Slider(0,
             llm_history_len = gr.Slider(0,
                                         10,
                                         10,
@@ -113,7 +124,7 @@ with gr.Blocks(css="""
                                         interactive=True)
                                         interactive=True)
             embedding_model = gr.Radio(embedding_model_dict_list,
             embedding_model = gr.Radio(embedding_model_dict_list,
                                        label="Embedding 模型",
                                        label="Embedding 模型",
-                                       value="text2vec",
+                                       value=EMBEDDING_MODEL,
                                        interactive=True)
                                        interactive=True)
             top_k = gr.Slider(1,
             top_k = gr.Slider(1,
                               20,
                               20,
@@ -133,34 +144,27 @@ with gr.Blocks(css="""
                 file = gr.File(label="content file",
                 file = gr.File(label="content file",
                                file_types=['.txt', '.md', '.docx', '.pdf']
                                file_types=['.txt', '.md', '.docx', '.pdf']
                                )  # .style(height=100)
                                )  # .style(height=100)
-            load_button = gr.Button("重新加载文件")
+            load_file_button = gr.Button("重新加载文件")
     load_model_button.click(reinit_model,
     load_model_button.click(reinit_model,
                             show_progress=True,
                             show_progress=True,
-                            api_name="init_cfg",
-                            inputs=[llm_model, embedding_model, llm_history_len, top_k]
-                            ).then(
-        get_model_status, chatbot, chatbot
-    )
+                            inputs=[llm_model, embedding_model, llm_history_len, top_k],
+                            outputs=model_status
+                            ).then(update_status, [chatbot, model_status], chatbot)
     # 将上传的文件保存到content文件夹下,并更新下拉框
     # 将上传的文件保存到content文件夹下,并更新下拉框
     file.upload(upload_file,
     file.upload(upload_file,
                 inputs=file,
                 inputs=file,
                 outputs=selectFile)
                 outputs=selectFile)
-    load_button.click(get_vector_store,
-                      show_progress=True,
-                      api_name="init_knowledge_vector_store",
-                      inputs=selectFile,
-                      outputs=vs_path
-                      )#.then(
-    #     get_file_status,
-    #     chatbot,
-    #     chatbot,
-    #     show_progress=True,
-    # )
-    # query.submit(get_answer,
-    #              [query, vs_path, chatbot],
-    #              [chatbot, history],
-    #              api_name="get_knowledge_based_answer"
-    #              )
+    load_file_button.click(get_vector_store,
+                           show_progress=True,
+                           inputs=selectFile,
+                           outputs=[vs_path, file_status],
+                           ).then(
+        update_status, [chatbot, file_status], chatbot
+    )
+    query.submit(get_answer,
+                 [query, vs_path, chatbot],
+                 [chatbot, history],
+                 )
 
 
 demo.queue(concurrency_count=3).launch(
 demo.queue(concurrency_count=3).launch(
     server_name='0.0.0.0', share=False, inbrowser=False)
     server_name='0.0.0.0', share=False, inbrowser=False)