Browse Source

Merge remote-tracking branch 'upstream/dev' into dev

fengyu 2 years ago
parent
commit
53d3037cf2

+ 4 - 0
.gitignore

@@ -0,0 +1,4 @@
+output/*
+__pycache__/*
+log/*
+vector_store/*

+ 58 - 8
README.md

@@ -16,6 +16,8 @@
 
 🚩 本项目未涉及微调、训练过程,但可利用微调或训练对本项目效果进行优化。
 
+[TOC]
+
 ## 更新信息
 
 **[2023/04/07]** 
@@ -54,29 +56,68 @@
 
     本项目中默认选用的 Embedding 模型 [GanymedeNil/text2vec-large-chinese](https://huggingface.co/GanymedeNil/text2vec-large-chinese/tree/main) 约占用显存 3GB,也可修改为在 CPU 中运行。
 ### 软件需求
-本项目已在 python 3.8 环境下完成测试。
-### 1. 安装 python 依赖包
+本项目已在 python 3.8,cuda11.7 环境下完成测试。
+
+
+
+### 1. 安装环境
+
+- 环境检查
+
+```
+# 首先,确信你的机器安装了 Python 3.8 及以上版本
+$ python --version
+Python 3.8.13
+
+# 如果低于这个版本,可使用conda安装环境
+$ conda create -p /your_path/env_name python=3.8
+
+# 激活环境
+$ source activate /your_path/env_name
+
+# 关闭环境
+$ source deactivate /your_path/env_name
+
+# 删除环境
+$ conda env remove -p  /your_path/env_name
+```
+
+- 项目依赖
+
 ```commandline
-pip install -r requirements.txt
+
+# 拉取仓库
+$ git clone https://github.com/imClumsyPanda/langchain-ChatGLM.git
+
+# 安装依赖
+$ pip install -r requirements.txt
+
 ```
 注:使用 langchain.document_loaders.UnstructuredFileLoader 进行非结构化文件接入时,可能需要依据文档进行其他依赖包的安装,请参考 [langchain 文档](https://python.langchain.com/en/latest/modules/indexes/document_loaders/examples/unstructured_file.html)
 
+
+
 ### 2. 执行脚本体验 Web UI 或命令行交互
 执行 [webui.py](webui.py) 脚本体验 **Web 交互** <img src="https://img.shields.io/badge/Version-0.1-brightgreen">
 ```commandline
 python webui.py
 ```
+注:执行前检查`$HOME/.cache/huggingface/`文件夹剩余空间,至少15G
+
+
+
 执行后效果如下图所示:
 ![webui](img/ui1.png)
 Web UI 中提供的 API 接口如下图所示:
 ![webui](img/ui2.png)
 Web UI 可以实现如下功能:
+
 1. 自动读取`knowledge_based_chatglm.py`中`LLM`及`embedding`模型枚举,选择后点击`setting`进行模型加载,可随时切换模型进行测试
 2. 可手动调节保留对话历史长度,可根据显存大小自行调节
 3. 添加上传文件功能,通过下拉框选择已上传的文件,点击`loading`加载文件,过程中可随时更换加载的文件
 4. 底部添加`use via API`可对接到自己系统
 
-或执行 [knowledge_based_chatglm.py](knowledge_based_chatglm.py) 脚本体验**命令行交互**
+或执行 [knowledge_based_chatglm.py](cli_demo.py) 脚本体验**命令行交互**
 ```commandline
 python knowledge_based_chatglm.py
 ```
@@ -114,11 +155,20 @@ A5: 可以尝试使用 chatglm-6b-int4 模型在 colab 中运行,需要注意
 
 
 
-Q6: 本项目用到的模型权重文件百度网盘地址:
+Q6: 在Anaconda中使用pip安装包无效问题
+
+此问题是系统环境问题,详细见  [在Anaconda中使用pip安装包无效问题](docs/在Anaconda中使用pip安装包无效问题.md)
+
+
+Q7: 本项目用到的模型权重文件百度网盘地址:
+
+ernie-3.0-base-zh.zip 链接: https://pan.baidu.com/s/1CIvKnD3qzE-orFouA8qvNQ?pwd=4wih
+ernie-3.0-nano-zh.zip 链接: https://pan.baidu.com/s/1Fh8fgzVdavf5P1omAJJ-Zw?pwd=q6s5
+text2vec-large-chinese.zip 链接: https://pan.baidu.com/s/1sMyPzBIXdEzHygftEoyBuA?pwd=4xs7
+chatglm-6b-int4-qe.zip 链接: https://pan.baidu.com/s/1DDKMOMHtNZccOOBGWIOYww?pwd=22ji
+chatglm-6b-int4.zip 链接: https://pan.baidu.com/s/1pvZ6pMzovjhkA6uPcRLuJA?pwd=3gjd
+chatglm-6b.zip 链接: https://pan.baidu.com/s/1B-MpsVVs1GHhteVBetaquw?pwd=djay
 
-1. ernie-3.0-base-zh.zip  链接: https://pan.baidu.com/s/1CIvKnD3qzE-orFouA8qvNQ?pwd=4wih
-2. ernie-3.0-nano-zh.zip  链接: https://pan.baidu.com/s/1Fh8fgzVdavf5P1omAJJ-Zw?pwd=q6s5 
-3. 
 
 ## DEMO
 

+ 170 - 59
README_en.md

@@ -1,97 +1,208 @@
-# ChatGLM Application Based on Local Knowledge
+# ChatGLM Application with Local Knowledge Implementation
 
 ## Introduction
 
 🌍 [_中文文档_](README.md)
 
-🤖️ A local knowledge based LLM Application with [ChatGLM-6B](https://github.com/THUDM/ChatGLM-6B) and [langchain](https://github.com/hwchase17/langchain).
+🤖️ This is a ChatGLM application based on local knowledge, implemented using [ChatGLM-6B](https://github.com/THUDM/ChatGLM-6B) and [langchain](https://github.com/hwchase17/langchain).
 
-💡 Inspired by [document.ai](https://github.com/GanymedeNil/document.ai) by [GanymedeNil](https://github.com/GanymedeNil) and [ChatGLM-6B Pull Request](https://github.com/THUDM/ChatGLM-6B/pull/216) by [AlexZhangji](https://github.com/AlexZhangji).
+💡 Inspired by [document.ai](https://github.com/GanymedeNil/document.ai) and [Alex Zhangji](https://github.com/AlexZhangji)'s [ChatGLM-6B Pull Request](https://github.com/THUDM/ChatGLM-6B/pull/216), this project establishes a local knowledge question-answering application using open-source models.
 
-✅ In this project, [GanymedeNil/text2vec-large-chinese](https://huggingface.co/GanymedeNil/text2vec-large-chinese/tree/main) is used as Embedding Model,and [ChatGLM-6B](https://github.com/THUDM/ChatGLM-6B) used as LLM。Based on those models,this project can be deployed **offline** with all **open source** models。
+✅ The embeddings used in this project are [GanymedeNil/text2vec-large-chinese](https://huggingface.co/GanymedeNil/text2vec-large-chinese/tree/main), and the LLM is [ChatGLM-6B](https://github.com/THUDM/ChatGLM-6B). Relying on these models, this project enables the use of **open-source** models for **offline private deployment**.
 
-## Webui 
-![webui](./img/ui1.png)
-Click on steps 1-3 according to the above figure to complete the model loading, file loading, and viewing of dialogue history
+⛓️ The implementation principle of this project is illustrated in the figure below. The process includes loading files -> reading text -> text segmentation -> text vectorization -> question vectorization -> matching the top k most similar text vectors to the question vector -> adding the matched text to `prompt` along with the question as context -> submitting to `LLM` to generate an answer.
 
-![webui](./img/ui2.png)
-Click on the Use via API at the bottom to view the API interface. Existing applications can be docked and called through post requests
+![Implementation schematic diagram](img/langchain+chatglm.png)
 
-### TODO
--[] Add Model Load progress bar
--[] Add output content and error prompts
--[] International language switching
--[] Reference annotation
--[] Add plugin system (can be used for basic LORA training, etc.)
+🚩 This project does not involve fine-tuning or training; however, fine-tuning or training can be employed to optimize the effectiveness of this project.
 
-## Update
+[TOC]
 
-**[2023/04/11]** 
-1. Add Webui V0.1 version and synchronize the updated content before the current day;
-2. Automatically read knowledge_ based_ Enumerate LLM and embedding models in chatglm.py, select and click 'setting' to load the model. You can switch models for testing at any time
-3. The length of the conversation history can be manually adjusted and can be adjusted according to the size of the video memory
-4. Add the upload file function, select the uploaded file from the dropdown box, click loading to load the file, and the loaded file can be changed at any time during the process
-5. Add use via API at the bottom to connect to your own system
+## Changelog
 
 **[2023/04/07]**
-1. Fix bug which costs twice gpu memory (Thanks to [@suc16](https://github.com/suc16) and [@myml](https://github.com/myml)).
-2. Add gpu memory clear function after each call of ChatGLM.
-3. Add `nghuyong/ernie-3.0-nano-zh` and `nghuyong/ernie-3.0-base-zh` as Embedding model alternatives,costing less gpu than `GanymedeNil/text2vec-large-chinese` (Thanks to [@lastrei](https://github.com/lastrei))
+
+   1. Resolved the issue of doubled video memory usage when loading the ChatGLM model (thanks to [@suc16](https://github.com/suc16) and [@myml](https://github.com/myml));
+   2. Added a mechanism to clear video memory;
+   3. Added `nghuyong/ernie-3.0-nano-zh` and `nghuyong/ernie-3.0-base-zh` as Embedding model options, which consume less video memory resources than `GanymedeNil/text2vec-large-chinese` (thanks to [@lastrei](https://github.com/lastrei)).
 
 **[2023/04/09]**
-1. Using `RetrievalQA` in `langchain` to replace the previously selected `ChatVectorDBChain`, the replacement can effectively solve the problem of program stopping after 2-3 questions due to insufficient gpu memory.
-2. Add `EMBEDDING_MODEL`, `VECTOR_SEARCH_TOP_K`, `LLM_MODEL`, `LLM_HISTORY_LEN`, `REPLY_WITH_SOURCE` parameter value settings in `knowledge_based_chatglm.py`.
-3. Add `chatglm-6b-int4`, `chatglm-6b-int4-qe` with smaller GPU memory requirements as LLM model alternatives.
-4. Correct code errors in `README.md` (Thanks to [@calcitem](https://github.com/calcitem)).
 
-## Usage
+   1. Replaced the previously selected `ChatVectorDBChain` with `RetrievalQA` in `langchain`, effectively reducing the issue of stopping due to insufficient video memory after asking 2-3 times;
+   2. Added `EMBEDDING_MODEL`, `VECTOR_SEARCH_TOP_K`, `LLM_MODEL`, `LLM_HISTORY_LEN`, `REPLY_WITH_SOURCE` parameter value settings in `knowledge_based_chatglm.py`;
+   3. Added `chatglm-6b-int4` and `chatglm-6b-int4-qe`, which require less GPU memory, as LLM model options;
+   4. Corrected code errors in `README.md` (thanks to [@calcitem](https://github.com/calcitem)).
 
-### Hardware Requirements
+**[2023/04/11]**
+
+   1. Added Web UI V0.1 version (thanks to [@liangtongt](https://github.com/liangtongt));
+   2. Added Frequently Asked Questions in `README.md` (thanks to [@calcitem](https://github.com/calcitem) and [@bolongliu](https://github.com/bolongliu));
+   3. Enhanced automatic detection for the availability of `cuda`, `mps`, and `cpu` for LLM and Embedding model running devices;
+   4. Added a check for `filepath` in `knowledge_based_chatglm.py`. In addition to supporting single file import, it now supports a single folder path as input. After input, it will traverse each file in the folder and display a command-line message indicating the success of each file load.
+
+   **[2023/04/12]**
 
-- ChatGLM Hardware Requirements
+   1. Replaced the sample files in the Web UI to avoid issues with unreadable files due to encoding problems in Ubuntu;
+   2. Replaced the prompt template in `knowledge_based_chatglm.py` to prevent confusion in the content returned by ChatGLM, which may arise from the prompt template containing Chinese and English bilingual text.
 
-    | **Quantization Level** | **GPU Memory** |
-    |------------------------|----------------|
-    | FP16(no quantization)  | 13 GB          |
-    | INT8                   | 10 GB          |
-    | INT4                   | 6 GB           |
-- Embedding Hardware Requirements
+## How to Use
 
-   The default Embedding model in this repo is [GanymedeNil/text2vec-large-chinese](https://huggingface.co/GanymedeNil/text2vec-large-chinese/tree/main), 3GB GPU Memory required when running on GPU.
+### Hardware Requirements
+
+- ChatGLM-6B Model Hardware Requirements
+  
+     | **Quantization Level** | **Minimum GPU Memory** (inference) | **Minimum GPU Memory** (efficient parameter fine-tuning) |
+     | -------------- | ------------------------- | -------- ------------------------- |
+     | FP16 (no quantization) | 13 GB | 14 GB |
+     | INT8 | 8 GB | 9 GB |
+     | INT4 | 6 GB | 7 GB |
 
+- Embedding Model Hardware Requirements
+
+     The default Embedding model [GanymedeNil/text2vec-large-chinese](https://huggingface.co/GanymedeNil/text2vec-large-chinese/tree/main) in this project occupies around 3GB of video memory and can also be configured to run on a CPU.
 ### Software Requirements
-This repo has been tested in python 3.8 environment。
 
-### 1. install python packages
+This repository has been tested with Python 3.8 and CUDA 11.7 environments.
+
+### 1. Setting up the environment
+
+* Environment check
+
+```shell
+# First, make sure your machine has Python 3.8 or higher installed
+$ python --version
+Python 3.8.13
+
+# If your version is lower, you can use conda to install the environment
+$ conda create -p /your_path/env_name python=3.8
+
+# Activate the environment
+$ source activate /your_path/env_name
+
+# Deactivate the environment
+$ source deactivate /your_path/env_name
+
+# Remove the environment
+$ conda env remove -p  /your_path/env_name
+```
+
+* Project dependencies
+
+```shell
+
+# Clone the repository
+$ git clone https://github.com/imClumsyPanda/langchain-ChatGLM.git
+
+# Install dependencies
+$ pip install -r requirements.txt
+```
+
+Note: When using langchain.document_loaders.UnstructuredFileLoader for unstructured file integration, you may need to install other dependency packages according to the documentation. Please refer to [langchain documentation](https://python.langchain.com/en/latest/modules/indexes/document_loaders/examples/unstructured_file.html).
+
+### 2. Run Scripts to Experience Web UI or Command Line Interaction
+
+Execute [webui.py](webui.py) script to experience **Web interaction** <img src="https://img.shields.io/badge/Version-0.1-brightgreen">
 ```commandline
-pip install -r requirements.txt
+python webui.py
 ```
-Attention: With langchain.document_loaders.UnstructuredFileLoader used to connect with local knowledge file, you may need some other dependencies as mentioned in  [langchain documentation](https://python.langchain.com/en/latest/modules/indexes/document_loaders/examples/unstructured_file.html)
+Note: Before executing, check the remaining space in the `$HOME/.cache/huggingface/` folder, at least 15G.
+
+The resulting interface is shown below:
+![webui](img/ui1.png)
+The API interface provided in the Web UI is shown below:
+![webui](img/ui2.png)The Web UI supports the following features:
+
+1. Automatically reads the `LLM` and `embedding` model enumerations in `knowledge_based_chatglm.py`, allowing you to select and load the model by clicking `setting`. Models can be switched at any time for testing.
+2. The length of retained dialogue history can be manually adjusted according to the available video memory.
+3. Adds a file upload function. Select the uploaded file through the drop-down box, click `loading` to load the file, and change the loaded file at any time during the process.
+4. Adds a `use via API` option at the bottom to connect to your own system.
+
+Alternatively, execute the [knowledge_based_chatglm.py](https://chat.openai.com/chat/cli_demo.py) script to experience **command line interaction**:
 
-### 2. Run [knowledge_based_chatglm.py](knowledge_based_chatglm.py) script
 ```commandline
 python knowledge_based_chatglm.py
 ```
 
-### Known issues
-- Currently tested to support txt, docx, md format files, for more file formats please refer to [langchain documentation](https://python.langchain.com/en/latest/modules/indexes/document_loaders/examples/unstructured_file.html). If the document contains special characters, the file may not be correctly loaded.
-- When running this project with macOS, it may not work properly due to incompatibility with pytorch caused by macOS version 13.3 and above.
-
 ### FAQ
 
-Q: How to solve `Resource punkt not found.`?
+Q1: What file formats does this project support?
 
-A: Unzip `packages/tokenizers` in https://github.com/nltk/nltk_data/raw/gh-pages/packages/tokenizers/punkt.zip and put it in the corresponding directory of `Searched in:`.
+A1: Currently, this project has been tested with txt, docx, and md file formats. For more file formats, please refer to the [langchain documentation](https://python.langchain.com/en/latest/modules/indexes/document_loaders/examples/unstructured_file.html). It is known that if the document contains special characters, there might be issues with loading the file.
 
-Q: How to solve `Resource averaged_perceptron_tagger not found.`?
+Q2: How can I resolve the `detectron2` dependency issue when reading specific file formats?
 
-A: Download https://github.com/nltk/nltk_data/blob/gh-pages/packages/taggers/averaged_perceptron_tagger.zip, decompress it and put it in the corresponding directory of `Searched in:`.
+A2: As the installation process for this package can be problematic and it is only required for some file formats, it is not included in `requirements.txt`. You can install it with the following command:
 
-## Roadmap
-
-- [x] local knowledge based application with langchain + ChatGLM-6B
-- [x] unstructured files loaded with langchain
-- [ ] more different file format loaded with langchain
-- [ ] implement web ui DEMO with gradio/streamlit 
-- [ ] implement API with fastapi,and web ui DEMO with API
+```commandline
+pip install "detectron2@git+https://github.com/facebookresearch/detectron2.git@v0.6#egg=detectron2"
+```
 
+Q3: How can I solve the `Resource punkt not found.` error?
+
+A3: Unzip the `packages/tokenizers` folder from https://github.com/nltk/nltk_data/raw/gh-pages/packages/tokenizers/punkt.zip, and place it in the `nltk_data/tokenizers` storage path.
+
+The `nltk_data` storage path can be found using `nltk.data.path`.
+
+Q4: How can I solve the `Resource averaged_perceptron_tagger not found.` error?
+
+A4: Download https://github.com/nltk/nltk_data/blob/gh-pages/packages/taggers/averaged_perceptron_tagger.zip, extract it, and place it in the `nltk_data/taggers` storage path.
+
+The `nltk_data` storage path can be found using `nltk.data.path`.
+
+Q5: Can this project run in Google Colab?
+
+A5: You can try running the chatglm-6b-int4 model in Google Colab. Please note that if you want to run the Web UI in Colab, you need to set the `share` parameter in `demo.queue(concurrency_count=3).launch(server_name='0.0.0.0', share=False, inbrowser=False)` to `True`.
+
+This issue is related to the system environment. For more details, please refer to [Issues with installing packages using pip in Anaconda](docs/Issue-with-Installing-Packages-Using-pip-in-Anaconda.md).
+
+## DEMO
+
+Using the question "What are the limitations of ChatGLM-6B and how can they be improved?" as an example:
+
+Without using langchain to access local documents:
+
+> ChatGLM-6B is a chatbot based on GPT-3, and its limitations are mainly reflected in the following aspects:
+>
+>   1. Language model accuracy: Although GPT-3 is a large language model with relatively high accuracy, it still has some limitations. For instance, it may not be able to accurately understand certain dialects, accents, or complex sentence structures.
+>   2. Completeness of the knowledge graph: GPT-3 can generate some knowledge graphs, but its completeness and accuracy may not be sufficient. For example, it may struggle to accurately understand specific domain expertise and terminology.
+>   3. Coherence of dialogue generation: While GPT-3 can generate somewhat coherent dialogues, they may not always be natural. For instance, it may generate false, incoherent dialogue or misinterpret the user's intent.
+>
+>   To improve ChatGLM-6B, consider the following aspects:
+>
+>   1. Enhance the accuracy of the language model: The accuracy of the language model can be improved by increasing the training data, using better language model architectures, and optimizing the model training process.
+>   2. Strengthen the integrity of the knowledge graph: The integrity of the knowledge graph can be enhanced by increasing its dimensions, employing better knowledge graph construction methods, and optimizing the knowledge graph's matching algorithm.
+>   3. Boost the coherence of dialogue generation: The coherence of dialogue generation can be improved by augmenting the context information of the dialogue generation model, utilizing better context information extraction methods, and optimizing the dialogue generation model's algorithm.
+>   4. Incorporate more human feedback: Human feedback can help ChatGLM-6B better understand users' intentions and needs, thereby improving the quality and accuracy of dialogue generation. More human feedback can be introduced by involving more human agents and using human feedback collectors.
+
+ChatGLM's answer after using LangChain to access the README.md file of the ChatGLM-6B project:
+>ChatGLM-6B is a deep learning-based natural language processing model that excels at answering questions. However, it also has some limitations, as follows:
+>1. Limited language understanding: ChatGLM-6B has been primarily trained on Chinese natural language, and its understanding of other languages may be limited.
+>2. Insufficient knowledge base: The training dataset of ChatGLM-6B contains only a Chinese corpus, so it may not be able to answer non-Chinese questions or queries in specific domains.
+>3. Limited data volume: ChatGLM-6B's training dataset has only a few million records, which may hinder its ability to answer very specific or complex questions.
+>
+>To improve ChatGLM-6B, consider the following aspects:
+>1. Expand language knowledge: Learn natural language processing techniques in other languages to broaden the model's language understanding capabilities.
+>2. Broaden the knowledge base: Collect more Chinese corpora or use datasets in other languages to expand the model's knowledge base.
+>3. Increase data volume: Use larger datasets to train ChatGLM-6B, which can improve the model's performance.
+>4. Introduce more evaluation metrics: Incorporate additional evaluation metrics to assess the model's performance, which can help identify the shortcomings and limitations of ChatGLM-6B.
+>5. Enhance the model architecture: Improve ChatGLM-6B's model architecture to boost its performance and capabilities. For example, employ larger neural networks or refined convolutional neural network structures.
+
+## Road map
+
+- [x] Implement LangChain + ChatGLM-6B for local knowledge application
+- [x] Unstructured file access based on langchain
+   - [x].md
+   - [x].pdf (need to install `detectron2` as described in FAQ Q2)
+   - [x].docx
+   - [x].txt
+- [ ] Add support for more LLM models
+   - [x] THUDM/chatglm-6b
+   - [x] THUDM/chatglm-6b-int4
+   - [x] THUDM/chatglm-6b-int4-qe
+- [ ] Add Web UI DEMO
+   - [x]  Implement Web UI DEMO using Gradio
+   - [ ] Add model loading progress bar
+   - [ ] Add output and error messages
+   - [ ] Internationalization for language switching
+   - [ ] Citation callout
+- [ ] Use FastAPI to implement API deployment method and develop a Web UI DEMO for API calls

+ 116 - 0
chains/local_doc_qa.py

@@ -0,0 +1,116 @@
+from langchain.chains import RetrievalQA
+from langchain.prompts import PromptTemplate
+from langchain.embeddings.huggingface import HuggingFaceEmbeddings
+from langchain.vectorstores import FAISS
+from langchain.document_loaders import UnstructuredFileLoader
+from models.chatglm_llm import ChatGLM
+import sentence_transformers
+import os
+from configs.model_config import *
+import datetime
+from typing import List
+
+# return top-k text chunk from vector store
+VECTOR_SEARCH_TOP_K = 10
+
+# LLM input history length
+LLM_HISTORY_LEN = 3
+
+# Show reply with source text from input document
+REPLY_WITH_SOURCE = True
+
+
+class LocalDocQA:
+    llm: object = None
+    embeddings: object = None
+
+    def init_cfg(self,
+                 embedding_model: str = EMBEDDING_MODEL,
+                 embedding_device=EMBEDDING_DEVICE,
+                 llm_history_len: int = LLM_HISTORY_LEN,
+                 llm_model: str = LLM_MODEL,
+                 llm_device=LLM_DEVICE,
+                 top_k=VECTOR_SEARCH_TOP_K,
+                 ):
+        self.llm = ChatGLM()
+        self.llm.load_model(model_name_or_path=llm_model_dict[llm_model],
+                            llm_device=llm_device)
+        self.llm.history_len = llm_history_len
+
+        self.embeddings = HuggingFaceEmbeddings(model_name=embedding_model_dict[embedding_model], )
+        self.embeddings.client = sentence_transformers.SentenceTransformer(self.embeddings.model_name,
+                                                                           device=embedding_device)
+        self.top_k = top_k
+
+    def init_knowledge_vector_store(self,
+                                    filepath: str or List[str]):
+        if isinstance(filepath, str):
+            if not os.path.exists(filepath):
+                print("路径不存在")
+                return None
+            elif os.path.isfile(filepath):
+                file = os.path.split(filepath)[-1]
+                try:
+                    loader = UnstructuredFileLoader(filepath, mode="elements")
+                    docs = loader.load()
+                    print(f"{file} 已成功加载")
+                except:
+                    print(f"{file} 未能成功加载")
+                    return None
+            elif os.path.isdir(filepath):
+                docs = []
+                for file in os.listdir(filepath):
+                    fullfilepath = os.path.join(filepath, file)
+                    try:
+                        loader = UnstructuredFileLoader(fullfilepath, mode="elements")
+                        docs += loader.load()
+                        print(f"{file} 已成功加载")
+                    except:
+                        print(f"{file} 未能成功加载")
+        else:
+            docs = []
+            for file in filepath:
+                try:
+                    loader = UnstructuredFileLoader(file, mode="elements")
+                    docs += loader.load()
+                    print(f"{file} 已成功加载")
+                except:
+                    print(f"{file} 未能成功加载")
+
+        vector_store = FAISS.from_documents(docs, self.embeddings)
+        vs_path = f"""./vector_store/{os.path.splitext(file)[0]}_FAISS_{datetime.datetime.now().strftime("%Y%m%d_%H%M%S")}"""
+        vector_store.save_local(vs_path)
+        return vs_path
+
+    def get_knowledge_based_answer(self,
+                                   query,
+                                   vs_path,
+                                   chat_history=[], ):
+        prompt_template = """基于以下已知信息,简洁和专业的来回答用户的问题。
+    如果无法从中得到答案,请说 "根据已知信息无法回答该问题" 或 "没有提供足够的相关信息",不允许在答案中添加编造成分,答案请使用中文。
+    
+    已知内容:
+    {context}
+    
+    问题:
+    {question}"""
+        prompt = PromptTemplate(
+            template=prompt_template,
+            input_variables=["context", "question"]
+        )
+        self.llm.history = chat_history
+        vector_store = FAISS.load_local(vs_path, self.embeddings)
+        knowledge_chain = RetrievalQA.from_llm(
+            llm=self.llm,
+            retriever=vector_store.as_retriever(search_kwargs={"k": self.top_k}),
+            prompt=prompt
+        )
+        knowledge_chain.combine_documents_chain.document_prompt = PromptTemplate(
+            input_variables=["page_content"], template="{page_content}"
+        )
+
+        knowledge_chain.return_source_documents = True
+
+        result = knowledge_chain({"query": query})
+        self.llm.history[-1][0] = query
+        return result, self.llm.history

+ 33 - 0
cli_demo.py

@@ -0,0 +1,33 @@
+from configs.model_config import *
+from chains.local_doc_qa import LocalDocQA
+
+# return top-k text chunk from vector store
+VECTOR_SEARCH_TOP_K = 10
+
+# LLM input history length
+LLM_HISTORY_LEN = 3
+
+# Show reply with source text from input document
+REPLY_WITH_SOURCE = True
+
+if __name__ == "__main__":
+    local_doc_qa = LocalDocQA()
+    local_doc_qa.init_cfg(llm_model=LLM_MODEL,
+                          embedding_model=EMBEDDING_MODEL,
+                          embedding_device=EMBEDDING_DEVICE,
+                          llm_history_len=LLM_HISTORY_LEN,
+                          top_k=VECTOR_SEARCH_TOP_K)
+    vs_path = None
+    while not vs_path:
+        filepath = input("Input your local knowledge file path 请输入本地知识文件路径:")
+        vs_path = local_doc_qa.init_knowledge_vector_store(filepath)
+    history = []
+    while True:
+        query = input("Input your question 请输入问题:")
+        resp, history = local_doc_qa.get_knowledge_based_answer(query=query,
+                                                                vs_path=vs_path,
+                                                                chat_history=history)
+        if REPLY_WITH_SOURCE:
+            print(resp)
+        else:
+            print(resp["result"])

+ 29 - 0
configs/model_config.py

@@ -0,0 +1,29 @@
+import torch.cuda
+import torch.backends
+
+
+embedding_model_dict = {
+    "ernie-tiny": "nghuyong/ernie-3.0-nano-zh",
+    "ernie-base": "nghuyong/ernie-3.0-base-zh",
+    "text2vec": "GanymedeNil/text2vec-large-chinese",
+}
+
+# Embedding model name
+EMBEDDING_MODEL = "text2vec"
+
+# Embedding running device
+EMBEDDING_DEVICE = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
+
+# supported LLM models
+llm_model_dict = {
+    "chatglm-6b-int4-qe": "THUDM/chatglm-6b-int4-qe",
+    "chatglm-6b-int4": "THUDM/chatglm-6b-int4",
+    "chatglm-6b": "THUDM/chatglm-6b",
+}
+
+# LLM model name
+LLM_MODEL = "chatglm-6b"
+
+# LLM running device
+LLM_DEVICE = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
+

+ 0 - 0
content/langchain-ChatGLM README.md → content/langchain-ChatGLM_README.md


+ 114 - 0
docs/Issue-with-Installing-Packages-Using-pip-in-Anaconda.md

@@ -0,0 +1,114 @@
+## Issue with Installing Packages Using pip in Anaconda
+
+## Problem
+
+Recently, when running open-source code, I encountered an issue: after creating a virtual environment with conda and switching to the new environment, using pip to install packages would be "ineffective." Here, "ineffective" means that the packages installed with pip are not in this new environment.
+
+------
+
+## Analysis
+
+1. First, create a test environment called test: `conda create -n test`
+2. Activate the test environment: `conda activate test`
+3. Use pip to install numpy: `pip install numpy`. You'll find that numpy already exists in the default environment.
+
+```powershell
+Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple
+Requirement already satisfied: numpy in c:\programdata\anaconda3\lib\site-packages (1.20.3)
+```
+
+4. Check the information of pip: `pip show pip`
+
+```powershell
+Name: pip
+Version: 21.2.4
+Summary: The PyPA recommended tool for installing Python packages.
+Home-page: https://pip.pypa.io/
+Author: The pip developers
+Author-email: distutils-sig@python.org
+License: MIT
+Location: c:\programdata\anaconda3\lib\site-packages
+Requires:
+Required-by:
+```
+
+5. We can see that the current pip is in the default conda environment. This explains why the package is not in the new virtual environment when we directly use pip to install packages - because the pip being used belongs to the default environment, the installed package either already exists or is installed directly into the default environment.
+
+------
+
+## Solution
+
+1. We can directly use the conda command to install new packages, but sometimes conda may not have certain packages/libraries, so we still need to use pip to install.
+2. We can first use the conda command to install the pip package for the current virtual environment, and then use pip to install new packages.
+
+```powershell
+# Use conda to install the pip package
+(test) PS C:\Users\Administrator> conda install pip
+Collecting package metadata (current_repodata.json): done
+Solving environment: done
+....
+done
+
+# Display the information of the current pip, and find that pip is in the test environment
+(test) PS C:\Users\Administrator> pip show pip
+Name: pip
+Version: 21.2.4
+Summary: The PyPA recommended tool for installing Python packages.
+Home-page: https://pip.pypa.io/
+Author: The pip developers
+Author-email: distutils-sig@python.org
+License: MIT
+Location: c:\programdata\anaconda3\envs\test\lib\site-packages
+Requires:
+Required-by:
+
+# Now use pip to install the numpy package, and it is installed successfully
+(test) PS C:\Users\Administrator> pip install numpy
+Looking in indexes: 
+https://pypi.tuna.tsinghua.edu.cn/simple
+Collecting numpy
+  Using cached https://pypi.tuna.tsinghua.edu.cn/packages/4b/23/140ec5a509d992fe39db17200e96c00fd29603c1531ce633ef93dbad5e9e/numpy-1.22.2-cp39-cp39-win_amd64.whl (14.7 MB)
+Installing collected packages: numpy
+Successfully installed numpy-1.22.2
+
+# Use pip list to view the currently installed packages, no problem
+(test) PS C:\Users\Administrator> pip list
+Package      Version
+------------ ---------
+certifi      2021.10.8
+numpy        1.22.2
+pip          21.2.4
+setuptools   58.0.4
+wheel        0.37.1
+wincertstore 0.2
+```
+
+## Supplement
+
+1. The reason I didn't notice this problem before might be because the packages installed in the virtual environment were of a specific version, which overwrote the packages in the default environment. The main issue was actually a lack of careful observation:), otherwise, I could have noticed `Successfully uninstalled numpy-xxx` **default version** and `Successfully installed numpy-1.20.3` **specified version**.
+2. During testing, I found that if the Python version is specified when creating a new package, there shouldn't be this issue. I guess this is because pip will be installed in the virtual environment, while in our case, including pip, no packages were installed, so the default environment's pip was used.
+3. There's a question: I should have specified the Python version when creating a new virtual environment before, but I still used the default environment's pip package. However, I just couldn't reproduce the issue successfully on two different machines, which led to the second point mentioned above.
+4. After encountering the problem mentioned in point 3, I solved it by using `python -m pip install package-name`, adding `python -m` before pip. As for why, you can refer to the answer on [StackOverflow](https://stackoverflow.com/questions/41060382/using-pip-to-install-packages-to-anaconda-environment):
+
+>1. If you have a non-conda pip as your default pip but conda python as your default python (as below):
+>
+>```shell
+>>which -a pip
+>/home/<user>/.local/bin/pip   
+>/home/<user>/.conda/envs/newenv/bin/pip
+>/usr/bin/pip
+>
+>>which -a python
+>/home/<user>/.conda/envs/newenv/bin/python
+>/usr/bin/python
+>```
+>
+>2. Then, instead of calling `pip install <package>` directly, you can use the module flag -m in python so that it installs with the anaconda python
+>
+>```shell
+>python -m pip install <package>
+>```
+>
+>3. This will install the package to the anaconda library directory rather than the library directory associated with the (non-anaconda) pip
+>4. The reason for doing this is as follows: the pip command references a specific pip file/shortcut (which -a pip will tell you which one). Similarly, the python command references a specific python file (which -a python will tell you which one). For one reason or another, these two commands can become out of sync, so your "default" pip is in a different folder than your default python and therefore is associated with different versions of python.
+>5. In contrast, the python -m pip construct does not use the shortcut that the pip command points to. Instead, it asks python to find its pip version and use that version to install a package.

+ 125 - 0
docs/在Anaconda中使用pip安装包无效问题.md

@@ -0,0 +1,125 @@
+##  在 Anaconda 中使用 pip 安装包无效问题
+
+##  问题
+
+最近在跑开源代码的时候遇到的问题:使用 conda 创建虚拟环境并切换到新的虚拟环境后,再使用 pip 来安装包会“无效”。这里的“无效”指的是使用 pip 安装的包不在这个新的环境中。
+
+------
+
+## 分析
+
+1、首先创建一个测试环境 test,`conda create -n test`
+
+2、激活该测试环境,`conda activate test`
+
+3、使用 pip 安装 numpy,`pip install numpy`,会发现 numpy 已经存在默认的环境中
+
+```powershell
+Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple
+Requirement already satisfied: numpy in c:\programdata\anaconda3\lib\site-packages (1.20.3)
+```
+
+4、这时候看一下 pip 的信息,`pip show pip`
+
+```powershell
+Name: pip
+Version: 21.2.4
+Summary: The PyPA recommended tool for installing Python packages.
+Home-page: https://pip.pypa.io/
+Author: The pip developers
+Author-email: distutils-sig@python.org
+License: MIT
+Location: c:\programdata\anaconda3\lib\site-packages
+Requires:
+Required-by:
+```
+
+5、可以发现当前 pip 是在默认的 conda 环境中。这也就解释了当我们直接使用 pip 安装包时为什么包不在这个新的虚拟环境中,因为使用的 pip 属于默认环境,安装的包要么已经存在,要么直接装到默认环境中去了。
+
+------
+
+## 解决
+
+1、我们可以直接使用 conda 命令安装新的包,但有些时候 conda 可能没有某些包/库,所以还是得用 pip 安装
+
+2、我们可以先使用 conda 命令为当前虚拟环境安装 pip 包,再使用 pip 安装新的包
+
+```powershell
+# 使用 conda 安装 pip 包
+(test) PS C:\Users\Administrator> conda install pip
+Collecting package metadata (current_repodata.json): done
+Solving environment: done
+....
+done
+
+# 显示当前 pip 的信息,发现 pip 在测试环境 test 中
+(test) PS C:\Users\Administrator> pip show pip
+Name: pip
+Version: 21.2.4
+Summary: The PyPA recommended tool for installing Python packages.
+Home-page: https://pip.pypa.io/
+Author: The pip developers
+Author-email: distutils-sig@python.org
+License: MIT
+Location: c:\programdata\anaconda3\envs\test\lib\site-packages
+Requires:
+Required-by:
+
+# 再使用 pip 安装 numpy 包,成功安装
+(test) PS C:\Users\Administrator> pip install numpy
+Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple
+Collecting numpy
+  Using cached https://pypi.tuna.tsinghua.edu.cn/packages/4b/23/140ec5a509d992fe39db17200e96c00fd29603c1531ce633ef93dbad5e9e/numpy-1.22.2-cp39-cp39-win_amd64.whl (14.7 MB)
+Installing collected packages: numpy
+Successfully installed numpy-1.22.2
+
+# 使用 pip list 查看当前安装的包,没有问题
+(test) PS C:\Users\Administrator> pip list
+Package      Version
+------------ ---------
+certifi      2021.10.8
+numpy        1.22.2
+pip          21.2.4
+setuptools   58.0.4
+wheel        0.37.1
+wincertstore 0.2
+```
+
+------
+
+## 补充
+
+1、之前没有发现这个问题可能时因为在虚拟环境中安装的包是指定版本的,覆盖了默认环境中的包。其实主要还是观察不仔细:),不然可以发现 `Successfully uninstalled numpy-xxx`【默认版本】 以及 `Successfully installed numpy-1.20.3`【指定版本】
+
+2、测试时发现如果在新建包的时候指定了 python 版本的话应该是没有这个问题的,猜测时因为会在虚拟环境中安装好 pip ,而我们这里包括 pip 在内啥包也没有装,所以使用的是默认环境的 pip
+
+3、有个问题,之前我在创建新的虚拟环境时应该指定了 python 版本,但还是使用的默认环境的 pip 包,但是刚在在两台机器上都没有复现成功,于是有了上面的第 2 点
+
+4、出现了第 3 点的问题后,我当时是使用 `python -m pip install package-name` 解决的,在 pip 前面加上了 python -m。至于为什么,可以参考 [StackOverflow](https://stackoverflow.com/questions/41060382/using-pip-to-install-packages-to-anaconda-environment) 上的回答:
+
+> 1、如果你有一个非 conda 的 pip 作为你的默认 pip,但是 conda 的 python 是你的默认 python(如下):
+>
+> ```shell
+> >which -a pip
+> /home/<user>/.local/bin/pip   
+> /home/<user>/.conda/envs/newenv/bin/pip
+> /usr/bin/pip
+> 
+> >which -a python
+> /home/<user>/.conda/envs/newenv/bin/python
+> /usr/bin/python
+> ```
+>
+> 2、然后,而不是直接调用 `pip install <package>`,你可以在 python 中使用模块标志 -m,以便它使用 anaconda python 进行安装
+>
+> ```shell
+>python -m pip install <package>
+> ```
+>
+> 3、这将把包安装到 anaconda 库目录,而不是与(非anaconda) pip 关联的库目录
+> 
+> 4、这样做的原因如下:命令 pip 引用了一个特定的 pip 文件 / 快捷方式(which -a pip 会告诉你是哪一个)。类似地,命令 python 引用一个特定的 python 文件(which -a python 会告诉你是哪个)。由于这样或那样的原因,这两个命令可能变得不同步,因此你的“默认” pip 与你的默认 python 位于不同的文件夹中,因此与不同版本的 python 相关联。
+>
+> 5、与此相反,python -m pip 构造不使用 pip 命令指向的快捷方式。相反,它要求 python 找到它的pip 版本,并使用该版本安装一个包。
+
+-   

+ 0 - 124
knowledge_based_chatglm.py

@@ -1,124 +0,0 @@
-from langchain.chains import RetrievalQA
-from langchain.prompts import PromptTemplate
-from langchain.embeddings.huggingface import HuggingFaceEmbeddings
-from langchain.vectorstores import FAISS
-from langchain.document_loaders import UnstructuredFileLoader
-from chatglm_llm import ChatGLM
-import sentence_transformers
-import torch
-import os
-import readline
-
-
-# Global Parameters
-EMBEDDING_MODEL = "text2vec"
-VECTOR_SEARCH_TOP_K = 6
-LLM_MODEL = "chatglm-6b"
-LLM_HISTORY_LEN = 3
-DEVICE = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
-
-# Show reply with source text from input document
-REPLY_WITH_SOURCE = True
-
-embedding_model_dict = {
-    "ernie-tiny": "nghuyong/ernie-3.0-nano-zh",
-    "ernie-base": "nghuyong/ernie-3.0-base-zh",
-    "text2vec": "GanymedeNil/text2vec-large-chinese",
-}
-
-llm_model_dict = {
-    "chatglm-6b-int4-qe": "THUDM/chatglm-6b-int4-qe",
-    "chatglm-6b-int4": "THUDM/chatglm-6b-int4",
-    "chatglm-6b": "THUDM/chatglm-6b",
-}
-
-
-def init_cfg(LLM_MODEL, EMBEDDING_MODEL, LLM_HISTORY_LEN, V_SEARCH_TOP_K=6):
-    global chatglm, embeddings, VECTOR_SEARCH_TOP_K
-    VECTOR_SEARCH_TOP_K = V_SEARCH_TOP_K
-
-    chatglm = ChatGLM()
-    chatglm.load_model(model_name_or_path=llm_model_dict[LLM_MODEL])
-    chatglm.history_len = LLM_HISTORY_LEN
-
-    embeddings = HuggingFaceEmbeddings(model_name=embedding_model_dict[EMBEDDING_MODEL],)
-    embeddings.client = sentence_transformers.SentenceTransformer(embeddings.model_name,
-                                                                  device=DEVICE)
-
-
-def init_knowledge_vector_store(filepath:str):
-    if not os.path.exists(filepath):
-        print("路径不存在")
-        return None
-    elif os.path.isfile(filepath):
-        file = os.path.split(filepath)[-1]
-        try:
-            loader = UnstructuredFileLoader(filepath, mode="elements")
-            docs = loader.load()
-            print(f"{file} 已成功加载")
-        except:
-            print(f"{file} 未能成功加载")
-            return None
-    elif os.path.isdir(filepath):
-        docs = []
-        for file in os.listdir(filepath):
-            fullfilepath = os.path.join(filepath, file)
-            try:
-                loader = UnstructuredFileLoader(fullfilepath, mode="elements")
-                docs += loader.load()
-                print(f"{file} 已成功加载")
-            except:
-                print(f"{file} 未能成功加载")
-
-    vector_store = FAISS.from_documents(docs, embeddings)
-    return vector_store
-
-
-def get_knowledge_based_answer(query, vector_store, chat_history=[]):
-    global chatglm, embeddings
-
-    prompt_template = """基于以下已知信息,简洁和专业的来回答用户的问题。
-如果无法从中得到答案,请说 "根据已知信息无法回答该问题" 或 "没有提供足够的相关信息",不允许在答案中添加编造成分,答案请使用中文。
-
-已知内容:
-{context}
-
-问题:
-{question}"""
-    prompt = PromptTemplate(
-        template=prompt_template,
-        input_variables=["context", "question"]
-    )
-    chatglm.history = chat_history
-    knowledge_chain = RetrievalQA.from_llm(
-        llm=chatglm,
-        retriever=vector_store.as_retriever(search_kwargs={"k": VECTOR_SEARCH_TOP_K}),
-        prompt=prompt
-    )
-    knowledge_chain.combine_documents_chain.document_prompt = PromptTemplate(
-            input_variables=["page_content"], template="{page_content}"
-        )
-
-    knowledge_chain.return_source_documents = True
-
-    result = knowledge_chain({"query": query})
-    chatglm.history[-1][0] = query
-    return result, chatglm.history
-
-
-if __name__ == "__main__":
-    init_cfg(LLM_MODEL, EMBEDDING_MODEL, LLM_HISTORY_LEN)
-    vector_store = None
-    while not vector_store:
-        filepath = input("Input your local knowledge file path 请输入本地知识文件路径:")
-        vector_store = init_knowledge_vector_store(filepath)
-    history = []
-    while True:
-        query = input("Input your question 请输入问题:")
-        resp, history = get_knowledge_based_answer(query=query,
-                                                   vector_store=vector_store,
-                                                   chat_history=history)
-        if REPLY_WITH_SOURCE:
-            print(resp)
-        else:
-            print(resp["result"])

+ 8 - 12
chatglm_llm.py → models/chatglm_llm.py

@@ -3,8 +3,9 @@ from typing import Optional, List
 from langchain.llms.utils import enforce_stop_tokens
 from transformers import AutoTokenizer, AutoModel
 import torch
+from configs.model_config import LLM_DEVICE
 
-DEVICE = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
+DEVICE = LLM_DEVICE
 DEVICE_ID = "0" if torch.cuda.is_available() else None
 CUDA_DEVICE = f"{DEVICE}:{DEVICE_ID}" if DEVICE_ID else DEVICE
 
@@ -38,7 +39,7 @@ class ChatGLM(LLM):
         response, _ = self.model.chat(
             self.tokenizer,
             prompt,
-            history=self.history[-self.history_len:],
+            history=self.history[-self.history_len:] if self.history_len>0 else [],
             max_length=self.max_token,
             temperature=self.temperature,
         )
@@ -48,12 +49,14 @@ class ChatGLM(LLM):
         self.history = self.history+[[None, response]]
         return response
 
-    def load_model(self, model_name_or_path: str = "THUDM/chatglm-6b"):
+    def load_model(self,
+                   model_name_or_path: str = "THUDM/chatglm-6b",
+                   llm_device=LLM_DEVICE):
         self.tokenizer = AutoTokenizer.from_pretrained(
             model_name_or_path,
             trust_remote_code=True
         )
-        if torch.cuda.is_available():
+        if torch.cuda.is_available() and llm_device.lower().startswith("cuda"):
             self.model = (
                 AutoModel.from_pretrained(
                     model_name_or_path,
@@ -61,19 +64,12 @@ class ChatGLM(LLM):
                 .half()
                 .cuda()
             )
-        elif torch.backends.mps.is_available():
-            self.model = (
-                AutoModel.from_pretrained(
-                    model_name_or_path,
-                    trust_remote_code=True)
-                .float()
-                .to('mps')
-            )
         else:
             self.model = (
                 AutoModel.from_pretrained(
                     model_name_or_path,
                     trust_remote_code=True)
                 .float()
+                .to(llm_device)
             )
         self.model = self.model.eval()

+ 3 - 2
requirements.txt

@@ -1,4 +1,4 @@
-langchain>=0.0.120
+langchain>=0.0.124
 transformers==4.27.1
 unstructured[local-inference]
 layoutparser[layoutmodels,tesseract]
@@ -8,4 +8,5 @@ beautifulsoup4
 icetk
 cpm_kernels
 faiss-cpu
-gradio>=3.25.0
+gradio>=3.25.0
+detectron2@git+https://github.com/facebookresearch/detectron2.git@v0.6#egg=detectron2

+ 97 - 82
webui.py

@@ -1,7 +1,8 @@
 import gradio as gr
 import os
 import shutil
-import knowledge_based_chatglm as kb
+from chains.local_doc_qa import LocalDocQA
+from configs.model_config import *
 
 
 def get_file_list():
@@ -12,9 +13,11 @@ def get_file_list():
 
 file_list = get_file_list()
 
-embedding_model_dict_list = list(kb.embedding_model_dict.keys())
+embedding_model_dict_list = list(embedding_model_dict.keys())
 
-llm_model_dict_list = list(kb.llm_model_dict.keys())
+llm_model_dict_list = list(llm_model_dict.keys())
+
+local_doc_qa = LocalDocQA()
 
 
 def upload_file(file):
@@ -27,9 +30,9 @@ def upload_file(file):
     return gr.Dropdown.update(choices=file_list, value=filename)
 
 
-def get_answer(query, vector_store, history):
-    resp, history = kb.get_knowledge_based_answer(
-        query=query, vector_store=vector_store, chat_history=history)
+def get_answer(query, vs_path, history):
+    resp, history = local_doc_qa.get_knowledge_based_answer(
+        query=query, vs_path=vs_path, chat_history=history)
     return history, history
 
 
@@ -41,6 +44,29 @@ def get_file_status(history):
     return history + [[None, "文档已完成加载,请开始提问"]]
 
 
+def init_model():
+    try:
+        local_doc_qa.init_cfg()
+        return """模型已成功加载,请选择文件后点击"加载文件"按钮"""
+    except:
+        return """模型未成功加载,请重新选择后点击"加载模型"按钮"""
+
+
+def reinit_model(llm_model, embedding_model, llm_history_len, top_k):
+    local_doc_qa.init_cfg(llm_model=llm_model,
+                          embedding_model=embedding_model,
+                          llm_history_len=llm_history_len,
+                          top_k=top_k),
+
+
+def get_vector_store(filepath):
+    local_doc_qa.init_knowledge_vector_store("content/"+filepath)
+
+
+model_status = gr.State()
+history = gr.State([])
+vs_path = gr.State()
+model_status = init_model()
 with gr.Blocks(css="""
 .importantButton {
     background: linear-gradient(45deg, #7e0570,#5d1c99, #6e00ff) !important;
@@ -63,89 +89,78 @@ with gr.Blocks(css="""
     with gr.Row():
         with gr.Column(scale=2):
             chatbot = gr.Chatbot([[None, """欢迎使用 langchain-ChatGLM Web UI,开始提问前,请依次如下 3 个步骤:
-1. 选择语言模型、Embedding 模型及相关参数后点击"step.1: setting",并等待加载完成提示
-2. 上传或选择已有文件作为本地知识文档输入后点击"step.2 loading",并等待加载完成提示
-3. 输入要提交的问题后点击"step.3 asking" """]],
+1. 选择语言模型、Embedding 模型及相关参数后点击"重新加载模型",并等待加载完成提示
+2. 上传或选择已有文件作为本地知识文档输入后点击"重新加载文档",并等待加载完成提示
+3. 输入要提交的问题后,点击回车提交 """], [None, str(model_status)]],
                                  elem_id="chat-box",
                                  show_label=False).style(height=600)
-        with gr.Column(scale=1):
-            with gr.Column():
-                llm_model = gr.Radio(llm_model_dict_list,
-                                     label="llm model",
-                                     value="chatglm-6b",
-                                     interactive=True)
-                LLM_HISTORY_LEN = gr.Slider(1,
-                                            10,
-                                            value=3,
-                                            step=1,
-                                            label="LLM history len",
-                                            interactive=True)
-                embedding_model = gr.Radio(embedding_model_dict_list,
-                                           label="embedding model",
-                                           value="text2vec",
-                                           interactive=True)
-                VECTOR_SEARCH_TOP_K = gr.Slider(1,
-                                                20,
-                                                value=6,
-                                                step=1,
-                                                label="vector search top k",
-                                                interactive=True)
-                load_model_button = gr.Button("step.1:setting")
-                load_model_button.click(lambda *args:
-                                        kb.init_cfg(args[0], args[1], args[2], args[3]),
-                                        show_progress=True,
-                                        api_name="init_cfg",
-                                        inputs=[llm_model, embedding_model, VECTOR_SEARCH_TOP_K, LLM_HISTORY_LEN]
-                                        ).then(
-                    get_model_status, chatbot, chatbot
-                )
-
-            with gr.Column():
-                with gr.Tab("select"):
-                    selectFile = gr.Dropdown(file_list,
-                                             label="content file",
-                                             interactive=True,
-                                             value=file_list[0] if len(file_list) > 0 else None)
-                with gr.Tab("upload"):
-                    file = gr.File(label="content file",
-                                   file_types=['.txt', '.md', '.docx']
-                                   ).style(height=100)
-                    # 将上传的文件保存到content文件夹下,并更新下拉框
-                    file.upload(upload_file,
-                                inputs=file,
-                                outputs=selectFile)
-                history = gr.State([])
-                vector_store = gr.State()
-                load_button = gr.Button("step.2:loading")
-                load_button.click(lambda fileName:
-                                  kb.init_knowledge_vector_store(
-                                      "content/" + fileName),
-                                  show_progress=True,
-                                  api_name="init_knowledge_vector_store",
-                                  inputs=selectFile,
-                                  outputs=vector_store
-                                  ).then(
-                    get_file_status,
-                    chatbot,
-                    chatbot,
-                    show_progress=True,
-                )
-
-    with gr.Row():
-        with gr.Column(scale=2):
             query = gr.Textbox(show_label=False,
-                               placeholder="Prompts",
+                               placeholder="请提问",
                                lines=1,
                                value="用200字总结一下"
                                ).style(container=False)
+
         with gr.Column(scale=1):
-            generate_button = gr.Button("step.3:asking",
-                                        elem_classes="importantButton")
-            generate_button.click(get_answer,
-                                  [query, vector_store, chatbot],
-                                  [chatbot, history],
-                                  api_name="get_knowledge_based_answer"
-                                  )
+            llm_model = gr.Radio(llm_model_dict_list,
+                                 label="LLM 模型",
+                                 value="chatglm-6b",
+                                 interactive=True)
+            llm_history_len = gr.Slider(0,
+                                        10,
+                                        value=3,
+                                        step=1,
+                                        label="LLM history len",
+                                        interactive=True)
+            embedding_model = gr.Radio(embedding_model_dict_list,
+                                       label="Embedding 模型",
+                                       value="text2vec",
+                                       interactive=True)
+            top_k = gr.Slider(1,
+                              20,
+                              value=6,
+                              step=1,
+                              label="向量匹配 top k",
+                              interactive=True)
+            load_model_button = gr.Button("重新加载模型")
+
+            # with gr.Column():
+            with gr.Tab("select"):
+                selectFile = gr.Dropdown(file_list,
+                                         label="content file",
+                                         interactive=True,
+                                         value=file_list[0] if len(file_list) > 0 else None)
+            with gr.Tab("upload"):
+                file = gr.File(label="content file",
+                               file_types=['.txt', '.md', '.docx', '.pdf']
+                               )  # .style(height=100)
+            load_button = gr.Button("重新加载文件")
+    load_model_button.click(reinit_model,
+                            show_progress=True,
+                            api_name="init_cfg",
+                            inputs=[llm_model, embedding_model, llm_history_len, top_k]
+                            ).then(
+        get_model_status, chatbot, chatbot
+    )
+    # 将上传的文件保存到content文件夹下,并更新下拉框
+    file.upload(upload_file,
+                inputs=file,
+                outputs=selectFile)
+    load_button.click(get_vector_store,
+                      show_progress=True,
+                      api_name="init_knowledge_vector_store",
+                      inputs=selectFile,
+                      outputs=vs_path
+                      )#.then(
+    #     get_file_status,
+    #     chatbot,
+    #     chatbot,
+    #     show_progress=True,
+    # )
+    # query.submit(get_answer,
+    #              [query, vs_path, chatbot],
+    #              [chatbot, history],
+    #              api_name="get_knowledge_based_answer"
+    #              )
 
 demo.queue(concurrency_count=3).launch(
     server_name='0.0.0.0', share=False, inbrowser=False)