|
@@ -8,17 +8,19 @@ import uuid
|
|
|
|
|
|
nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path
|
|
|
|
|
|
+
|
|
|
def get_vs_list():
|
|
|
- lst_default = ["新建知识库"]
|
|
|
+ lst_default = ["新建知识库"]
|
|
|
if not os.path.exists(VS_ROOT_PATH):
|
|
|
return lst_default
|
|
|
- lst= os.listdir(VS_ROOT_PATH)
|
|
|
- if not lst:
|
|
|
+ lst = os.listdir(VS_ROOT_PATH)
|
|
|
+ if not lst:
|
|
|
return lst_default
|
|
|
lst.sort(reverse=True)
|
|
|
- return lst+ lst_default
|
|
|
+ return lst + lst_default
|
|
|
+
|
|
|
|
|
|
-vs_list =get_vs_list()
|
|
|
+vs_list = get_vs_list()
|
|
|
|
|
|
embedding_model_dict_list = list(embedding_model_dict.keys())
|
|
|
|
|
@@ -29,6 +31,7 @@ local_doc_qa = LocalDocQA()
|
|
|
logger = gr.CSVLogger()
|
|
|
username = uuid.uuid4().hex
|
|
|
|
|
|
+
|
|
|
def get_answer(query, vs_path, history, mode,
|
|
|
streaming: bool = STREAMING):
|
|
|
if mode == "知识库问答" and vs_path:
|
|
@@ -51,8 +54,9 @@ def get_answer(query, vs_path, history, mode,
|
|
|
streaming=streaming):
|
|
|
history[-1][-1] = resp + (
|
|
|
"\n\n当前知识库为空,如需基于知识库进行问答,请先加载知识库后,再进行提问。" if mode == "知识库问答" else "")
|
|
|
- yield history, ""
|
|
|
- logger.flag([query, vs_path, history, mode],username=username)
|
|
|
+ yield history, ""
|
|
|
+ logger.flag([query, vs_path, history, mode], username=username)
|
|
|
+
|
|
|
|
|
|
def init_model():
|
|
|
try:
|
|
@@ -78,8 +82,8 @@ def reinit_model(llm_model, embedding_model, llm_history_len, use_ptuning_v2, us
|
|
|
embedding_model=embedding_model,
|
|
|
llm_history_len=llm_history_len,
|
|
|
use_ptuning_v2=use_ptuning_v2,
|
|
|
- use_lora = use_lora,
|
|
|
- top_k=top_k,)
|
|
|
+ use_lora=use_lora,
|
|
|
+ top_k=top_k, )
|
|
|
model_status = """模型已成功重新加载,可以开始对话,或从右侧选择模式后开始对话"""
|
|
|
print(model_status)
|
|
|
except Exception as e:
|
|
@@ -111,12 +115,14 @@ def get_vector_store(vs_id, files, history):
|
|
|
return vs_path, None, history + [[None, file_status]]
|
|
|
|
|
|
|
|
|
-def change_vs_name_input(vs_id,history):
|
|
|
+def change_vs_name_input(vs_id, history):
|
|
|
if vs_id == "新建知识库":
|
|
|
- return gr.update(visible=True), gr.update(visible=True), gr.update(visible=False), None,history
|
|
|
+ return gr.update(visible=True), gr.update(visible=True), gr.update(visible=False), None, history
|
|
|
else:
|
|
|
file_status = f"已加载知识库{vs_id},请开始提问"
|
|
|
- return gr.update(visible=False), gr.update(visible=False), gr.update(visible=True), os.path.join(VS_ROOT_PATH, vs_id),history + [[None, file_status]]
|
|
|
+ return gr.update(visible=False), gr.update(visible=False), gr.update(visible=True), os.path.join(VS_ROOT_PATH,
|
|
|
+ vs_id), history + [
|
|
|
+ [None, file_status]]
|
|
|
|
|
|
|
|
|
def change_mode(mode):
|
|
@@ -136,6 +142,7 @@ def add_vs_name(vs_name, vs_list, chatbot):
|
|
|
chatbot = chatbot + [[None, vs_status]]
|
|
|
return gr.update(visible=True, choices=vs_list + [vs_name], value=vs_name), vs_list + [vs_name], chatbot
|
|
|
|
|
|
+
|
|
|
block_css = """.importantButton {
|
|
|
background: linear-gradient(45deg, #7e0570,#5d1c99, #6e00ff) !important;
|
|
|
border: none !important;
|
|
@@ -163,10 +170,11 @@ init_message = f"""欢迎使用 langchain-ChatGLM Web UI!
|
|
|
"""
|
|
|
|
|
|
model_status = init_model()
|
|
|
-default_path = os.path.join(VS_ROOT_PATH, vs_list[0]) if len(vs_list) > 1 else ""
|
|
|
+default_path = os.path.join(VS_ROOT_PATH, vs_list[0]) if len(vs_list) > 1 else ""
|
|
|
|
|
|
with gr.Blocks(css=block_css) as demo:
|
|
|
- vs_path, file_status, model_status, vs_list = gr.State(default_path), gr.State(""), gr.State(model_status), gr.State(vs_list)
|
|
|
+ vs_path, file_status, model_status, vs_list = gr.State(default_path), gr.State(""), gr.State(
|
|
|
+ model_status), gr.State(vs_list)
|
|
|
gr.Markdown(webui_title)
|
|
|
with gr.Tab("对话"):
|
|
|
with gr.Row():
|
|
@@ -175,7 +183,7 @@ with gr.Blocks(css=block_css) as demo:
|
|
|
elem_id="chat-box",
|
|
|
show_label=False).style(height=750)
|
|
|
query = gr.Textbox(show_label=False,
|
|
|
- placeholder="请输入提问内容,按回车进行提交").style(container=False)
|
|
|
+ placeholder="请输入提问内容,按回车进行提交").style(container=False)
|
|
|
with gr.Column(scale=5):
|
|
|
mode = gr.Radio(["LLM 对话", "知识库问答"],
|
|
|
label="请选择使用模式",
|
|
@@ -218,7 +226,7 @@ with gr.Blocks(css=block_css) as demo:
|
|
|
load_folder_button = gr.Button("上传文件夹并加载知识库")
|
|
|
# load_vs.click(fn=)
|
|
|
select_vs.change(fn=change_vs_name_input,
|
|
|
- inputs=[select_vs,chatbot],
|
|
|
+ inputs=[select_vs, chatbot],
|
|
|
outputs=[vs_name, vs_add, file2vs, vs_path, chatbot])
|
|
|
# 将上传的文件保存到content文件夹下,并更新下拉框
|
|
|
load_file_button.click(get_vector_store,
|
|
@@ -230,11 +238,11 @@ with gr.Blocks(css=block_css) as demo:
|
|
|
show_progress=True,
|
|
|
inputs=[select_vs, folder_files, chatbot],
|
|
|
outputs=[vs_path, folder_files, chatbot],
|
|
|
- )
|
|
|
- logger.setup([query, vs_path, chatbot, mode], "flagged")
|
|
|
+ )
|
|
|
+ logger.setup([query, vs_path, chatbot, mode], "flagged")
|
|
|
query.submit(get_answer,
|
|
|
- [query, vs_path, chatbot, mode],
|
|
|
- [chatbot, query])
|
|
|
+ [query, vs_path, chatbot, mode],
|
|
|
+ [chatbot, query])
|
|
|
with gr.Tab("模型配置"):
|
|
|
llm_model = gr.Radio(llm_model_dict_list,
|
|
|
label="LLM 模型",
|
|
@@ -250,8 +258,8 @@ with gr.Blocks(css=block_css) as demo:
|
|
|
label="使用p-tuning-v2微调过的模型",
|
|
|
interactive=True)
|
|
|
use_lora = gr.Checkbox(USE_LORA,
|
|
|
- label="使用lora微调的权重",
|
|
|
- interactive=True)
|
|
|
+ label="使用lora微调的权重",
|
|
|
+ interactive=True)
|
|
|
embedding_model = gr.Radio(embedding_model_dict_list,
|
|
|
label="Embedding 模型",
|
|
|
value=EMBEDDING_MODEL,
|
|
@@ -265,7 +273,8 @@ with gr.Blocks(css=block_css) as demo:
|
|
|
load_model_button = gr.Button("重新加载模型")
|
|
|
load_model_button.click(reinit_model,
|
|
|
show_progress=True,
|
|
|
- inputs=[llm_model, embedding_model, llm_history_len, use_ptuning_v2, use_lora, top_k, chatbot],
|
|
|
+ inputs=[llm_model, embedding_model, llm_history_len, use_ptuning_v2, use_lora, top_k,
|
|
|
+ chatbot],
|
|
|
outputs=chatbot
|
|
|
)
|
|
|
|