瀏覽代碼

修改项目架构

imClumsyPanda 2 年之前
父節點
當前提交
a6184b01be
共有 8 個文件被更改,包括 181 次插入139 次删除
  1. 3 1
      README.md
  2. 1 1
      README_en.md
  3. 104 0
      chains/local_doc_qa.py
  4. 33 0
      cli_demo.py
  5. 31 0
      configs/model_config.py
  6. 0 124
      knowledge_based_chatglm.py
  7. 7 11
      models/chatglm_llm.py
  8. 2 2
      webui.py

+ 3 - 1
README.md

@@ -16,6 +16,8 @@
 
 
 🚩 本项目未涉及微调、训练过程,但可利用微调或训练对本项目效果进行优化。
 🚩 本项目未涉及微调、训练过程,但可利用微调或训练对本项目效果进行优化。
 
 
+[TOC]
+
 ## 更新信息
 ## 更新信息
 
 
 **[2023/04/07]** 
 **[2023/04/07]** 
@@ -76,7 +78,7 @@ Web UI 可以实现如下功能:
 3. 添加上传文件功能,通过下拉框选择已上传的文件,点击`loading`加载文件,过程中可随时更换加载的文件
 3. 添加上传文件功能,通过下拉框选择已上传的文件,点击`loading`加载文件,过程中可随时更换加载的文件
 4. 底部添加`use via API`可对接到自己系统
 4. 底部添加`use via API`可对接到自己系统
 
 
-或执行 [knowledge_based_chatglm.py](knowledge_based_chatglm.py) 脚本体验**命令行交互**
+或执行 [knowledge_based_chatglm.py](cli_demo.py) 脚本体验**命令行交互**
 ```commandline
 ```commandline
 python knowledge_based_chatglm.py
 python knowledge_based_chatglm.py
 ```
 ```

+ 1 - 1
README_en.md

@@ -68,7 +68,7 @@ pip install -r requirements.txt
 ```
 ```
 Attention: With langchain.document_loaders.UnstructuredFileLoader used to connect with local knowledge file, you may need some other dependencies as mentioned in  [langchain documentation](https://python.langchain.com/en/latest/modules/indexes/document_loaders/examples/unstructured_file.html)
 Attention: With langchain.document_loaders.UnstructuredFileLoader used to connect with local knowledge file, you may need some other dependencies as mentioned in  [langchain documentation](https://python.langchain.com/en/latest/modules/indexes/document_loaders/examples/unstructured_file.html)
 
 
-### 2. Run [knowledge_based_chatglm.py](knowledge_based_chatglm.py) script
+### 2. Run [knowledge_based_chatglm.py](cli_demo.py) script
 ```commandline
 ```commandline
 python knowledge_based_chatglm.py
 python knowledge_based_chatglm.py
 ```
 ```

+ 104 - 0
chains/local_doc_qa.py

@@ -0,0 +1,104 @@
+from langchain.chains import RetrievalQA
+from langchain.prompts import PromptTemplate
+from langchain.embeddings.huggingface import HuggingFaceEmbeddings
+from langchain.vectorstores import FAISS
+from langchain.document_loaders import UnstructuredFileLoader
+from models.chatglm_llm import ChatGLM
+import sentence_transformers
+import os
+from configs.model_config import *
+import datetime
+
+# return top-k text chunk from vector store
+VECTOR_SEARCH_TOP_K = 10
+
+# LLM input history length
+LLM_HISTORY_LEN = 3
+
+# Show reply with source text from input document
+REPLY_WITH_SOURCE = True
+
+
+class LocalDocQA:
+    llm: object = None
+    embeddings: object = None
+
+    def init_cfg(self,
+                 embedding_model: str = EMBEDDING_MODEL,
+                 embedding_device=EMBEDDING_DEVICE,
+                 llm_history_len: int = LLM_HISTORY_LEN,
+                 llm_model: str = LLM_MODEL,
+                 llm_device=LLM_DEVICE
+                 ):
+        self.llm = ChatGLM()
+        self.llm.load_model(model_name_or_path=llm_model_dict[llm_model],
+                            llm_device=llm_device)
+        self.llm.history_len = llm_history_len
+
+        self.embeddings = HuggingFaceEmbeddings(model_name=embedding_model_dict[embedding_model], )
+        self.embeddings.client = sentence_transformers.SentenceTransformer(self.embeddings.model_name,
+                                                                           device=embedding_device)
+
+    def init_knowledge_vector_store(self,
+                                    filepath: str):
+        if not os.path.exists(filepath):
+            print("路径不存在")
+            return None
+        elif os.path.isfile(filepath):
+            file = os.path.split(filepath)[-1]
+            try:
+                loader = UnstructuredFileLoader(filepath, mode="elements")
+                docs = loader.load()
+                print(f"{file} 已成功加载")
+            except:
+                print(f"{file} 未能成功加载")
+                return None
+        elif os.path.isdir(filepath):
+            docs = []
+            for file in os.listdir(filepath):
+                fullfilepath = os.path.join(filepath, file)
+                try:
+                    loader = UnstructuredFileLoader(fullfilepath, mode="elements")
+                    docs += loader.load()
+                    print(f"{file} 已成功加载")
+                except:
+                    print(f"{file} 未能成功加载")
+
+        vector_store = FAISS.from_documents(docs, self.embeddings)
+        vs_path = f"""./vector_store/{os.path.splitext(file)}_FAISS_{datetime.datetime.now().strftime("%Y%m%d_%H%M%S")}"""
+        vector_store.save_local(vs_path)
+        return vs_path
+
+    def get_knowledge_based_answer(self,
+                                   query,
+                                   vs_path,
+                                   chat_history=[],
+                                   top_k=VECTOR_SEARCH_TOP_K):
+        prompt_template = """基于以下已知信息,简洁和专业的来回答用户的问题。
+    如果无法从中得到答案,请说 "根据已知信息无法回答该问题" 或 "没有提供足够的相关信息",不允许在答案中添加编造成分,答案请使用中文。
+    
+    已知内容:
+    {context}
+    
+    问题:
+    {question}"""
+        prompt = PromptTemplate(
+            template=prompt_template,
+            input_variables=["context", "question"]
+        )
+        self.llm.history = chat_history
+        vector_store = FAISS.load_local(vs_path, self.embeddings)
+        knowledge_chain = RetrievalQA.from_llm(
+            llm=self.llm,
+            retriever=vector_store.as_retriever(search_kwargs={"k": top_k}),
+            prompt=prompt
+        )
+        knowledge_chain.combine_documents_chain.document_prompt = PromptTemplate(
+            input_variables=["page_content"], template="{page_content}"
+        )
+
+        knowledge_chain.return_source_documents = True
+
+        result = knowledge_chain({"query": query})
+        self.llm.history[-1][0] = query
+        return result, self.llm.history

+ 33 - 0
cli_demo.py

@@ -0,0 +1,33 @@
+from configs.model_config import *
+import datetime
+from chains.local_doc_qa import LocalDocQA
+
+# return top-k text chunk from vector store
+VECTOR_SEARCH_TOP_K = 10
+
+# LLM input history length
+LLM_HISTORY_LEN = 3
+
+# Show reply with source text from input document
+REPLY_WITH_SOURCE = True
+
+if __name__ == "__main__":
+    local_doc_qa = LocalDocQA()
+    local_doc_qa.init_cfg(llm_model=LLM_MODEL,
+                          embedding_model=EMBEDDING_MODEL,
+                          embedding_device=EMBEDDING_DEVICE,
+                          llm_history_len=LLM_HISTORY_LEN)
+    vs_path = None
+    while not vs_path:
+        filepath = input("Input your local knowledge file path 请输入本地知识文件路径:")
+        vs_path = local_doc_qa.init_knowledge_vector_store(filepath)
+    history = []
+    while True:
+        query = input("Input your question 请输入问题:")
+        resp, history = local_doc_qa.get_knowledge_based_answer(query=query,
+                                                                vs_path=vs_path,
+                                                                chat_history=history)
+        if REPLY_WITH_SOURCE:
+            print(resp)
+        else:
+            print(resp["result"])

+ 31 - 0
configs/model_config.py

@@ -0,0 +1,31 @@
+import torch.cuda
+import torch.backends
+
+
+embedding_model_dict = {
+    "ernie-tiny": "nghuyong/ernie-3.0-nano-zh",
+    "ernie-base": "nghuyong/ernie-3.0-base-zh",
+    "text2vec": "GanymedeNil/text2vec-large-chinese",
+    "local": "/Users/liuqian/Downloads/ChatGLM-6B/text2vec-large-chinese"
+}
+
+# Embedding model name
+EMBEDDING_MODEL = "local"#"text2vec"
+
+# Embedding running device
+EMBEDDING_DEVICE = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
+
+# supported LLM models
+llm_model_dict = {
+    "chatglm-6b-int4-qe": "THUDM/chatglm-6b-int4-qe",
+    "chatglm-6b-int4": "THUDM/chatglm-6b-int4",
+    "chatglm-6b": "THUDM/chatglm-6b",
+    "local": "/Users/liuqian/Downloads/ChatGLM-6B/chatglm-6b"
+}
+
+# LLM model name
+LLM_MODEL = "local"#"chatglm-6b"
+
+# LLM running device
+LLM_DEVICE = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
+

+ 0 - 124
knowledge_based_chatglm.py

@@ -1,124 +0,0 @@
-from langchain.chains import RetrievalQA
-from langchain.prompts import PromptTemplate
-from langchain.embeddings.huggingface import HuggingFaceEmbeddings
-from langchain.vectorstores import FAISS
-from langchain.document_loaders import UnstructuredFileLoader
-from chatglm_llm import ChatGLM
-import sentence_transformers
-import torch
-import os
-import readline
-
-
-# Global Parameters
-EMBEDDING_MODEL = "text2vec"
-VECTOR_SEARCH_TOP_K = 6
-LLM_MODEL = "chatglm-6b"
-LLM_HISTORY_LEN = 3
-DEVICE = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
-
-# Show reply with source text from input document
-REPLY_WITH_SOURCE = True
-
-embedding_model_dict = {
-    "ernie-tiny": "nghuyong/ernie-3.0-nano-zh",
-    "ernie-base": "nghuyong/ernie-3.0-base-zh",
-    "text2vec": "GanymedeNil/text2vec-large-chinese",
-}
-
-llm_model_dict = {
-    "chatglm-6b-int4-qe": "THUDM/chatglm-6b-int4-qe",
-    "chatglm-6b-int4": "THUDM/chatglm-6b-int4",
-    "chatglm-6b": "THUDM/chatglm-6b",
-}
-
-
-def init_cfg(LLM_MODEL, EMBEDDING_MODEL, LLM_HISTORY_LEN, V_SEARCH_TOP_K=6):
-    global chatglm, embeddings, VECTOR_SEARCH_TOP_K
-    VECTOR_SEARCH_TOP_K = V_SEARCH_TOP_K
-
-    chatglm = ChatGLM()
-    chatglm.load_model(model_name_or_path=llm_model_dict[LLM_MODEL])
-    chatglm.history_len = LLM_HISTORY_LEN
-
-    embeddings = HuggingFaceEmbeddings(model_name=embedding_model_dict[EMBEDDING_MODEL],)
-    embeddings.client = sentence_transformers.SentenceTransformer(embeddings.model_name,
-                                                                  device=DEVICE)
-
-
-def init_knowledge_vector_store(filepath:str):
-    if not os.path.exists(filepath):
-        print("路径不存在")
-        return None
-    elif os.path.isfile(filepath):
-        file = os.path.split(filepath)[-1]
-        try:
-            loader = UnstructuredFileLoader(filepath, mode="elements")
-            docs = loader.load()
-            print(f"{file} 已成功加载")
-        except:
-            print(f"{file} 未能成功加载")
-            return None
-    elif os.path.isdir(filepath):
-        docs = []
-        for file in os.listdir(filepath):
-            fullfilepath = os.path.join(filepath, file)
-            try:
-                loader = UnstructuredFileLoader(fullfilepath, mode="elements")
-                docs += loader.load()
-                print(f"{file} 已成功加载")
-            except:
-                print(f"{file} 未能成功加载")
-
-    vector_store = FAISS.from_documents(docs, embeddings)
-    return vector_store
-
-
-def get_knowledge_based_answer(query, vector_store, chat_history=[]):
-    global chatglm, embeddings
-
-    prompt_template = """基于以下已知信息,简洁和专业的来回答用户的问题。
-如果无法从中得到答案,请说 "根据已知信息无法回答该问题" 或 "没有提供足够的相关信息",不允许在答案中添加编造成分,答案请使用中文。
-
-已知内容:
-{context}
-
-问题:
-{question}"""
-    prompt = PromptTemplate(
-        template=prompt_template,
-        input_variables=["context", "question"]
-    )
-    chatglm.history = chat_history
-    knowledge_chain = RetrievalQA.from_llm(
-        llm=chatglm,
-        retriever=vector_store.as_retriever(search_kwargs={"k": VECTOR_SEARCH_TOP_K}),
-        prompt=prompt
-    )
-    knowledge_chain.combine_documents_chain.document_prompt = PromptTemplate(
-            input_variables=["page_content"], template="{page_content}"
-        )
-
-    knowledge_chain.return_source_documents = True
-
-    result = knowledge_chain({"query": query})
-    chatglm.history[-1][0] = query
-    return result, chatglm.history
-
-
-if __name__ == "__main__":
-    init_cfg(LLM_MODEL, EMBEDDING_MODEL, LLM_HISTORY_LEN)
-    vector_store = None
-    while not vector_store:
-        filepath = input("Input your local knowledge file path 请输入本地知识文件路径:")
-        vector_store = init_knowledge_vector_store(filepath)
-    history = []
-    while True:
-        query = input("Input your question 请输入问题:")
-        resp, history = get_knowledge_based_answer(query=query,
-                                                   vector_store=vector_store,
-                                                   chat_history=history)
-        if REPLY_WITH_SOURCE:
-            print(resp)
-        else:
-            print(resp["result"])

+ 7 - 11
chatglm_llm.py → models/chatglm_llm.py

@@ -3,8 +3,9 @@ from typing import Optional, List
 from langchain.llms.utils import enforce_stop_tokens
 from langchain.llms.utils import enforce_stop_tokens
 from transformers import AutoTokenizer, AutoModel
 from transformers import AutoTokenizer, AutoModel
 import torch
 import torch
+from configs.model_config import LLM_DEVICE
 
 
-DEVICE = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
+DEVICE = LLM_DEVICE
 DEVICE_ID = "0" if torch.cuda.is_available() else None
 DEVICE_ID = "0" if torch.cuda.is_available() else None
 CUDA_DEVICE = f"{DEVICE}:{DEVICE_ID}" if DEVICE_ID else DEVICE
 CUDA_DEVICE = f"{DEVICE}:{DEVICE_ID}" if DEVICE_ID else DEVICE
 
 
@@ -48,12 +49,14 @@ class ChatGLM(LLM):
         self.history = self.history+[[None, response]]
         self.history = self.history+[[None, response]]
         return response
         return response
 
 
-    def load_model(self, model_name_or_path: str = "THUDM/chatglm-6b"):
+    def load_model(self,
+                   model_name_or_path: str = "THUDM/chatglm-6b",
+                   llm_device=LLM_DEVICE):
         self.tokenizer = AutoTokenizer.from_pretrained(
         self.tokenizer = AutoTokenizer.from_pretrained(
             model_name_or_path,
             model_name_or_path,
             trust_remote_code=True
             trust_remote_code=True
         )
         )
-        if torch.cuda.is_available():
+        if torch.cuda.is_available() and llm_device.lower().startswith("cuda"):
             self.model = (
             self.model = (
                 AutoModel.from_pretrained(
                 AutoModel.from_pretrained(
                     model_name_or_path,
                     model_name_or_path,
@@ -61,19 +64,12 @@ class ChatGLM(LLM):
                 .half()
                 .half()
                 .cuda()
                 .cuda()
             )
             )
-        elif torch.backends.mps.is_available():
-            self.model = (
-                AutoModel.from_pretrained(
-                    model_name_or_path,
-                    trust_remote_code=True)
-                .float()
-                .to('mps')
-            )
         else:
         else:
             self.model = (
             self.model = (
                 AutoModel.from_pretrained(
                 AutoModel.from_pretrained(
                     model_name_or_path,
                     model_name_or_path,
                     trust_remote_code=True)
                     trust_remote_code=True)
                 .float()
                 .float()
+                .to(llm_device)
             )
             )
         self.model = self.model.eval()
         self.model = self.model.eval()

+ 2 - 2
webui.py

@@ -1,7 +1,7 @@
 import gradio as gr
 import gradio as gr
 import os
 import os
 import shutil
 import shutil
-import knowledge_based_chatglm as kb
+import cli_demo as kb
 
 
 
 
 def get_file_list():
 def get_file_list():
@@ -108,7 +108,7 @@ with gr.Blocks(css="""
                                              value=file_list[0] if len(file_list) > 0 else None)
                                              value=file_list[0] if len(file_list) > 0 else None)
                 with gr.Tab("upload"):
                 with gr.Tab("upload"):
                     file = gr.File(label="content file",
                     file = gr.File(label="content file",
-                                   file_types=['.txt', '.md', '.docx']
+                                   file_types=['.txt', '.md', '.docx', '.pdf']
                                    ).style(height=100)
                                    ).style(height=100)
                     # 将上传的文件保存到content文件夹下,并更新下拉框
                     # 将上传的文件保存到content文件夹下,并更新下拉框
                     file.upload(upload_file,
                     file.upload(upload_file,