Prechádzať zdrojové kódy

Merge branch 'dev' into dev_tg

imClumsyPanda 2 rokov pred
rodič
commit
edab7c3c28
16 zmenil súbory, kde vykonal 388 pridanie a 122 odobranie
  1. 22 0
      CONTRIBUTING.md
  2. 34 0
      Dockerfile
  3. 30 18
      README.md
  4. 12 5
      README_en.md
  5. 4 3
      api.py
  6. 19 6
      chains/local_doc_qa.py
  7. 52 0
      chains/text_load.py
  8. 1 1
      cli_demo.py
  9. 6 0
      configs/model_config.py
  10. 21 4
      docs/FAQ.md
  11. BIN
      img/qr_code.jpg
  12. BIN
      img/qr_code_4.jpg
  13. BIN
      img/ui1.png
  14. BIN
      img/webui_0419.png
  15. 72 11
      models/chatglm_llm.py
  16. 115 74
      webui.py

+ 22 - 0
CONTRIBUTING.md

@@ -0,0 +1,22 @@
+# 贡献指南
+
+欢迎!我们是一个非常友好的社区,非常高兴您想要帮助我们让这个应用程序变得更好。但是,请您遵循一些通用准则以保持组织有序。
+
+1. 确保为您要修复的错误或要添加的功能创建了一个[问题](https://github.com/imClumsyPanda/langchain-ChatGLM/issues),尽可能保持它们小。
+2. 请使用 `git pull --rebase` 来拉取和衍合上游的更新。
+3. 将提交合并为格式良好的提交。在提交说明中单独一行提到要解决的问题,如`Fix #<bug>`(有关更多可以使用的关键字,请参见[将拉取请求链接到问题](https://docs.github.com/en/issues/tracking-your-work-with-issues/linking-a-pull-request-to-an-issue))。
+4. 推送到`dev`。在说明中提到正在解决的问题。
+
+---
+
+# Contribution Guide
+
+Welcome! We're a pretty friendly community, and we're thrilled that you want to help make this app even better. However, we ask that you follow some general guidelines to keep things organized around here.
+
+1. Make sure an [issue](https://github.com/imClumsyPanda/langchain-ChatGLM/issues) is created for the bug you're about to fix, or feature you're about to add. Keep them as small as possible.
+
+2. Please use `git pull --rebase` to fetch and merge updates from the upstream.
+
+3. Rebase commits into well-formatted commits. Mention the issue being resolved in the commit message on a line all by itself like `Fixes #<bug>` (refer to [Linking a pull request to an issue](https://docs.github.com/en/issues/tracking-your-work-with-issues/linking-a-pull-request-to-an-issue) for more keywords you can use).
+
+4. Push into `dev`.  Mention which bug is being resolved in the description.

+ 34 - 0
Dockerfile

@@ -0,0 +1,34 @@
+FROM python:3.8
+
+MAINTAINER "chatGLM"
+
+COPY agent /chatGLM/agent
+
+COPY chains /chatGLM/chains
+
+COPY configs /chatGLM/configs
+
+COPY content /chatGLM/content
+
+COPY models /chatGLM/models
+
+COPY nltk_data /chatGLM/content
+
+COPY requirements.txt /chatGLM/
+
+COPY cli_demo.py /chatGLM/
+
+COPY webui.py /chatGLM/
+
+WORKDIR /chatGLM
+
+RUN pip install --user torch torchvision tensorboard cython -i https://pypi.tuna.tsinghua.edu.cn/simple
+# RUN pip install --user 'git+https://github.com/cocodataset/cocoapi.git#subdirectory=PythonAPI'
+
+# RUN pip install --user 'git+https://github.com/facebookresearch/fvcore'
+# install detectron2
+# RUN git clone https://github.com/facebookresearch/detectron2
+
+RUN pip install -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple/ --trusted-host pypi.tuna.tsinghua.edu.cn
+
+CMD ["python","-u", "webui.py"]

+ 30 - 18
README.md

@@ -4,11 +4,11 @@
 
 🌍 [_READ THIS IN ENGLISH_](README_en.md)
 
-🤖️ 一种利用 [ChatGLM-6B](https://github.com/THUDM/ChatGLM-6B) + [langchain](https://github.com/hwchase17/langchain) 实现的基于本地知识的 ChatGLM 应用。
+🤖️ 一种利用 [ChatGLM-6B](https://github.com/THUDM/ChatGLM-6B) + [langchain](https://github.com/hwchase17/langchain) 实现的基于本地知识的 ChatGLM 应用。增加 [clue-ai/ChatYuan](https://github.com/clue-ai/ChatYuan) 项目的模型 [ClueAI/ChatYuan-large-v2](https://huggingface.co/ClueAI/ChatYuan-large-v2) 的支持。
 
 💡 受 [GanymedeNil](https://github.com/GanymedeNil) 的项目 [document.ai](https://github.com/GanymedeNil/document.ai) 和 [AlexZhangji](https://github.com/AlexZhangji) 创建的 [ChatGLM-6B Pull Request](https://github.com/THUDM/ChatGLM-6B/pull/216) 启发,建立了全部基于开源模型实现的本地知识问答应用。
 
-✅ 本项目中 Embedding 选用的是 [GanymedeNil/text2vec-large-chinese](https://huggingface.co/GanymedeNil/text2vec-large-chinese/tree/main),LLM 选用的是 [ChatGLM-6B](https://github.com/THUDM/ChatGLM-6B)。依托上述模型,本项目可实现全部使用**开源**模型**离线私有部署**。
+✅ 本项目中 Embedding 默认选用的是 [GanymedeNil/text2vec-large-chinese](https://huggingface.co/GanymedeNil/text2vec-large-chinese/tree/main),LLM 默认选用的是 [ChatGLM-6B](https://github.com/THUDM/ChatGLM-6B)。依托上述模型,本项目可实现全部使用**开源**模型**离线私有部署**。
 
 ⛓️ 本项目实现原理如下图所示,过程包括加载文件 -> 读取文本 -> 文本分割 -> 文本向量化 -> 问句向量化 -> 在文本向量中匹配出与问句向量最相似的`top k`个 -> 匹配出的文本作为上下文和问题一起添加到`prompt`中 -> 提交给`LLM`生成回答。
 
@@ -22,9 +22,7 @@
 
 参见 [变更日志](docs/CHANGELOG.md)。
 
-## 使用方式
-
-### 硬件需求
+## 硬件需求
 
 - ChatGLM-6B 模型硬件需求
   
@@ -38,9 +36,19 @@
 
     本项目中默认选用的 Embedding 模型 [GanymedeNil/text2vec-large-chinese](https://huggingface.co/GanymedeNil/text2vec-large-chinese/tree/main) 约占用显存 3GB,也可修改为在 CPU 中运行。
 
+## Docker 部署
+
+```commandline
+$ docker build -t chatglm:v1.0 .
+
+$ docker run -d --restart=always --name chatglm -p 7860:7860 -v /www/wwwroot/code/langchain-ChatGLM:/chatGLM  chatglm
+```
+
+## 开发部署
+
 ### 软件需求
 
-本项目已在 Python 3.8,CUDA 11.7 环境下完成测试。
+本项目已在 Python 3.8 - 3.10,CUDA 11.7 环境下完成测试。已在 Windows、ARM 架构的 macOS、Linux 系统中完成测试。
 
 ### 从本地加载模型
 
@@ -72,7 +80,7 @@ $ python webui.py
 注:如未将模型下载至本地,请执行前检查`$HOME/.cache/huggingface/`文件夹剩余空间,至少15G。
 
 执行后效果如下图所示:
-![webui](img/ui1.png)
+![webui](img/webui_0419.png)
 Web UI 可以实现如下功能:
 
 1. 运行前自动读取`configs/model_config.py`中`LLM`及`Embedding`模型枚举及默认模型设置运行模型,如需重新加载模型,可在界面重新选择后点击`重新加载模型`进行模型加载;
@@ -115,25 +123,29 @@ Web UI 可以实现如下功能:
 
 ## 路线图
 
-- [x] 实现 langchain + ChatGLM-6B 本地知识应用
-- [x] 基于 langchain 实现非结构化文件接入
-  - [x] .md
-  - [x] .pdf(需要按照常见问题 Q2 中描述进行`detectron2`的安装)
-  - [x] .docx
-  - [x] .txt
+- [x] Langchain 应用
+  - [x] 接入非结构化文档(已支持 md、pdf、docx、txt 文件格式)
   - [ ] 搜索引擎与本地网页
+  - [ ] Agent 实现
 - [ ] 增加更多 LLM 模型支持
   - [x] THUDM/chatglm-6b
   - [x] THUDM/chatglm-6b-int4
   - [x] THUDM/chatglm-6b-int4-qe
-- [ ] 增加 Web UI DEMO
+  - [x] ClueAI/ChatYuan-large-v2
+- [ ] Web UI
   - [x] 利用 gradio 实现 Web UI DEMO
   - [x] 添加输出内容及错误提示
-  - [ ] 引用标注
-- [ ] 利用 fastapi 实现 API 部署方式,并实现调用 API 的 web ui DEMO
+  - [x] 引用标注
+  - [ ] 增加知识库管理
+    - [x] 选择知识库开始问答
+    - [x] 上传文件/文件夹至知识库
+    - [ ] 删除知识库中文件
+  - [ ] 利用 streamlit 实现 Web UI Demo
+- [ ] 增加 API 支持
+  - [x] 利用 fastapi 实现 API 部署方式
+  - [ ] 实现调用 API 的 web ui DEMO
 
 ## 项目交流群
-
-![二维码](img/qr_code.jpg)
+![二维码](img/qr_code_4.jpg)
 
 🎉 langchain-ChatGLM 项目交流群,如果你也对本项目感兴趣,欢迎加入群聊参与讨论交流。

+ 12 - 5
README_en.md

@@ -60,7 +60,7 @@
 - ChatGLM-6B Model Hardware Requirements
   
      | **Quantization Level** | **Minimum GPU Memory** (inference) | **Minimum GPU Memory** (efficient parameter fine-tuning) |
-     | -------------- | ------------------------- | -------- ------------------------- |
+     | -------------- | ------------------------- | --------------------------------- |
      | FP16 (no quantization) | 13 GB | 14 GB |
      | INT8 | 8 GB | 9 GB |
      | INT4 | 6 GB | 7 GB |
@@ -116,7 +116,7 @@ python webui.py
 Note: Before executing, check the remaining space in the `$HOME/.cache/huggingface/` folder, at least 15G.
 
 The resulting interface is shown below:
-![webui](img/ui1.png)
+![webui](img/webui_0419.png)
 The Web UI supports the following features:
 
 1. Automatically reads the `LLM` and `embedding` model enumerations in `configs/model_config.py`, allowing you to select and reload the model by clicking `重新加载模型`.
@@ -206,7 +206,14 @@ ChatGLM's answer after using LangChain to access the README.md file of the ChatG
    - [x] THUDM/chatglm-6b-int4
    - [x] THUDM/chatglm-6b-int4-qe
 - [ ] Add Web UI DEMO
-   - [x]  Implement Web UI DEMO using Gradio
+   - [x] Implement Web UI DEMO using Gradio
    - [x] Add output and error messages
-   - [ ] Citation callout
-- [ ] Use FastAPI to implement API deployment method and develop a Web UI DEMO for API calls
+   - [x] Citation callout
+   - [ ] Knowledge base management
+     - [x] QA based on selected knowledge base
+     - [x] Add files/folder to knowledge base
+     - [ ] Add files/folder to knowledge base
+   - [ ] Implement Web UI DEMO using Streamlit
+- [ ] Add support for API deployment
+  - [x] Use fastapi to implement API
+  - [ ] Implement Web UI DEMO for API calls

+ 4 - 3
api.py

@@ -54,10 +54,11 @@ async def upload_file(UserFile: UploadFile=File(...)):
         # print(UserFile.filename)
         with open(filepath, 'wb') as f:
             f.write(content)
-        vs_path = local_doc_qa.init_knowledge_vector_store(filepath)
+        vs_path, files = local_doc_qa.init_knowledge_vector_store(filepath)
         response = {
-            'msg': 'seccessful',
-            'status': 1
+            'msg': 'seccess' if len(files)>0 else 'fail',
+            'status': 1 if len(files)>0 else 0,
+            'loaded_files': files
         }
         
     except Exception as err:

+ 19 - 6
chains/local_doc_qa.py

@@ -44,10 +44,12 @@ class LocalDocQA:
                  llm_model: str = LLM_MODEL,
                  llm_device=LLM_DEVICE,
                  top_k=VECTOR_SEARCH_TOP_K,
+                 use_ptuning_v2: bool = USE_PTUNING_V2
                  ):
         self.llm = ChatGLM()
         self.llm.load_model(model_name_or_path=llm_model_dict[llm_model],
-                            llm_device=llm_device)
+                            llm_device=llm_device,
+                            use_ptuning_v2=use_ptuning_v2)
         self.llm.history_len = llm_history_len
 
         self.embeddings = HuggingFaceEmbeddings(model_name=embedding_model_dict[embedding_model], )
@@ -56,7 +58,9 @@ class LocalDocQA:
         self.top_k = top_k
 
     def init_knowledge_vector_store(self,
-                                    filepath: str or List[str]):
+                                    filepath: str or List[str],
+                                    vs_path: str or os.PathLike = None):
+        loaded_files = []
         if isinstance(filepath, str):
             if not os.path.exists(filepath):
                 print("路径不存在")
@@ -66,6 +70,7 @@ class LocalDocQA:
                 try:
                     docs = load_file(filepath)
                     print(f"{file} 已成功加载")
+                    loaded_files.append(filepath)
                 except Exception as e:
                     print(e)
                     print(f"{file} 未能成功加载")
@@ -77,6 +82,7 @@ class LocalDocQA:
                     try:
                         docs += load_file(fullfilepath)
                         print(f"{file} 已成功加载")
+                        loaded_files.append(fullfilepath)
                     except Exception as e:
                         print(e)
                         print(f"{file} 未能成功加载")
@@ -86,14 +92,21 @@ class LocalDocQA:
                 try:
                     docs += load_file(file)
                     print(f"{file} 已成功加载")
+                    loaded_files.append(file)
                 except Exception as e:
                     print(e)
                     print(f"{file} 未能成功加载")
 
-        vector_store = FAISS.from_documents(docs, self.embeddings)
-        vs_path = f"""./vector_store/{os.path.splitext(file)[0]}_FAISS_{datetime.datetime.now().strftime("%Y%m%d_%H%M%S")}"""
+        if vs_path and os.path.isdir(vs_path):
+            vector_store = FAISS.load_local(vs_path, self.embeddings)
+            vector_store.add_documents(docs)
+        else:
+            if not vs_path:
+                vs_path = f"""{VS_ROOT_PATH}{os.path.splitext(file)[0]}_FAISS_{datetime.datetime.now().strftime("%Y%m%d_%H%M%S")}"""
+            vector_store = FAISS.from_documents(docs, self.embeddings)
+
         vector_store.save_local(vs_path)
-        return vs_path if len(docs) > 0 else None
+        return vs_path if len(docs) > 0 else None, loaded_files
 
     def get_knowledge_based_answer(self,
                                    query,
@@ -123,7 +136,7 @@ class LocalDocQA:
         )
 
         knowledge_chain.return_source_documents = True
-
+        
         result = knowledge_chain({"query": query})
         self.llm.history[-1][0] = query
         return result, self.llm.history

+ 52 - 0
chains/text_load.py

@@ -0,0 +1,52 @@
+import os
+import pinecone 
+from tqdm import tqdm
+from langchain.llms import OpenAI
+from langchain.text_splitter import SpacyTextSplitter
+from langchain.document_loaders import TextLoader
+from langchain.document_loaders import DirectoryLoader
+from langchain.indexes import VectorstoreIndexCreator
+from langchain.embeddings.openai import OpenAIEmbeddings
+from langchain.vectorstores import Pinecone
+
+#一些配置文件
+openai_key="你的key" # 注册 openai.com 后获得
+pinecone_key="你的key" # 注册 app.pinecone.io 后获得
+pinecone_index="你的库" #app.pinecone.io 获得
+pinecone_environment="你的Environment"  # 登录pinecone后,在indexes页面 查看Environment
+pinecone_namespace="你的Namespace" #如果不存在自动创建
+
+#科学上网你懂得
+os.environ['HTTP_PROXY'] = 'http://127.0.0.1:7890'
+os.environ['HTTPS_PROXY'] = 'http://127.0.0.1:7890'
+
+#初始化pinecone
+pinecone.init(
+    api_key=pinecone_key,
+    environment=pinecone_environment
+)
+index = pinecone.Index(pinecone_index)
+
+#初始化OpenAI的embeddings
+embeddings = OpenAIEmbeddings(openai_api_key=openai_key)
+
+#初始化text_splitter
+text_splitter = SpacyTextSplitter(pipeline='zh_core_web_sm',chunk_size=1000,chunk_overlap=200)
+
+# 读取目录下所有后缀是txt的文件
+loader = DirectoryLoader('../docs', glob="**/*.txt", loader_cls=TextLoader)
+
+#读取文本文件
+documents = loader.load()
+
+# 使用text_splitter对文档进行分割
+split_text = text_splitter.split_documents(documents)
+try:
+	for document in tqdm(split_text):
+		# 获取向量并储存到pinecone
+		Pinecone.from_documents([document], embeddings, index_name=pinecone_index)
+except Exception as e:
+    print(f"Error: {e}")
+    quit()
+
+

+ 1 - 1
cli_demo.py

@@ -24,7 +24,7 @@ if __name__ == "__main__":
     vs_path = None
     while not vs_path:
         filepath = input("Input your local knowledge file path 请输入本地知识文件路径:")
-        vs_path = local_doc_qa.init_knowledge_vector_store(filepath)
+        vs_path, _ = local_doc_qa.init_knowledge_vector_store(filepath)
     history = []
     while True:
         query = input("Input your question 请输入问题:")

+ 6 - 0
configs/model_config.py

@@ -25,6 +25,7 @@ llm_model_dict = {
     "chatglm-6b-int4-qe": "THUDM/chatglm-6b-int4-qe",
     "chatglm-6b-int4": "THUDM/chatglm-6b-int4",
     "chatglm-6b": "THUDM/chatglm-6b",
+    "chatyuan": "ClueAI/ChatYuan-large-v2",
 }
 
 # LLM model name
@@ -35,3 +36,8 @@ USE_PTUNING_V2 = False
 
 # LLM running device
 LLM_DEVICE = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
+
+VS_ROOT_PATH = "./vector_store/"
+
+UPLOAD_ROOT_PATH = "./content/"
+

+ 21 - 4
docs/FAQ.md

@@ -20,18 +20,29 @@ $ pip install -e .
 
 Q3: 使用过程中 Python 包`nltk`发生了`Resource punkt not found.`报错,该如何解决?
 
-A3: https://github.com/nltk/nltk_data/raw/gh-pages/packages/tokenizers/punkt.zip 中的 `packages/tokenizers` 解压,放到  `nltk_data/tokenizers` 存储路径下。
+A3: 方法一:https://github.com/nltk/nltk_data/raw/gh-pages/packages/tokenizers/punkt.zip 中的 `packages/tokenizers` 解压,放到  `nltk_data/tokenizers` 存储路径下。
 
  `nltk_data` 存储路径可以通过 `nltk.data.path` 查询。
+ 
+ 方法二:执行python代码
+``` 
+import nltk
+nltk.download()
+``` 
 
 ---
 
 Q4: 使用过程中 Python 包`nltk`发生了`Resource averaged_perceptron_tagger not found.`报错,该如何解决?
 
-A4: 将 https://github.com/nltk/nltk_data/blob/gh-pages/packages/taggers/averaged_perceptron_tagger.zip 下载,解压放到 `nltk_data/taggers` 存储路径下。
-
- `nltk_data` 存储路径可以通过 `nltk.data.path` 查询。
+A4: 方法一:将 https://github.com/nltk/nltk_data/blob/gh-pages/packages/taggers/averaged_perceptron_tagger.zip 下载,解压放到 `nltk_data/taggers` 存储路径下。
 
+ `nltk_data` 存储路径可以通过 `nltk.data.path` 查询。  
+ 
+方法二:执行python代码
+``` 
+import nltk
+nltk.download()
+``` 
 ---
 
 Q5: 本项目可否在 colab 中运行?
@@ -101,4 +112,10 @@ embedding_model_dict = {
                         "text2vec": "/Users/liuqian/Downloads/ChatGLM-6B/text2vec-large-chinese"
 }
 ```
+---
 
+Q10: 执行`python cli_demo.py`过程中,显卡内存爆了,提示"OutOfMemoryError: CUDA out of memory"
+
+A10: 将 `VECTOR_SEARCH_TOP_K` 和 `LLM_HISTORY_LEN` 的值调低,比如 `VECTOR_SEARCH_TOP_K = 5` 和 `LLM_HISTORY_LEN = 2`,这样由 `query` 和 `context` 拼接得到的 `prompt` 会变短,会减少内存的占用。
+
+---

BIN
img/qr_code.jpg


BIN
img/qr_code_4.jpg


BIN
img/ui1.png


BIN
img/webui_0419.png


+ 72 - 11
models/chatglm_llm.py

@@ -1,7 +1,7 @@
 import sys
-
 sys.path.append("..")  # 将父目录放入系统路径中
 
+import json
 from langchain.llms.base import LLM
 from typing import Optional, List
 from langchain.llms.utils import enforce_stop_tokens
@@ -12,6 +12,8 @@ import torch
 from configs import *
 from enum import Enum
 
+from typing import Dict, Tuple, Union, Optional
+
 DEVICE = LLM_DEVICE
 DEVICE_ID = "0" if torch.cuda.is_available() else None
 CUDA_DEVICE = f"{DEVICE}:{DEVICE_ID}" if DEVICE_ID else DEVICE
@@ -29,7 +31,38 @@ def torch_gc():
             torch.cuda.ipc_collect()
 
 
-class ChatGLM():
+
+def auto_configure_device_map(num_gpus: int) -> Dict[str, int]:
+    # transformer.word_embeddings 占用1层
+    # transformer.final_layernorm 和 lm_head 占用1层
+    # transformer.layers 占用 28 层
+    # 总共30层分配到num_gpus张卡上
+    num_trans_layers = 28
+    per_gpu_layers = 30 / num_gpus
+
+    # bugfix: 在linux中调用torch.embedding传入的weight,input不在同一device上,导致RuntimeError
+    # windows下 model.device 会被设置成 transformer.word_embeddings.device
+    # linux下 model.device 会被设置成 lm_head.device
+    # 在调用chat或者stream_chat时,input_ids会被放到model.device上
+    # 如果transformer.word_embeddings.device和model.device不同,则会导致RuntimeError
+    # 因此这里将transformer.word_embeddings,transformer.final_layernorm,lm_head都放到第一张卡上
+    device_map = {'transformer.word_embeddings': 0,
+                  'transformer.final_layernorm': 0, 'lm_head': 0}
+
+    used = 2
+    gpu_target = 0
+    for i in range(num_trans_layers):
+        if used >= per_gpu_layers:
+            gpu_target += 1
+            used = 0
+        assert gpu_target < num_gpus
+        device_map[f'transformer.layers.{i}'] = gpu_target
+        used += 1
+
+    return device_map
+
+
+class ChatGLM(LLM):
     max_token: int = 10000
     temperature: float = 0.01
     top_p = 0.9
@@ -66,10 +99,25 @@ class ChatGLM():
             self.history = self.history + [[None, response]]
             return response
 
+    def chat(self,
+              prompt: str) -> str:
+        response, _ = self.model.chat(
+            self.tokenizer,
+            prompt,
+            history=[],#self.history[-self.history_len:] if self.history_len>0 else 
+            max_length=self.max_token,
+            temperature=self.temperature,
+        )
+        torch_gc()
+        self.history = self.history+[[None, response]]
+        return response
+        
     def load_model(self,
                    model_name_or_path: str = "THUDM/chatglm-6b",
                    llm_device=LLM_DEVICE,
-                   use_ptuning_v2=False):
+                   use_ptuning_v2=False,
+                   device_map: Optional[Dict[str, int]] = None,
+                   **kwargs):
         self.tokenizer = AutoTokenizer.from_pretrained(
             model_name_or_path,
             trust_remote_code=True
@@ -88,14 +136,27 @@ class ChatGLM():
                 print("加载PrefixEncoder config.json失败")
 
         if torch.cuda.is_available() and llm_device.lower().startswith("cuda"):
-            self.model = (
-                AutoModel.from_pretrained(
-                    model_name_or_path,
-                    config=model_config,
-                    trust_remote_code=True)
-                .half()
-                .cuda()
-            )
+            # 根据当前设备GPU数量决定是否进行多卡部署
+            num_gpus = torch.cuda.device_count()
+            if num_gpus < 2 and device_map is None:
+                self.model = (
+                    AutoModel.from_pretrained(
+                        model_name_or_path,
+                        config=model_config,
+                        trust_remote_code=True, 
+                        **kwargs)
+                    .half()
+                    .cuda()
+                )
+            else:
+                from accelerate import dispatch_model
+
+                model = AutoModel.from_pretrained(model_name_or_path, trust_remote_code=True, **kwargs).half()
+                # 可传入device_map自定义每张卡的部署情况
+                if device_map is None:
+                    device_map = auto_configure_device_map(num_gpus)
+
+                self.model = dispatch_model(model, device_map=device_map)
         else:
             self.model = (
                 AutoModel.from_pretrained(

+ 115 - 74
webui.py

@@ -14,19 +14,12 @@ VECTOR_SEARCH_TOP_K = 6
 LLM_HISTORY_LEN = 3
 
 
-def get_file_list():
-    if not os.path.exists("content"):
-        return []
-    return [f for f in os.listdir("content")]
-
-
 def get_vs_list():
-    if not os.path.exists("vector_store"):
+    if not os.path.exists(VS_ROOT_PATH):
         return []
-    return [f for f in os.listdir("vector_store")]
+    return ["新建知识库"] + os.listdir(VS_ROOT_PATH)
 
 
-file_list = get_file_list()
 vs_list = get_vs_list()
 
 embedding_model_dict_list = list(embedding_model_dict.keys())
@@ -36,19 +29,8 @@ llm_model_dict_list = list(llm_model_dict.keys())
 local_doc_qa = LocalDocQA()
 
 
-def upload_file(file, chatbot):
-    if not os.path.exists("content"):
-        os.mkdir("content")
-    filename = os.path.basename(file.name)
-    shutil.move(file.name, "content/" + filename)
-    # file_list首位插入新上传的文件
-    file_list.insert(0, filename)
-    status = "已将xx上传至xxx"
-    return chatbot + [None, status]
-
-
-def get_answer(query, vs_path, history):
-    if vs_path:
+def get_answer(query, vs_path, history, mode):
+    if vs_path and mode == "知识库问答":
         resp, history = local_doc_qa.get_knowledge_based_answer(
             query=query, vs_path=vs_path, chat_history=history)
         source = "".join([f"""<details> <summary>出处 {i + 1}</summary>
@@ -59,7 +41,7 @@ def get_answer(query, vs_path, history):
         history[-1][-1] += source
     else:
         resp = local_doc_qa.llm._call(query)
-        history = history + [[None, resp + "\n如需基于知识库进行问答,请先加载知识库后,再进行提问。"]]
+        history = history + [[query, resp + ("\n\n当前知识库为空,如需基于知识库进行问答,请先加载知识库后,再进行提问。" if mode == "知识库问答" else "")]]
     return history, ""
 
 
@@ -73,10 +55,10 @@ def init_model():
     try:
         local_doc_qa.init_cfg()
         local_doc_qa.llm._call("你好")
-        return """模型已成功加载,请选择文件后点击"加载文件"按钮"""
+        return """模型已成功加载,可以开始对话,或从右侧选择模式后开始对话"""
     except Exception as e:
         print(e)
-        return """模型未成功加载,请重新选择后点击"加载模型"按钮"""
+        return """模型未成功加载,请到页面左上角"模型配置"选项卡中重新选择后点击"加载模型"按钮"""
 
 
 def reinit_model(llm_model, embedding_model, llm_history_len, use_ptuning_v2, top_k, history):
@@ -86,24 +68,54 @@ def reinit_model(llm_model, embedding_model, llm_history_len, use_ptuning_v2, to
                               llm_history_len=llm_history_len,
                               use_ptuning_v2=use_ptuning_v2,
                               top_k=top_k)
-        model_status = """模型已成功重新加载,请选择文件后点击"加载文件"按钮"""
+        model_status = """模型已成功重新加载,可以开始对话,或从右侧选择模式后开始对话"""
     except Exception as e:
         print(e)
-        model_status = """模型未成功重新加载,请重新选择后点击"加载模型"按钮"""
+        model_status = """模型未成功重新加载,请到页面左上角"模型配置"选项卡中重新选择后点击"加载模型"按钮"""
     return history + [[None, model_status]]
 
 
-def get_vector_store(filepath, history):
+def get_vector_store(vs_id, files, history):
+    vs_path = VS_ROOT_PATH + vs_id
+    filelist = []
+    for file in files:
+        filename = os.path.split(file.name)[-1]
+        shutil.move(file.name, UPLOAD_ROOT_PATH + filename)
+        filelist.append(UPLOAD_ROOT_PATH + filename)
     if local_doc_qa.llm and local_doc_qa.embeddings:
-        vs_path = local_doc_qa.init_knowledge_vector_store(["content/" + filepath])
-        if vs_path:
-            file_status = "文件已成功加载,请开始提问"
+        vs_path, loaded_files = local_doc_qa.init_knowledge_vector_store(filelist, vs_path)
+        if len(loaded_files):
+            file_status = f"已上传 {'、'.join([os.path.split(i)[-1] for i in loaded_files])} 至知识库,并已加载知识库,请开始提问"
         else:
             file_status = "文件未成功加载,请重新上传文件"
     else:
         file_status = "模型未完成加载,请先在加载模型后再导入文件"
         vs_path = None
-    return vs_path, history + [[None, file_status]]
+    return vs_path, None, history + [[None, file_status]]
+
+
+def change_vs_name_input(vs_id):
+    if vs_id == "新建知识库":
+        return gr.update(visible=True), gr.update(visible=True), gr.update(visible=False), None
+    else:
+        return gr.update(visible=False), gr.update(visible=False), gr.update(visible=True), VS_ROOT_PATH + vs_id
+
+
+def change_mode(mode):
+    if mode == "知识库问答":
+        return gr.update(visible=True)
+    else:
+        return gr.update(visible=False)
+
+
+def add_vs_name(vs_name, vs_list, chatbot):
+    if vs_name in vs_list:
+        chatbot = chatbot + [[None, "与已有知识库名称冲突,请重新选择其他名称后提交"]]
+        return gr.update(visible=True), vs_list, chatbot
+    else:
+        chatbot = chatbot + [
+            [None, f"""已新增知识库"{vs_name}",将在上传文件并载入成功后进行存储。请在开始对话前,先完成文件上传。 """]]
+        return gr.update(visible=True, choices=vs_list + [vs_name], value=vs_name), vs_list + [vs_name], chatbot
 
 
 block_css = """.importantButton {
@@ -123,46 +135,88 @@ webui_title = """
 
 """
 
-init_message = """欢迎使用 langchain-ChatGLM Web UI,开始提问前,请依次如下 3 个步骤:
-1. 选择语言模型、Embedding 模型及相关参数,如果使用 ptuning-v2 方式微调过模型,将 PrefixEncoder 模型放在 ptuning-v2 文件夹里并勾选相关选项,然后点击"重新加载模型",并等待加载完成提示
-2. 上传或选择已有文件作为本地知识文档输入后点击"重新加载文档",并等待加载完成提示
-3. 输入要提交的问题后,点击回车提交 """
+init_message = """欢迎使用 langchain-ChatGLM Web UI!
+
+请在右侧切换模式,目前支持直接与 LLM 模型对话或基于本地知识库问答。
+
+知识库问答模式中,选择知识库名称后,即可开始问答,如有需要可以在选择知识库名称后上传文件/文件夹至知识库。
+
+知识库暂不支持文件删除,该功能将在后续版本中推出。
+"""
 
 model_status = init_model()
 
 with gr.Blocks(css=block_css) as demo:
-    vs_path, file_status, model_status = gr.State(""), gr.State(""), gr.State(model_status)
+    vs_path, file_status, model_status, vs_list = gr.State(""), gr.State(""), gr.State(model_status), gr.State(vs_list)
     gr.Markdown(webui_title)
-    with gr.Tab("聊天"):
+    with gr.Tab("对话"):
         with gr.Row():
-            with gr.Column(scale=2):
+            with gr.Column(scale=10):
                 chatbot = gr.Chatbot([[None, init_message], [None, model_status.value]],
                                      elem_id="chat-box",
                                      show_label=False).style(height=750)
                 query = gr.Textbox(show_label=False,
                                    placeholder="请输入提问内容,按回车进行提交",
                                    ).style(container=False)
-
-            with gr.Column(scale=1):
-                # with gr.Column():
-                # with gr.Tab("select"):
-                selectFile = gr.Dropdown(vs_list,
-                                         label="请选择要加载的知识库",
-                                         interactive=True,
-                                         value=vs_list[0] if len(vs_list) > 0 else None)
-                #
-                gr.Markdown("向知识库中添加文件")
-                with gr.Tab("上传文件"):
-                    files = gr.File(label="向知识库中添加文件",
-                                    file_types=['.txt', '.md', '.docx', '.pdf'],
-                                    file_count="multiple"
-                                    )  # .style(height=100)
-                with gr.Tab("上传文件夹"):
-                    files = gr.File(label="向知识库中添加文件",
-                                    file_types=['.txt', '.md', '.docx', '.pdf'],
-                                    file_count="directory"
-                                    )  # .style(height=100)
-                load_file_button = gr.Button("加载知识库")
+            with gr.Column(scale=5):
+                mode = gr.Radio(["LLM 对话", "知识库问答"],
+                                label="请选择使用模式",
+                                value="知识库问答", )
+                vs_setting = gr.Accordion("配置知识库")
+                mode.change(fn=change_mode,
+                            inputs=mode,
+                            outputs=vs_setting)
+                with vs_setting:
+                    select_vs = gr.Dropdown(vs_list.value,
+                                            label="请选择要加载的知识库",
+                                            interactive=True,
+                                            value=vs_list.value[0] if len(vs_list.value) > 0 else None
+                                            )
+                    vs_name = gr.Textbox(label="请输入新建知识库名称",
+                                         lines=1,
+                                         interactive=True)
+                    vs_add = gr.Button(value="添加至知识库选项")
+                    vs_add.click(fn=add_vs_name,
+                                 inputs=[vs_name, vs_list, chatbot],
+                                 outputs=[select_vs, vs_list, chatbot])
+
+                    file2vs = gr.Column(visible=False)
+                    with file2vs:
+                        # load_vs = gr.Button("加载知识库")
+                        gr.Markdown("向知识库中添加文件")
+                        with gr.Tab("上传文件"):
+                            files = gr.File(label="添加文件",
+                                            file_types=['.txt', '.md', '.docx', '.pdf'],
+                                            file_count="multiple",
+                                            show_label=False
+                                            )
+                            load_file_button = gr.Button("上传文件并加载知识库")
+                        with gr.Tab("上传文件夹"):
+                            folder_files = gr.File(label="添加文件",
+                                                   # file_types=['.txt', '.md', '.docx', '.pdf'],
+                                                   file_count="directory",
+                                                   show_label=False
+                                                   )
+                            load_folder_button = gr.Button("上传文件夹并加载知识库")
+                    # load_vs.click(fn=)
+                    select_vs.change(fn=change_vs_name_input,
+                                     inputs=select_vs,
+                                     outputs=[vs_name, vs_add, file2vs, vs_path])
+                    # 将上传的文件保存到content文件夹下,并更新下拉框
+                    load_file_button.click(get_vector_store,
+                                           show_progress=True,
+                                           inputs=[select_vs, files, chatbot],
+                                           outputs=[vs_path, files, chatbot],
+                                           )
+                    load_folder_button.click(get_vector_store,
+                                             show_progress=True,
+                                             inputs=[select_vs, folder_files, chatbot],
+                                             outputs=[vs_path, folder_files, chatbot],
+                                             )
+                    query.submit(get_answer,
+                                 [query, vs_path, chatbot, mode],
+                                 [chatbot, query],
+                                 )
     with gr.Tab("模型配置"):
         llm_model = gr.Radio(llm_model_dict_list,
                              label="LLM 模型",
@@ -172,7 +226,7 @@ with gr.Blocks(css=block_css) as demo:
                                     10,
                                     value=LLM_HISTORY_LEN,
                                     step=1,
-                                    label="LLM history len",
+                                    label="LLM 对话轮数",
                                     interactive=True)
         use_ptuning_v2 = gr.Checkbox(USE_PTUNING_V2,
                                      label="使用p-tuning-v2微调过的模型",
@@ -193,19 +247,6 @@ with gr.Blocks(css=block_css) as demo:
                             inputs=[llm_model, embedding_model, llm_history_len, use_ptuning_v2, top_k, chatbot],
                             outputs=chatbot
                             )
-    # 将上传的文件保存到content文件夹下,并更新下拉框
-    files.upload(upload_file,
-                 inputs=[files, chatbot],
-                 outputs=chatbot)
-    load_file_button.click(get_vector_store,
-                           show_progress=True,
-                           inputs=[selectFile, chatbot],
-                           outputs=[vs_path, chatbot],
-                           )
-    query.submit(get_answer,
-                 [query, vs_path, chatbot],
-                 [chatbot, query],
-                 )
 
 demo.queue(concurrency_count=3
            ).launch(server_name='0.0.0.0',