Ver Fonte

Merge remote-tracking branch 'origin/dev' into dev_llm

# Conflicts:
#	api.py
#	configs/model_config.py
#	models/chatglm_llm.py
#	requirements.txt
#	webui.py
glide-the há 2 anos atrás
pai
commit
7c1f46a641

+ 35 - 0
.github/ISSUE_TEMPLATE/bug_report.md

@@ -0,0 +1,35 @@
+---
+name: Bug 报告 / Bug Report
+about: 报告项目中的错误或问题 / Report errors or issues in the project
+title: "[BUG] 简洁阐述问题 / Concise description of the issue"
+labels: bug
+assignees: ''
+
+---
+
+**问题描述 / Problem Description**
+用简洁明了的语言描述这个问题 / Describe the problem in a clear and concise manner.
+
+**复现问题的步骤 / Steps to Reproduce**
+1. 执行 '...' / Run '...'
+2. 点击 '...' / Click '...'
+3. 滚动到 '...' / Scroll to '...'
+4. 问题出现 / Problem occurs
+
+**预期的结果 / Expected Result**
+描述应该出现的结果 / Describe the expected result.
+
+**实际结果 / Actual Result**
+描述实际发生的结果 / Describe the actual result.
+
+**环境信息 / Environment Information**
+- langchain-ChatGLM 版本/commit 号:(例如:v1.0.0 或 commit 123456) / langchain-ChatGLM version/commit number: (e.g., v1.0.0 or commit 123456)
+- 是否使用 Docker 部署(是/否):是 / Is Docker deployment used (yes/no): yes
+- 使用的模型(ChatGLM-6B / ClueAI/ChatYuan-large-v2 等):ChatGLM-6B / Model used (ChatGLM-6B / ClueAI/ChatYuan-large-v2, etc.): ChatGLM-6B
+- 使用的 Embedding 模型(GanymedeNil/text2vec-large-chinese 等):GanymedeNil/text2vec-large-chinese / Embedding model used (GanymedeNil/text2vec-large-chinese, etc.): GanymedeNil/text2vec-large-chinese
+- 操作系统及版本 / Operating system and version:
+- Python 版本 / Python version:
+- 其他相关环境信息 / Other relevant environment information:
+
+**附加信息 / Additional Information**
+添加与问题相关的任何其他信息 / Add any other information related to the issue.

+ 23 - 0
.github/ISSUE_TEMPLATE/feature_request.md

@@ -0,0 +1,23 @@
+---
+name: 功能请求 / Feature Request
+about: 为项目提出新功能或建议 / Propose new features or suggestions for the project
+title: "[FEATURE] 简洁阐述功能 / Concise description of the feature"
+labels: enhancement
+assignees: ''
+
+---
+
+**功能描述 / Feature Description**
+用简洁明了的语言描述所需的功能 / Describe the desired feature in a clear and concise manner.
+
+**解决的问题 / Problem Solved**
+解释此功能如何解决现有问题或改进项目 / Explain how this feature solves existing problems or improves the project.
+
+**实现建议 / Implementation Suggestions**
+如果可能,请提供关于如何实现此功能的建议 / If possible, provide suggestions on how to implement this feature.
+
+**替代方案 / Alternative Solutions**
+描述您考虑过的替代方案 / Describe alternative solutions you have considered.
+
+**其他信息 / Additional Information**
+添加与功能请求相关的任何其他信息 / Add any other information related to the feature request.

+ 36 - 0
Dockerfile

@@ -0,0 +1,36 @@
+FROM python:3.8
+
+MAINTAINER "chatGLM"
+
+COPY agent /chatGLM/agent
+
+COPY chains /chatGLM/chains
+
+COPY configs /chatGLM/configs
+
+COPY content /chatGLM/content
+
+COPY models /chatGLM/models
+
+COPY nltk_data /chatGLM/content
+
+COPY requirements.txt /chatGLM/
+
+COPY cli_demo.py /chatGLM/
+
+COPY textsplitter /chatGLM/
+
+COPY webui.py /chatGLM/
+
+WORKDIR /chatGLM
+
+RUN pip install --user torch torchvision tensorboard cython -i https://pypi.tuna.tsinghua.edu.cn/simple
+# RUN pip install --user 'git+https://github.com/cocodataset/cocoapi.git#subdirectory=PythonAPI'
+
+# RUN pip install --user 'git+https://github.com/facebookresearch/fvcore'
+# install detectron2
+# RUN git clone https://github.com/facebookresearch/detectron2
+
+RUN pip install -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple/ --trusted-host pypi.tuna.tsinghua.edu.cn
+
+CMD ["python","-u", "webui.py"]

+ 13 - 0
Dockerfile-cuda

@@ -0,0 +1,13 @@
+FROM  nvidia/cuda:12.1.0-runtime-ubuntu20.04
+LABEL MAINTAINER="chatGLM"
+
+COPY . /chatGLM/
+
+WORKDIR /chatGLM
+
+RUN apt-get update -y && apt-get install python3 python3-pip curl -y  && apt-get clean
+RUN curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py && python3 get-pip.py 
+
+RUN pip3 install -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple --trusted-host pypi.tuna.tsinghua.edu.cn && rm -rf `pip3 cache dir`
+
+CMD ["python3","-u", "webui.py"]

+ 53 - 18
README.md

@@ -4,27 +4,31 @@
 
 
 🌍 [_READ THIS IN ENGLISH_](README_en.md)
 🌍 [_READ THIS IN ENGLISH_](README_en.md)
 
 
-🤖️ 一种利用 [ChatGLM-6B](https://github.com/THUDM/ChatGLM-6B) + [langchain](https://github.com/hwchase17/langchain) 实现的基于本地知识的 ChatGLM 应用。
+🤖️ 一种利用 [ChatGLM-6B](https://github.com/THUDM/ChatGLM-6B) + [langchain](https://github.com/hwchase17/langchain) 实现的基于本地知识的 ChatGLM 应用。增加 [clue-ai/ChatYuan](https://github.com/clue-ai/ChatYuan) 项目的模型 [ClueAI/ChatYuan-large-v2](https://huggingface.co/ClueAI/ChatYuan-large-v2) 的支持。
 
 
 💡 受 [GanymedeNil](https://github.com/GanymedeNil) 的项目 [document.ai](https://github.com/GanymedeNil/document.ai) 和 [AlexZhangji](https://github.com/AlexZhangji) 创建的 [ChatGLM-6B Pull Request](https://github.com/THUDM/ChatGLM-6B/pull/216) 启发,建立了全部基于开源模型实现的本地知识问答应用。
 💡 受 [GanymedeNil](https://github.com/GanymedeNil) 的项目 [document.ai](https://github.com/GanymedeNil/document.ai) 和 [AlexZhangji](https://github.com/AlexZhangji) 创建的 [ChatGLM-6B Pull Request](https://github.com/THUDM/ChatGLM-6B/pull/216) 启发,建立了全部基于开源模型实现的本地知识问答应用。
 
 
-✅ 本项目中 Embedding 选用的是 [GanymedeNil/text2vec-large-chinese](https://huggingface.co/GanymedeNil/text2vec-large-chinese/tree/main),LLM 选用的是 [ChatGLM-6B](https://github.com/THUDM/ChatGLM-6B)。依托上述模型,本项目可实现全部使用**开源**模型**离线私有部署**。
+✅ 本项目中 Embedding 默认选用的是 [GanymedeNil/text2vec-large-chinese](https://huggingface.co/GanymedeNil/text2vec-large-chinese/tree/main),LLM 默认选用的是 [ChatGLM-6B](https://github.com/THUDM/ChatGLM-6B)。依托上述模型,本项目可实现全部使用**开源**模型**离线私有部署**。
 
 
 ⛓️ 本项目实现原理如下图所示,过程包括加载文件 -> 读取文本 -> 文本分割 -> 文本向量化 -> 问句向量化 -> 在文本向量中匹配出与问句向量最相似的`top k`个 -> 匹配出的文本作为上下文和问题一起添加到`prompt`中 -> 提交给`LLM`生成回答。
 ⛓️ 本项目实现原理如下图所示,过程包括加载文件 -> 读取文本 -> 文本分割 -> 文本向量化 -> 问句向量化 -> 在文本向量中匹配出与问句向量最相似的`top k`个 -> 匹配出的文本作为上下文和问题一起添加到`prompt`中 -> 提交给`LLM`生成回答。
 
 
 ![实现原理图](img/langchain+chatglm.png)
 ![实现原理图](img/langchain+chatglm.png)
 
 
+从文档处理角度来看,实现流程如下:
+
+![实现原理图2](img/langchain+chatglm2.png)
+
 🚩 本项目未涉及微调、训练过程,但可利用微调或训练对本项目效果进行优化。
 🚩 本项目未涉及微调、训练过程,但可利用微调或训练对本项目效果进行优化。
 
 
+🌐 [AutoDL 镜像](https://www.codewithgpu.com/i/imClumsyPanda/langchain-ChatGLM/langchain-ChatGLM)
+
 📓 [ModelWhale 在线运行项目](https://www.heywhale.com/mw/project/643977aa446c45f4592a1e59)
 📓 [ModelWhale 在线运行项目](https://www.heywhale.com/mw/project/643977aa446c45f4592a1e59)
 
 
 ## 变更日志
 ## 变更日志
 
 
 参见 [变更日志](docs/CHANGELOG.md)。
 参见 [变更日志](docs/CHANGELOG.md)。
 
 
-## 使用方式
-
-### 硬件需求
+## 硬件需求
 
 
 - ChatGLM-6B 模型硬件需求
 - ChatGLM-6B 模型硬件需求
   
   
@@ -38,9 +42,29 @@
 
 
     本项目中默认选用的 Embedding 模型 [GanymedeNil/text2vec-large-chinese](https://huggingface.co/GanymedeNil/text2vec-large-chinese/tree/main) 约占用显存 3GB,也可修改为在 CPU 中运行。
     本项目中默认选用的 Embedding 模型 [GanymedeNil/text2vec-large-chinese](https://huggingface.co/GanymedeNil/text2vec-large-chinese/tree/main) 约占用显存 3GB,也可修改为在 CPU 中运行。
 
 
+## Docker 部署
+为了能让容器使用主机GPU资源,需要在主机上安装 [NVIDIA Container Toolkit](https://github.com/NVIDIA/nvidia-container-toolkit)。具体安装步骤如下:
+```shell
+sudo apt-get update
+sudo apt-get install -y nvidia-container-toolkit-base
+sudo systemctl daemon-reload 
+sudo systemctl restart docker
+```
+安装完成后,可以使用以下命令编译镜像和启动容器:
+```
+docker build -f Dockerfile-cuda -t chatglm-cuda:latest .
+docker run --gpus all -d --name chatglm -p 7860:7860  chatglm-cuda:latest
+
+#若要使用离线模型,请配置好模型路径,然后此repo挂载到Container
+docker run --gpus all -d --name chatglm -p 7860:7860 -v ~/github/langchain-ChatGLM:/chatGLM  chatglm-cuda:latest
+```
+
+
+## 开发部署
+
 ### 软件需求
 ### 软件需求
 
 
-本项目已在 Python 3.8,CUDA 11.7 环境下完成测试。
+本项目已在 Python 3.8 - 3.10,CUDA 11.7 环境下完成测试。已在 Windows、ARM 架构的 macOS、Linux 系统中完成测试。
 
 
 ### 从本地加载模型
 ### 从本地加载模型
 
 
@@ -58,7 +82,7 @@
 
 
 > 注:鉴于环境部署过程中可能遇到问题,建议首先测试命令行脚本。建议命令行脚本测试可正常运行后再运行 Web UI。
 > 注:鉴于环境部署过程中可能遇到问题,建议首先测试命令行脚本。建议命令行脚本测试可正常运行后再运行 Web UI。
 
 
-执行 [knowledge_based_chatglm.py](cli_demo.py) 脚本体验**命令行交互**:
+执行 [cli_demo.py](cli_demo.py) 脚本体验**命令行交互**:
 ```shell
 ```shell
 $ python cli_demo.py
 $ python cli_demo.py
 ```
 ```
@@ -75,9 +99,11 @@ $ python webui.py
 ![webui](img/webui_0419.png)
 ![webui](img/webui_0419.png)
 Web UI 可以实现如下功能:
 Web UI 可以实现如下功能:
 
 
-1. 运行前自动读取`configs/model_config.py`中`LLM`及`Embedding`模型枚举及默认模型设置运行模型,如需重新加载模型,可在界面重新选择后点击`重新加载模型`进行模型加载;
-2. 可手动调节保留对话历史长度,可根据显存大小自行调节;
-3. 添加上传文件功能,通过下拉框选择已上传的文件,点击`加载文件`按钮,过程中可随时更换加载的文件。
+1. 运行前自动读取`configs/model_config.py`中`LLM`及`Embedding`模型枚举及默认模型设置运行模型,如需重新加载模型,可在 `模型配置` 标签页重新选择后点击 `重新加载模型` 进行模型加载;
+2. 可手动调节保留对话历史长度、匹配知识库文段数量,可根据显存大小自行调节;
+3. 具备模式选择功能,可选择 `LLM对话` 与 `知识库问答` 模式进行对话,支持流式对话;
+4. 添加 `配置知识库` 功能,支持选择已有知识库或新建知识库,并可向知识库中**新增**上传文件/文件夹,使用文件上传组件选择好文件后点击 `上传文件并加载知识库`,会将所选上传文档数据加载至知识库中,并基于更新后知识库进行问答;
+5. 后续版本中将会增加对知识库的修改或删除,及知识库中已导入文件的查看。
 
 
 ### 常见问题
 ### 常见问题
 
 
@@ -115,14 +141,23 @@ Web UI 可以实现如下功能:
 
 
 ## 路线图
 ## 路线图
 
 
-- [x] Langchain 应用
+- [ ] Langchain 应用
   - [x] 接入非结构化文档(已支持 md、pdf、docx、txt 文件格式)
   - [x] 接入非结构化文档(已支持 md、pdf、docx、txt 文件格式)
-  - [ ] 搜索引擎与本地网页
+  - [ ] 搜索引擎与本地网页接入
+  - [ ] 结构化数据接入(如 csv、Excel、SQL 等)
+  - [ ] 知识图谱/图数据库接入
   - [ ] Agent 实现
   - [ ] Agent 实现
 - [ ] 增加更多 LLM 模型支持
 - [ ] 增加更多 LLM 模型支持
-  - [x] THUDM/chatglm-6b
-  - [x] THUDM/chatglm-6b-int4
-  - [x] THUDM/chatglm-6b-int4-qe
+  - [x] [THUDM/chatglm-6b](https://huggingface.co/THUDM/chatglm-6b)
+  - [x] [THUDM/chatglm-6b-int8](https://huggingface.co/THUDM/chatglm-6b-int8)
+  - [x] [THUDM/chatglm-6b-int4](https://huggingface.co/THUDM/chatglm-6b-int4)
+  - [x] [THUDM/chatglm-6b-int4-qe](https://huggingface.co/THUDM/chatglm-6b-int4-qe)
+  - [x] [ClueAI/ChatYuan-large-v2](https://huggingface.co/ClueAI/ChatYuan-large-v2)
+- [ ] 增加更多 Embedding 模型支持
+  - [x] [nghuyong/ernie-3.0-nano-zh](https://huggingface.co/nghuyong/ernie-3.0-nano-zh)
+  - [x] [nghuyong/ernie-3.0-base-zh](https://huggingface.co/nghuyong/ernie-3.0-base-zh)
+  - [x] [shibing624/text2vec-base-chinese](https://huggingface.co/shibing624/text2vec-base-chinese)
+  - [x] [GanymedeNil/text2vec-large-chinese](https://huggingface.co/GanymedeNil/text2vec-large-chinese)
 - [ ] Web UI
 - [ ] Web UI
   - [x] 利用 gradio 实现 Web UI DEMO
   - [x] 利用 gradio 实现 Web UI DEMO
   - [x] 添加输出内容及错误提示
   - [x] 添加输出内容及错误提示
@@ -133,10 +168,10 @@ Web UI 可以实现如下功能:
     - [ ] 删除知识库中文件
     - [ ] 删除知识库中文件
   - [ ] 利用 streamlit 实现 Web UI Demo
   - [ ] 利用 streamlit 实现 Web UI Demo
 - [ ] 增加 API 支持
 - [ ] 增加 API 支持
-  - [x] 利用 fastapi 实现 API 部署方式
-  - [ ] 实现调用 API 的 web ui DEMO
+  - [ ] 利用 fastapi 实现 API 部署方式
+  - [ ] 实现调用 API 的 Web UI Demo
 
 
 ## 项目交流群
 ## 项目交流群
-![二维码](img/qr_code_4.jpg)
+![二维码](img/qr_code_10.jpg)
 
 
 🎉 langchain-ChatGLM 项目交流群,如果你也对本项目感兴趣,欢迎加入群聊参与讨论交流。
 🎉 langchain-ChatGLM 项目交流群,如果你也对本项目感兴趣,欢迎加入群聊参与讨论交流。

+ 0 - 114
api.py

@@ -1,114 +0,0 @@
-from configs.model_config import *
-from chains.local_doc_qa import LocalDocQA
-import os
-import nltk
-
-import uvicorn
-from fastapi import FastAPI, File, UploadFile
-from pydantic import BaseModel
-from starlette.responses import RedirectResponse
-from models.loader.args import parser
-import models.shared as shared
-from models.loader import LoaderLLM
-from models.chatglm_llm import ChatGLM
-
-app = FastAPI()
-
-global local_doc_qa, vs_path
-
-nltk.data.path = [os.path.join(os.path.dirname(__file__), "nltk_data")] + nltk.data.path
-
-# return top-k text chunk from vector store
-VECTOR_SEARCH_TOP_K = 10
-
-# LLM input history length
-LLM_HISTORY_LEN = 3
-
-# Show reply with source text from input document
-REPLY_WITH_SOURCE = False
-
-class Query(BaseModel):
-    query: str
-
-@app.get('/')
-async def document():
-    return RedirectResponse(url="/docs")
-
-@app.on_event("startup")
-async def get_local_doc_qa():
-    global local_doc_qa
-
-    chatGLMLLM = ChatGLM(shared.loaderLLM)
-    chatGLMLLM.history_len = LLM_HISTORY_LEN
-    local_doc_qa = LocalDocQA()
-    local_doc_qa.init_cfg(llm_model=chatGLMLLM,
-                          embedding_model=EMBEDDING_MODEL,
-                          embedding_device=EMBEDDING_DEVICE,
-                          top_k=VECTOR_SEARCH_TOP_K)
-    
-
-@app.post("/file")
-async def upload_file(UserFile: UploadFile=File(...)):
-    global vs_path
-    response = {
-        "msg": None,
-        "status": 0
-    }
-    try:
-        filepath = './content/' + UserFile.filename
-        content = await UserFile.read()
-        # print(UserFile.filename)
-        with open(filepath, 'wb') as f:
-            f.write(content)
-        vs_path, files = local_doc_qa.init_knowledge_vector_store(filepath)
-        response = {
-            'msg': 'seccess' if len(files)>0 else 'fail',
-            'status': 1 if len(files)>0 else 0,
-            'loaded_files': files
-        }
-        
-    except Exception as err:
-        response["message"] = err
-        
-    return response 
-
-@app.post("/qa")
-async def get_answer(UserQuery: Query):
-    response = {
-        "status": 0,
-        "message": "",
-        "answer": None
-    }
-    global vs_path
-    history = []
-    try:
-        resp, history = local_doc_qa.get_knowledge_based_answer(query=UserQuery.query,
-                                                                vs_path=vs_path,
-                                                                chat_history=history)
-        if REPLY_WITH_SOURCE:
-            response["answer"] = resp
-        else:
-            response['answer'] = resp["result"]
-        
-        response["message"] = 'successful'
-        response["status"] = 1
-
-    except Exception as err:
-        response["message"] = err
-        
-    return response
-
-
-if __name__ == "__main__":
-    args = None
-    args = parser.parse_args()
-    args_dict = vars(args)
-
-    shared.loaderLLM = LoaderLLM(args_dict)
-    uvicorn.run(
-        app='api:app', 
-        host='0.0.0.0', 
-        port=8100,
-        reload = True,
-        )
-

+ 158 - 46
chains/local_doc_qa.py

@@ -1,24 +1,30 @@
-from langchain.chains import RetrievalQA
-from langchain.prompts import PromptTemplate
 from langchain.embeddings.huggingface import HuggingFaceEmbeddings
 from langchain.embeddings.huggingface import HuggingFaceEmbeddings
 from langchain.vectorstores import FAISS
 from langchain.vectorstores import FAISS
 from langchain.document_loaders import UnstructuredFileLoader
 from langchain.document_loaders import UnstructuredFileLoader
 from langchain.llms.base import LLM
 from langchain.llms.base import LLM
 from models.chatglm_llm import ChatGLM
 from models.chatglm_llm import ChatGLM
-import sentence_transformers
-import os
 from configs.model_config import *
 from configs.model_config import *
 import datetime
 import datetime
-from typing import List
 from textsplitter import ChineseTextSplitter
 from textsplitter import ChineseTextSplitter
+from typing import List, Tuple
+from langchain.docstore.document import Document
+import numpy as np
+from utils import torch_gc
 
 
 # return top-k text chunk from vector store
 # return top-k text chunk from vector store
 VECTOR_SEARCH_TOP_K = 6
 VECTOR_SEARCH_TOP_K = 6
 
 
 
 
+DEVICE_ = EMBEDDING_DEVICE
+DEVICE_ID = "0" if torch.cuda.is_available() else None
+DEVICE = f"{DEVICE_}:{DEVICE_ID}" if DEVICE_ID else DEVICE_
+
 
 
 def load_file(filepath):
 def load_file(filepath):
-    if filepath.lower().endswith(".pdf"):
+    if filepath.lower().endswith(".md"):
+        loader = UnstructuredFileLoader(filepath, mode="elements")
+        docs = loader.load()
+    elif filepath.lower().endswith(".pdf"):
         loader = UnstructuredFileLoader(filepath)
         loader = UnstructuredFileLoader(filepath)
         textsplitter = ChineseTextSplitter(pdf=True)
         textsplitter = ChineseTextSplitter(pdf=True)
         docs = loader.load_and_split(textsplitter)
         docs = loader.load_and_split(textsplitter)
@@ -29,9 +35,88 @@ def load_file(filepath):
     return docs
     return docs
 
 
 
 
+def generate_prompt(related_docs: List[str],
+                    query: str,
+                    prompt_template=PROMPT_TEMPLATE) -> str:
+    context = "\n".join([doc.page_content for doc in related_docs])
+    prompt = prompt_template.replace("{question}", query).replace("{context}", context)
+    return prompt
+
+
+def get_docs_with_score(docs_with_score):
+    docs = []
+    for doc, score in docs_with_score:
+        doc.metadata["score"] = score
+        docs.append(doc)
+    return docs
+
+
+def seperate_list(ls: List[int]) -> List[List[int]]:
+    lists = []
+    ls1 = [ls[0]]
+    for i in range(1, len(ls)):
+        if ls[i - 1] + 1 == ls[i]:
+            ls1.append(ls[i])
+        else:
+            lists.append(ls1)
+            ls1 = [ls[i]]
+    lists.append(ls1)
+    return lists
+
+
+def similarity_search_with_score_by_vector(
+        self,
+        embedding: List[float],
+        k: int = 4,
+) -> List[Tuple[Document, float]]:
+    scores, indices = self.index.search(np.array([embedding], dtype=np.float32), k)
+    docs = []
+    id_set = set()
+    for j, i in enumerate(indices[0]):
+        if i == -1:
+            # This happens when not enough docs are returned.
+            continue
+        _id = self.index_to_docstore_id[i]
+        doc = self.docstore.search(_id)
+        id_set.add(i)
+        docs_len = len(doc.page_content)
+        for k in range(1, max(i, len(docs) - i)):
+            break_flag = False
+            for l in [i + k, i - k]:
+                if 0 <= l < len(self.index_to_docstore_id):
+                    _id0 = self.index_to_docstore_id[l]
+                    doc0 = self.docstore.search(_id0)
+                    if docs_len + len(doc0.page_content) > self.chunk_size:
+                        break_flag=True
+                        break
+                    elif doc0.metadata["source"] == doc.metadata["source"]:
+                        docs_len += len(doc0.page_content)
+                        id_set.add(l)
+            if break_flag:
+                break
+    id_list = sorted(list(id_set))
+    id_lists = seperate_list(id_list)
+    for id_seq in id_lists:
+        for id in id_seq:
+            if id == id_seq[0]:
+                _id = self.index_to_docstore_id[id]
+                doc = self.docstore.search(_id)
+            else:
+                _id0 = self.index_to_docstore_id[id]
+                doc0 = self.docstore.search(_id0)
+                doc.page_content += doc0.page_content
+        if not isinstance(doc, Document):
+            raise ValueError(f"Could not find document for id {_id}, got {doc}")
+        docs.append((doc, scores[0][j]))
+    torch_gc(DEVICE)
+    return docs
+
+
 class LocalDocQA:
 class LocalDocQA:
     llm: object = None
     llm: object = None
     embeddings: object = None
     embeddings: object = None
+    top_k: int = VECTOR_SEARCH_TOP_K
+    chunk_size: int = CHUNK_SIZE
 
 
     def __init__(self):
     def __init__(self):
         self.top_k = VECTOR_SEARCH_TOP_K
         self.top_k = VECTOR_SEARCH_TOP_K
@@ -44,9 +129,8 @@ class LocalDocQA:
                  ):
                  ):
         self.llm = llm_model
         self.llm = llm_model
 
 
-        self.embeddings = HuggingFaceEmbeddings(model_name=embedding_model_dict[embedding_model], )
-        self.embeddings.client = sentence_transformers.SentenceTransformer(self.embeddings.model_name,
-                                                                           device=embedding_device)
+        self.embeddings = HuggingFaceEmbeddings(model_name=embedding_model_dict[embedding_model],
+                                                model_kwargs={'device': embedding_device})
         self.top_k = top_k
         self.top_k = top_k
 
 
     def init_knowledge_vector_store(self,
     def init_knowledge_vector_store(self,
@@ -88,47 +172,75 @@ class LocalDocQA:
                 except Exception as e:
                 except Exception as e:
                     print(e)
                     print(e)
                     print(f"{file} 未能成功加载")
                     print(f"{file} 未能成功加载")
+        if len(docs) > 0:
+            if vs_path and os.path.isdir(vs_path):
+                vector_store = FAISS.load_local(vs_path, self.embeddings)
+                vector_store.add_documents(docs)
+                torch_gc(DEVICE)
+            else:
+                if not vs_path:
+                    vs_path = f"""{VS_ROOT_PATH}{os.path.splitext(file)[0]}_FAISS_{datetime.datetime.now().strftime("%Y%m%d_%H%M%S")}"""
+                vector_store = FAISS.from_documents(docs, self.embeddings)
+                torch_gc(DEVICE)
 
 
-        if vs_path and os.path.isdir(vs_path):
-            vector_store = FAISS.load_local(vs_path, self.embeddings)
-            vector_store.add_documents(docs)
+            vector_store.save_local(vs_path)
+            return vs_path, loaded_files
         else:
         else:
-            if not vs_path:
-                vs_path = f"""{VS_ROOT_PATH}{os.path.splitext(file)[0]}_FAISS_{datetime.datetime.now().strftime("%Y%m%d_%H%M%S")}"""
-            vector_store = FAISS.from_documents(docs, self.embeddings)
-
-        vector_store.save_local(vs_path)
-        return vs_path if len(docs) > 0 else None, loaded_files
+            print("文件均未成功加载,请检查依赖包或替换为其他文件再次上传。")
+            return None, loaded_files
 
 
     def get_knowledge_based_answer(self,
     def get_knowledge_based_answer(self,
                                    query,
                                    query,
                                    vs_path,
                                    vs_path,
-                                   chat_history=[], ):
-        prompt_template = """基于以下已知信息,简洁和专业的来回答用户的问题。
-    如果无法从中得到答案,请说 "根据已知信息无法回答该问题" 或 "没有提供足够的相关信息",不允许在答案中添加编造成分,答案请使用中文。
-    
-    已知内容:
-    {context}
-    
-    问题:
-    {question}"""
-        prompt = PromptTemplate(
-            template=prompt_template,
-            input_variables=["context", "question"]
-        )
-        self.llm.history = chat_history
+                                   chat_history=[],
+                                   streaming: bool = STREAMING):
         vector_store = FAISS.load_local(vs_path, self.embeddings)
         vector_store = FAISS.load_local(vs_path, self.embeddings)
-        knowledge_chain = RetrievalQA.from_llm(
-            llm=self.llm,
-            retriever=vector_store.as_retriever(search_kwargs={"k": self.top_k}),
-            prompt=prompt
-        )
-        knowledge_chain.combine_documents_chain.document_prompt = PromptTemplate(
-            input_variables=["page_content"], template="{page_content}"
-        )
-
-        knowledge_chain.return_source_documents = True
-
-        result = knowledge_chain({"query": query})
-        self.llm.history[-1][0] = query
-        return result, self.llm.history
+        FAISS.similarity_search_with_score_by_vector = similarity_search_with_score_by_vector
+        vector_store.chunk_size = self.chunk_size
+        related_docs_with_score = vector_store.similarity_search_with_score(query,
+                                                                            k=self.top_k)
+        related_docs = get_docs_with_score(related_docs_with_score)
+        prompt = generate_prompt(related_docs, query)
+
+        # if streaming:
+        #     for result, history in self.llm._stream_call(prompt=prompt,
+        #                                                  history=chat_history):
+        #         history[-1][0] = query
+        #         response = {"query": query,
+        #                     "result": result,
+        #                     "source_documents": related_docs}
+        #         yield response, history
+        # else:
+        for result, history in self.llm._call(prompt=prompt,
+                                              history=chat_history,
+                                              streaming=streaming):
+            history[-1][0] = query
+            response = {"query": query,
+                        "result": result,
+                        "source_documents": related_docs}
+            yield response, history
+
+
+if __name__ == "__main__":
+    local_doc_qa = LocalDocQA()
+    local_doc_qa.init_cfg()
+    query = "本项目使用的embedding模型是什么,消耗多少显存"
+    vs_path = "/Users/liuqian/Downloads/glm-dev/vector_store/aaa"
+    last_print_len = 0
+    for resp, history in local_doc_qa.get_knowledge_based_answer(query=query,
+                                                                 vs_path=vs_path,
+                                                                 chat_history=[],
+                                                                 streaming=True):
+        print(resp["result"][last_print_len:], end="", flush=True)
+        last_print_len = len(resp["result"])
+    source_text = [f"""出处 [{inum + 1}] {os.path.split(doc.metadata['source'])[-1]}:\n\n{doc.page_content}\n\n"""
+                   # f"""相关度:{doc.metadata['score']}\n\n"""
+                   for inum, doc in
+                   enumerate(resp["source_documents"])]
+    print("\n\n" + "\n\n".join(source_text))
+    # for resp, history in local_doc_qa.get_knowledge_based_answer(query=query,
+    #                                                              vs_path=vs_path,
+    #                                                              chat_history=[],
+    #                                                              streaming=False):
+    #     print(resp["result"])
+    pass

+ 34 - 0
chains/modules/embeddings.py

@@ -0,0 +1,34 @@
+from langchain.embeddings.huggingface import HuggingFaceEmbeddings
+
+from typing import Any, List
+
+
+class MyEmbeddings(HuggingFaceEmbeddings):
+    def __init__(self, **kwargs: Any):
+        super().__init__(**kwargs)
+        
+    def embed_documents(self, texts: List[str]) -> List[List[float]]:
+        """Compute doc embeddings using a HuggingFace transformer model.
+
+        Args:
+            texts: The list of texts to embed.
+
+        Returns:
+            List of embeddings, one for each text.
+        """
+        texts = list(map(lambda x: x.replace("\n", " "), texts))
+        embeddings = self.client.encode(texts, normalize_embeddings=True)
+        return embeddings.tolist()
+
+    def embed_query(self, text: str) -> List[float]:
+        """Compute query embeddings using a HuggingFace transformer model.
+
+        Args:
+            text: The text to embed.
+
+        Returns:
+            Embeddings for the text.
+        """
+        text = text.replace("\n", " ")
+        embedding = self.client.encode(text, normalize_embeddings=True)
+        return embedding.tolist()

+ 121 - 0
chains/modules/vectorstores.py

@@ -0,0 +1,121 @@
+from langchain.vectorstores import FAISS
+from typing import Any, Callable, List, Optional, Tuple, Dict
+from langchain.docstore.document import Document
+from langchain.docstore.base import Docstore
+
+from langchain.vectorstores.utils import maximal_marginal_relevance
+from langchain.embeddings.base import Embeddings
+import uuid
+from langchain.docstore.in_memory import InMemoryDocstore
+
+import numpy as np
+
+def dependable_faiss_import() -> Any:
+    """Import faiss if available, otherwise raise error."""
+    try:
+        import faiss
+    except ImportError:
+        raise ValueError(
+            "Could not import faiss python package. "
+            "Please install it with `pip install faiss` "
+            "or `pip install faiss-cpu` (depending on Python version)."
+        )
+    return faiss
+
+class FAISSVS(FAISS):
+    def __init__(self, 
+                 embedding_function: Callable[..., Any], 
+                 index: Any, 
+                 docstore: Docstore, 
+                 index_to_docstore_id: Dict[int, str]):
+        super().__init__(embedding_function, index, docstore, index_to_docstore_id)
+
+    def max_marginal_relevance_search_by_vector(
+        self, embedding: List[float], k: int = 4, fetch_k: int = 20, **kwargs: Any
+    ) -> List[Tuple[Document, float]]:
+        """Return docs selected using the maximal marginal relevance.
+
+        Maximal marginal relevance optimizes for similarity to query AND diversity
+        among selected documents.
+
+        Args:
+            embedding: Embedding to look up documents similar to.
+            k: Number of Documents to return. Defaults to 4.
+            fetch_k: Number of Documents to fetch to pass to MMR algorithm.
+
+        Returns:
+            List of Documents with scores selected by maximal marginal relevance.
+        """
+        scores, indices = self.index.search(np.array([embedding], dtype=np.float32), fetch_k)
+        # -1 happens when not enough docs are returned.
+        embeddings = [self.index.reconstruct(int(i)) for i in indices[0] if i != -1]
+        mmr_selected = maximal_marginal_relevance(
+            np.array([embedding], dtype=np.float32), embeddings, k=k
+        )
+        selected_indices = [indices[0][i] for i in mmr_selected]
+        selected_scores = [scores[0][i] for i in mmr_selected]
+        docs = []
+        for i, score in zip(selected_indices, selected_scores):
+            if i == -1:
+                # This happens when not enough docs are returned.
+                continue
+            _id = self.index_to_docstore_id[i]
+            doc = self.docstore.search(_id)
+            if not isinstance(doc, Document):
+                raise ValueError(f"Could not find document for id {_id}, got {doc}")
+            docs.append((doc, score))
+        return docs
+
+    def max_marginal_relevance_search(
+        self,
+        query: str,
+        k: int = 4,
+        fetch_k: int = 20,
+        **kwargs: Any,
+    ) -> List[Tuple[Document, float]]:
+        """Return docs selected using the maximal marginal relevance.
+
+        Maximal marginal relevance optimizes for similarity to query AND diversity
+        among selected documents.
+
+        Args:
+            query: Text to look up documents similar to.
+            k: Number of Documents to return. Defaults to 4.
+            fetch_k: Number of Documents to fetch to pass to MMR algorithm.
+
+        Returns:
+            List of Documents with scores selected by maximal marginal relevance.
+        """
+        embedding = self.embedding_function(query)
+        docs = self.max_marginal_relevance_search_by_vector(embedding, k, fetch_k)
+        return docs
+    
+    @classmethod
+    def __from(
+        cls,
+        texts: List[str],
+        embeddings: List[List[float]],
+        embedding: Embeddings,
+        metadatas: Optional[List[dict]] = None,
+        **kwargs: Any,
+    ) -> FAISS:
+        faiss = dependable_faiss_import()
+        index = faiss.IndexFlatIP(len(embeddings[0]))
+        index.add(np.array(embeddings, dtype=np.float32))
+
+        # # my code, for speeding up search
+        # quantizer = faiss.IndexFlatL2(len(embeddings[0]))
+        # index = faiss.IndexIVFFlat(quantizer, len(embeddings[0]), 100)
+        # index.train(np.array(embeddings, dtype=np.float32))
+        # index.add(np.array(embeddings, dtype=np.float32))
+
+        documents = []
+        for i, text in enumerate(texts):
+            metadata = metadatas[i] if metadatas else {}
+            documents.append(Document(page_content=text, metadata=metadata))
+        index_to_id = {i: str(uuid.uuid4()) for i in range(len(documents))}
+        docstore = InMemoryDocstore(
+            {index_to_id[i]: doc for i, doc in enumerate(documents)}
+        )
+        return cls(embedding.embed_query, index, docstore, index_to_id)
+

+ 15 - 6
cli_demo.py

@@ -39,10 +39,19 @@ if __name__ == "__main__":
     history = []
     history = []
     while True:
     while True:
         query = input("Input your question 请输入问题:")
         query = input("Input your question 请输入问题:")
-        resp, history = local_doc_qa.get_knowledge_based_answer(query=query,
-                                                                vs_path=vs_path,
-                                                                chat_history=history)
+        last_print_len = 0
+        for resp, history in local_doc_qa.get_knowledge_based_answer(query=query,
+                                                                     vs_path=vs_path,
+                                                                     chat_history=history,
+                                                                     streaming=STREAMING):
+            if STREAMING:
+                print(resp["result"][last_print_len:], end="", flush=True)
+                last_print_len = len(resp["result"])
+            else:
+                print(resp["result"])
         if REPLY_WITH_SOURCE:
         if REPLY_WITH_SOURCE:
-            print(resp)
-        else:
-            print(resp["result"])
+            source_text = [f"""出处 [{inum + 1}] {os.path.split(doc.metadata['source'])[-1]}:\n\n{doc.page_content}\n\n"""
+                           # f"""相关度:{doc.metadata['score']}\n\n"""
+                           for inum, doc in
+                           enumerate(resp["source_documents"])]
+            print("\n\n" + "\n\n".join(source_text))

+ 23 - 2
configs/model_config.py

@@ -2,10 +2,12 @@ import torch.cuda
 import torch.backends
 import torch.backends
 from models.chatglm_llm import *
 from models.chatglm_llm import *
 from models.llama_llm import LLamaLLM
 from models.llama_llm import LLamaLLM
+import os
 
 
 embedding_model_dict = {
 embedding_model_dict = {
     "ernie-tiny": "nghuyong/ernie-3.0-nano-zh",
     "ernie-tiny": "nghuyong/ernie-3.0-nano-zh",
     "ernie-base": "nghuyong/ernie-3.0-base-zh",
     "ernie-base": "nghuyong/ernie-3.0-base-zh",
+    "text2vec-base": "shibing624/text2vec-base-chinese",
     "text2vec": "GanymedeNil/text2vec-large-chinese",
     "text2vec": "GanymedeNil/text2vec-large-chinese",
 }
 }
 
 
@@ -33,11 +35,22 @@ llm_model_dict = {
         "path": "llama-7b-hf",
         "path": "llama-7b-hf",
         "provides": LLamaLLM
         "provides": LLamaLLM
     },
     },
+    "chatyuan": {
+        "path": "ClueAI/ChatYuan-large-v2",
+        "provides": None
+    },
+    "chatglm-6b-int8":{
+        "path":  "THUDM/chatglm-6b-int8",
+        "provides": ChatGLM
+    },
 }
 }
 
 
 # LLM model name
 # LLM model name
 LLM_MODEL = "chatglm-6b"
 LLM_MODEL = "chatglm-6b"
 
 
+# LLM streaming reponse
+STREAMING = True
+
 # Use p-tuning-v2 PrefixEncoder
 # Use p-tuning-v2 PrefixEncoder
 USE_PTUNING_V2 = False
 USE_PTUNING_V2 = False
 
 
@@ -47,7 +60,15 @@ NO_REMOTE_MODEL = False
 # LLM running device
 # LLM running device
 LLM_DEVICE = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
 LLM_DEVICE = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
 
 
-VS_ROOT_PATH = "./vector_store/"
+VS_ROOT_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "vector_store", "")
+
+UPLOAD_ROOT_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "content", "")
+
+# 基于上下文的prompt模版,请务必保留"{question}"和"{context}"
+PROMPT_TEMPLATE = """已知信息:
+{context} 
 
 
-UPLOAD_ROOT_PATH = "./content/"
+根据上述已知信息,简洁和专业的来回答用户的问题。如果无法从中得到答案,请说 “根据已知信息无法回答该问题” 或 “没有提供足够的相关信息”,不允许在答案中添加编造成分,答案请使用中文。 问题是:{question}"""
 
 
+# 匹配后单段上下文长度
+CHUNK_SIZE = 500

+ 23 - 6
docs/FAQ.md

@@ -20,18 +20,29 @@ $ pip install -e .
 
 
 Q3: 使用过程中 Python 包`nltk`发生了`Resource punkt not found.`报错,该如何解决?
 Q3: 使用过程中 Python 包`nltk`发生了`Resource punkt not found.`报错,该如何解决?
 
 
-A3: https://github.com/nltk/nltk_data/raw/gh-pages/packages/tokenizers/punkt.zip 中的 `packages/tokenizers` 解压,放到  `nltk_data/tokenizers` 存储路径下。
+A3: 方法一:https://github.com/nltk/nltk_data/raw/gh-pages/packages/tokenizers/punkt.zip 中的 `packages/tokenizers` 解压,放到  `nltk_data/tokenizers` 存储路径下。
 
 
  `nltk_data` 存储路径可以通过 `nltk.data.path` 查询。
  `nltk_data` 存储路径可以通过 `nltk.data.path` 查询。
+ 
+ 方法二:执行python代码
+``` 
+import nltk
+nltk.download()
+``` 
 
 
 ---
 ---
 
 
 Q4: 使用过程中 Python 包`nltk`发生了`Resource averaged_perceptron_tagger not found.`报错,该如何解决?
 Q4: 使用过程中 Python 包`nltk`发生了`Resource averaged_perceptron_tagger not found.`报错,该如何解决?
 
 
-A4: 将 https://github.com/nltk/nltk_data/blob/gh-pages/packages/taggers/averaged_perceptron_tagger.zip 下载,解压放到 `nltk_data/taggers` 存储路径下。
-
- `nltk_data` 存储路径可以通过 `nltk.data.path` 查询。
+A4: 方法一:将 https://github.com/nltk/nltk_data/blob/gh-pages/packages/taggers/averaged_perceptron_tagger.zip 下载,解压放到 `nltk_data/taggers` 存储路径下。
 
 
+ `nltk_data` 存储路径可以通过 `nltk.data.path` 查询。  
+ 
+方法二:执行python代码
+``` 
+import nltk
+nltk.download()
+``` 
 ---
 ---
 
 
 Q5: 本项目可否在 colab 中运行?
 Q5: 本项目可否在 colab 中运行?
@@ -84,7 +95,7 @@ Q9: 下载完模型后,如何修改代码以执行本地模型?
 
 
 A9: 模型下载完成后,请在 [configs/model_config.py](../configs/model_config.py) 文件中,对`embedding_model_dict`和`llm_model_dict`参数进行修改,如把`llm_model_dict`从
 A9: 模型下载完成后,请在 [configs/model_config.py](../configs/model_config.py) 文件中,对`embedding_model_dict`和`llm_model_dict`参数进行修改,如把`llm_model_dict`从
 
 
-```json
+```python
 embedding_model_dict = {
 embedding_model_dict = {
     "ernie-tiny": "nghuyong/ernie-3.0-nano-zh",
     "ernie-tiny": "nghuyong/ernie-3.0-nano-zh",
     "ernie-base": "nghuyong/ernie-3.0-base-zh",
     "ernie-base": "nghuyong/ernie-3.0-base-zh",
@@ -94,11 +105,17 @@ embedding_model_dict = {
 
 
 修改为
 修改为
 
 
-```json
+```python
 embedding_model_dict = {
 embedding_model_dict = {
                         "ernie-tiny": "nghuyong/ernie-3.0-nano-zh",
                         "ernie-tiny": "nghuyong/ernie-3.0-nano-zh",
                         "ernie-base": "nghuyong/ernie-3.0-base-zh",
                         "ernie-base": "nghuyong/ernie-3.0-base-zh",
                         "text2vec": "/Users/liuqian/Downloads/ChatGLM-6B/text2vec-large-chinese"
                         "text2vec": "/Users/liuqian/Downloads/ChatGLM-6B/text2vec-large-chinese"
 }
 }
 ```
 ```
+---
 
 
+Q10: 执行`python cli_demo.py`过程中,显卡内存爆了,提示"OutOfMemoryError: CUDA out of memory"
+
+A10: 将 `VECTOR_SEARCH_TOP_K` 和 `LLM_HISTORY_LEN` 的值调低,比如 `VECTOR_SEARCH_TOP_K = 5` 和 `LLM_HISTORY_LEN = 2`,这样由 `query` 和 `context` 拼接得到的 `prompt` 会变短,会减少内存的占用。
+
+---

BIN
img/langchain+chatglm2.png


BIN
img/qr_code_10.jpg


BIN
img/qr_code_4.jpg


+ 63 - 15
models/chatglm_llm.py

@@ -3,6 +3,11 @@ from langchain.llms.base import LLM
 from typing import Optional, List
 from typing import Optional, List
 from langchain.llms.utils import enforce_stop_tokens
 from langchain.llms.utils import enforce_stop_tokens
 
 
+from transformers import AutoTokenizer, AutoModel, AutoConfig
+import torch
+from configs.model_config import *
+from langchain.callbacks.base import CallbackManager
+from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
 from typing import Dict, Tuple, Union, Optional
 from typing import Dict, Tuple, Union, Optional
 from models.loader import LoaderLLM
 from models.loader import LoaderLLM
 
 
@@ -11,9 +16,12 @@ class ChatGLM(LLM):
     max_token: int = 10000
     max_token: int = 10000
     temperature: float = 0.01
     temperature: float = 0.01
     top_p = 0.9
     top_p = 0.9
-    history = []
     llm: LoaderLLM = None
     llm: LoaderLLM = None
+    # history = []
+    tokenizer: object = None
+    model: object = None
     history_len: int = 10
     history_len: int = 10
+    callback_manager = CallbackManager([StreamingStdOutCallbackHandler()])
 
 
     def __init__(self, llm: LoaderLLM = None):
     def __init__(self, llm: LoaderLLM = None):
         super().__init__()
         super().__init__()
@@ -25,17 +33,57 @@ class ChatGLM(LLM):
 
 
     def _call(self,
     def _call(self,
               prompt: str,
               prompt: str,
-              stop: Optional[List[str]] = None) -> str:
-        response, _ = self.llm.model.chat(
-            self.llm.tokenizer,
-            prompt,
-            history=self.history[-self.history_len:] if self.history_len > 0 else [],
-            max_length=self.max_token,
-            temperature=self.temperature,
-        )
-        self.llm.clear_torch_cache()
-        if stop is not None:
-            response = enforce_stop_tokens(response, stop)
-
-        self.history = self.history + [[None, response]]
-        return response
+              history: List[List[str]] = [],
+              streaming: bool = STREAMING):  # -> Tuple[str, List[List[str]]]:
+        if streaming:
+            for inum, (stream_resp, _) in enumerate(self.model.stream_chat(
+                    self.tokenizer,
+                    prompt,
+                    history=history[-self.history_len:-1] if self.history_len > 0 else [],
+                    max_length=self.max_token,
+                    temperature=self.temperature,
+            )):
+                torch_gc(DEVICE)
+                if inum == 0:
+                    history += [[prompt, stream_resp]]
+                else:
+                    history[-1] = [prompt, stream_resp]
+                yield stream_resp, history
+        else:
+            response, _ = self.model.chat(
+                    self.tokenizer,
+                    prompt,
+                    history=history[-self.history_len:] if self.history_len > 0 else [],
+                    max_length=self.max_token,
+                    temperature=self.temperature,
+            )
+            torch_gc(DEVICE)
+            history += [[prompt, response]]
+            yield response, history
+
+    # def chat(self,
+    #          prompt: str) -> str:
+    #     response, _ = self.model.chat(
+    #         self.tokenizer,
+    #         prompt,
+    #         history=self.history[-self.history_len:] if self.history_len > 0 else [],
+    #         max_length=self.max_token,
+    #         temperature=self.temperature,
+    #     )
+    #     torch_gc()
+    #     self.history = self.history + [[None, response]]
+    #     return response
+
+
+
+if __name__ == "__main__":
+    llm = ChatGLM()
+    llm.load_model(model_name_or_path=llm_model_dict[LLM_MODEL],
+                   llm_device=LLM_DEVICE, )
+    last_print_len=0
+    for resp, history in llm._call("你好", streaming=True):
+        print(resp[last_print_len:], end="", flush=True)
+        last_print_len = len(resp)
+    for resp, history in llm._call("你好", streaming=False):
+        print(resp)
+    pass

+ 5 - 4
requirements.txt

@@ -1,5 +1,5 @@
 langchain>=0.0.124
 langchain>=0.0.124
-transformers==4.28.1
+transformers==4.27.1
 unstructured[local-inference]
 unstructured[local-inference]
 layoutparser[layoutmodels,tesseract]
 layoutparser[layoutmodels,tesseract]
 nltk
 nltk
@@ -8,8 +8,9 @@ beautifulsoup4
 icetk
 icetk
 cpm_kernels
 cpm_kernels
 faiss-cpu
 faiss-cpu
-gradio>=3.25.0
-accelerate==0.18.0
+accelerate
+gradio==3.24.1
 llama-cpp-python==0.1.34; platform_system != "Windows"
 llama-cpp-python==0.1.34; platform_system != "Windows"
 https://github.com/abetlen/llama-cpp-python/releases/download/v0.1.34/llama_cpp_python-0.1.34-cp310-cp310-win_amd64.whl; platform_system == "Windows"
 https://github.com/abetlen/llama-cpp-python/releases/download/v0.1.34/llama_cpp_python-0.1.34-cp310-cp310-win_amd64.whl; platform_system == "Windows"
-peft
+peft
+#detectron2@git+https://github.com/facebookresearch/detectron2.git@v0.6#egg=detectron2

+ 11 - 0
utils/__init__.py

@@ -0,0 +1,11 @@
+import torch.cuda
+import torch.mps
+import torch.backends
+
+def torch_gc(DEVICE):
+    if torch.cuda.is_available():
+        with torch.cuda.device(DEVICE):
+            torch.cuda.empty_cache()
+            torch.cuda.ipc_collect()
+    elif torch.backends.mps.is_available():
+        torch.mps.empty_cache()

+ 48 - 32
webui.py

@@ -1,11 +1,6 @@
 import gradio as gr
 import gradio as gr
-import sys
 import os
 import os
 import shutil
 import shutil
-import asyncio
-from argparse import Namespace
-from models.loader.args import parser
-from models.loader import LoaderLLM
 from chains.local_doc_qa import LocalDocQA
 from chains.local_doc_qa import LocalDocQA
 from configs.model_config import *
 from configs.model_config import *
 import nltk
 import nltk
@@ -24,10 +19,10 @@ LLM_HISTORY_LEN = 3
 def get_vs_list():
 def get_vs_list():
     if not os.path.exists(VS_ROOT_PATH):
     if not os.path.exists(VS_ROOT_PATH):
         return []
         return []
-    return ["新建知识库"] + os.listdir(VS_ROOT_PATH)
+    return os.listdir(VS_ROOT_PATH)
 
 
 
 
-vs_list = get_vs_list()
+vs_list = ["新建知识库"] + get_vs_list()
 
 
 embedding_model_dict_list = list(embedding_model_dict.keys())
 embedding_model_dict_list = list(embedding_model_dict.keys())
 
 
@@ -36,20 +31,29 @@ llm_model_dict_list = list(llm_model_dict.keys())
 local_doc_qa = LocalDocQA()
 local_doc_qa = LocalDocQA()
 
 
 
 
-def get_answer(query, vs_path, history, mode):
-    if vs_path and mode == "知识库问答":
-        resp, history = local_doc_qa.get_knowledge_based_answer(
-            query=query, vs_path=vs_path, chat_history=history)
-        source = "".join([f"""<details> <summary>出处 {i + 1}</summary>
-{doc.page_content}
-
-<b>所属文件:</b>{doc.metadata["source"]}
-</details>""" for i, doc in enumerate(resp["source_documents"])])
-        history[-1][-1] += source
+def get_answer(query, vs_path, history, mode,
+               streaming: bool = STREAMING):
+    if mode == "知识库问答" and vs_path:
+        for resp, history in local_doc_qa.get_knowledge_based_answer(
+                query=query,
+                vs_path=vs_path,
+                chat_history=history,
+                streaming=streaming):
+            source = "\n\n"
+            source += "".join(
+                [f"""<details> <summary>出处 [{i + 1}] {os.path.split(doc.metadata["source"])[-1]}</summary>\n"""
+                 f"""{doc.page_content}\n"""
+                 f"""</details>"""
+                 for i, doc in
+                 enumerate(resp["source_documents"])])
+            history[-1][-1] += source
+            yield history, ""
     else:
     else:
-        resp = local_doc_qa.llm._call(query)
-        history = history + [[query, resp + ("\n\n当前知识库为空,如需基于知识库进行问答,请先加载知识库后,再进行提问。" if mode == "知识库问答" else "")]]
-    return history, ""
+        for resp, history in local_doc_qa.llm._call(query, history,
+                                                    streaming=streaming):
+            history[-1][-1] = resp + (
+                "\n\n当前知识库为空,如需基于知识库进行问答,请先加载知识库后,再进行提问。" if mode == "知识库问答" else "")
+            yield history, ""
 
 
 
 
 def update_status(history, status):
 def update_status(history, status):
@@ -62,10 +66,18 @@ def init_model(llm_model: LLM = None):
     try:
     try:
         local_doc_qa.init_cfg(llm_model=llm_model)
         local_doc_qa.init_cfg(llm_model=llm_model)
         local_doc_qa.llm._call("你好")
         local_doc_qa.llm._call("你好")
-        return """模型已成功加载,可以开始对话,或从右侧选择模式后开始对话"""
+        reply = """模型已成功加载,可以开始对话,或从右侧选择模式后开始对话"""
+        print(reply)
+        return reply
     except Exception as e:
     except Exception as e:
         print(e)
         print(e)
-        return """模型未成功加载,请到页面左上角"模型配置"选项卡中重新选择后点击"加载模型"按钮"""
+        reply = """模型未成功加载,请到页面左上角"模型配置"选项卡中重新选择后点击"加载模型"按钮"""
+        if str(e) == "Unknown platform: darwin":
+            print("该报错可能因为您使用的是 macOS 操作系统,需先下载模型至本地后执行 Web UI,具体方法请参考项目 README 中本地部署方法及常见问题:"
+                  " https://github.com/imClumsyPanda/langchain-ChatGLM")
+        else:
+            print(reply)
+        return reply
 
 
 
 
 def reinit_model(llm_model, embedding_model, llm_history_len, no_remote_model, use_ptuning_v2, top_k, history):
 def reinit_model(llm_model, embedding_model, llm_history_len, no_remote_model, use_ptuning_v2, top_k, history):
@@ -83,9 +95,11 @@ def reinit_model(llm_model, embedding_model, llm_history_len, no_remote_model, u
                               embedding_model=embedding_model,
                               embedding_model=embedding_model,
                               top_k=top_k)
                               top_k=top_k)
         model_status = """模型已成功重新加载,可以开始对话,或从右侧选择模式后开始对话"""
         model_status = """模型已成功重新加载,可以开始对话,或从右侧选择模式后开始对话"""
+        print(model_status)
     except Exception as e:
     except Exception as e:
         print(e)
         print(e)
         model_status = """模型未成功重新加载,请到页面左上角"模型配置"选项卡中重新选择后点击"加载模型"按钮"""
         model_status = """模型未成功重新加载,请到页面左上角"模型配置"选项卡中重新选择后点击"加载模型"按钮"""
+        print(model_status)
     return history + [[None, model_status]]
     return history + [[None, model_status]]
 
 
 
 
@@ -105,6 +119,7 @@ def get_vector_store(vs_id, files, history):
     else:
     else:
         file_status = "模型未完成加载,请先在加载模型后再导入文件"
         file_status = "模型未完成加载,请先在加载模型后再导入文件"
         vs_path = None
         vs_path = None
+    print(file_status)
     return vs_path, None, history + [[None, file_status]]
     return vs_path, None, history + [[None, file_status]]
 
 
 
 
@@ -124,11 +139,12 @@ def change_mode(mode):
 
 
 def add_vs_name(vs_name, vs_list, chatbot):
 def add_vs_name(vs_name, vs_list, chatbot):
     if vs_name in vs_list:
     if vs_name in vs_list:
-        chatbot = chatbot + [[None, "与已有知识库名称冲突,请重新选择其他名称后提交"]]
+        vs_status = "与已有知识库名称冲突,请重新选择其他名称后提交"
+        chatbot = chatbot + [[None, vs_status]]
         return gr.update(visible=True), vs_list, chatbot
         return gr.update(visible=True), vs_list, chatbot
     else:
     else:
-        chatbot = chatbot + [
-            [None, f"""已新增知识库"{vs_name}",将在上传文件并载入成功后进行存储。请在开始对话前,先完成文件上传。 """]]
+        vs_status = f"""已新增知识库"{vs_name}",将在上传文件并载入成功后进行存储。请在开始对话前,先完成文件上传。 """
+        chatbot = chatbot + [[None, vs_status]]
         return gr.update(visible=True, choices=vs_list + [vs_name], value=vs_name), vs_list + [vs_name], chatbot
         return gr.update(visible=True, choices=vs_list + [vs_name], value=vs_name), vs_list + [vs_name], chatbot
 
 
 
 
@@ -275,10 +291,10 @@ with gr.Blocks(css=block_css) as demo:
                             outputs=chatbot
                             outputs=chatbot
                             )
                             )
 
 
-demo.queue(concurrency_count=3)
-demo.launch(server_name='0.0.0.0',
-            server_port=7860,
-            show_api=False,
-            share=False,
-            inbrowser=False)
-demo.close()
+(demo
+ .queue(concurrency_count=3)
+ .launch(server_name='0.0.0.0',
+         server_port=7860,
+         show_api=False,
+         share=False,
+         inbrowser=False))