소스 검색

Merge remote-tracking branch 'origin/dev' into dev

# Conflicts:
#	configs/model_config.py
glide-the 2 년 전
부모
커밋
a0b312d749
10개의 변경된 파일170개의 추가작업 그리고 111개의 파일을 삭제
  1. 7 1
      README.md
  2. 91 55
      chains/local_doc_qa.py
  3. 6 3
      cli_demo.py
  4. 5 6
      configs/model_config.py
  5. BIN
      img/langchain+chatglm2.png
  6. BIN
      img/qr_code_10.jpg
  7. BIN
      img/qr_code_9.jpg
  8. 28 24
      models/chatglm_llm.py
  9. 11 0
      utils/__init__.py
  10. 22 22
      webui.py

+ 7 - 1
README.md

@@ -14,8 +14,14 @@
 
 ![实现原理图](img/langchain+chatglm.png)
 
+从文档处理角度来看,实现流程如下:
+
+![实现原理图2](img/langchain+chatglm2.png)
+
 🚩 本项目未涉及微调、训练过程,但可利用微调或训练对本项目效果进行优化。
 
+🌐 [AutoDL 镜像](https://www.codewithgpu.com/i/imClumsyPanda/langchain-ChatGLM/langchain-ChatGLM)
+
 📓 [ModelWhale 在线运行项目](https://www.heywhale.com/mw/project/643977aa446c45f4592a1e59)
 
 ## 变更日志
@@ -166,6 +172,6 @@ Web UI 可以实现如下功能:
   - [ ] 实现调用 API 的 Web UI Demo
 
 ## 项目交流群
-![二维码](img/qr_code_9.jpg)
+![二维码](img/qr_code_10.jpg)
 
 🎉 langchain-ChatGLM 项目交流群,如果你也对本项目感兴趣,欢迎加入群聊参与讨论交流。

+ 91 - 55
chains/local_doc_qa.py

@@ -8,6 +8,7 @@ from textsplitter import ChineseTextSplitter
 from typing import List, Tuple
 from langchain.docstore.document import Document
 import numpy as np
+from utils import torch_gc
 
 # return top-k text chunk from vector store
 VECTOR_SEARCH_TOP_K = 6
@@ -15,6 +16,10 @@ VECTOR_SEARCH_TOP_K = 6
 # LLM input history length
 LLM_HISTORY_LEN = 3
 
+DEVICE_ = EMBEDDING_DEVICE
+DEVICE_ID = "0" if torch.cuda.is_available() else None
+DEVICE = f"{DEVICE_}:{DEVICE_ID}" if DEVICE_ID else DEVICE_
+
 
 def load_file(filepath):
     if filepath.lower().endswith(".md"):
@@ -30,6 +35,7 @@ def load_file(filepath):
         docs = loader.load_and_split(text_splitter=textsplitter)
     return docs
 
+
 def generate_prompt(related_docs: List[str],
                     query: str,
                     prompt_template=PROMPT_TEMPLATE) -> str:
@@ -39,7 +45,7 @@ def generate_prompt(related_docs: List[str],
 
 
 def get_docs_with_score(docs_with_score):
-    docs=[]
+    docs = []
     for doc, score in docs_with_score:
         doc.metadata["score"] = score
         docs.append(doc)
@@ -50,7 +56,7 @@ def seperate_list(ls: List[int]) -> List[List[int]]:
     lists = []
     ls1 = [ls[0]]
     for i in range(1, len(ls)):
-        if ls[i-1] + 1 == ls[i]:
+        if ls[i - 1] + 1 == ls[i]:
             ls1.append(ls[i])
         else:
             lists.append(ls1)
@@ -59,49 +65,52 @@ def seperate_list(ls: List[int]) -> List[List[int]]:
     return lists
 
 
-
 def similarity_search_with_score_by_vector(
         self,
         embedding: List[float],
         k: int = 4,
-    ) -> List[Tuple[Document, float]]:
-        scores, indices = self.index.search(np.array([embedding], dtype=np.float32), k)
-        docs = []
-        id_set = set()
-        for j, i in enumerate(indices[0]):
-            if i == -1:
-                # This happens when not enough docs are returned.
-                continue
-            _id = self.index_to_docstore_id[i]
-            doc = self.docstore.search(_id)
-            id_set.add(i)
-            docs_len = len(doc.page_content)
-            for k in range(1, max(i, len(docs)-i)):
-                for l in [i+k, i-k]:
-                    if 0 <= l < len(self.index_to_docstore_id):
-                        _id0 = self.index_to_docstore_id[l]
-                        doc0 = self.docstore.search(_id0)
-                        if docs_len + len(doc0.page_content) > self.chunk_size:
-                            break
-                        elif doc0.metadata["source"] == doc.metadata["source"]:
-                            docs_len += len(doc0.page_content)
-                            id_set.add(l)
-        id_list = sorted(list(id_set))
-        id_lists = seperate_list(id_list)
-        for id_seq in id_lists:
-            for id in id_seq:
-                if id == id_seq[0]:
-                    _id = self.index_to_docstore_id[id]
-                    doc = self.docstore.search(_id)
-                else:
-                    _id0 = self.index_to_docstore_id[id]
+) -> List[Tuple[Document, float]]:
+    scores, indices = self.index.search(np.array([embedding], dtype=np.float32), k)
+    docs = []
+    id_set = set()
+    for j, i in enumerate(indices[0]):
+        if i == -1:
+            # This happens when not enough docs are returned.
+            continue
+        _id = self.index_to_docstore_id[i]
+        doc = self.docstore.search(_id)
+        id_set.add(i)
+        docs_len = len(doc.page_content)
+        for k in range(1, max(i, len(docs) - i)):
+            break_flag = False
+            for l in [i + k, i - k]:
+                if 0 <= l < len(self.index_to_docstore_id):
+                    _id0 = self.index_to_docstore_id[l]
                     doc0 = self.docstore.search(_id0)
-                    doc.page_content += doc0.page_content
-            if not isinstance(doc, Document):
-                raise ValueError(f"Could not find document for id {_id}, got {doc}")
-            docs.append((doc, scores[0][j]))
-        return docs
-
+                    if docs_len + len(doc0.page_content) > self.chunk_size:
+                        break_flag=True
+                        break
+                    elif doc0.metadata["source"] == doc.metadata["source"]:
+                        docs_len += len(doc0.page_content)
+                        id_set.add(l)
+            if break_flag:
+                break
+    id_list = sorted(list(id_set))
+    id_lists = seperate_list(id_list)
+    for id_seq in id_lists:
+        for id in id_seq:
+            if id == id_seq[0]:
+                _id = self.index_to_docstore_id[id]
+                doc = self.docstore.search(_id)
+            else:
+                _id0 = self.index_to_docstore_id[id]
+                doc0 = self.docstore.search(_id0)
+                doc.page_content += doc0.page_content
+        if not isinstance(doc, Document):
+            raise ValueError(f"Could not find document for id {_id}, got {doc}")
+        docs.append((doc, scores[0][j]))
+    torch_gc(DEVICE)
+    return docs
 
 
 class LocalDocQA:
@@ -172,10 +181,12 @@ class LocalDocQA:
             if vs_path and os.path.isdir(vs_path):
                 vector_store = FAISS.load_local(vs_path, self.embeddings)
                 vector_store.add_documents(docs)
+                torch_gc(DEVICE)
             else:
                 if not vs_path:
                     vs_path = f"""{VS_ROOT_PATH}{os.path.splitext(file)[0]}_FAISS_{datetime.datetime.now().strftime("%Y%m%d_%H%M%S")}"""
                 vector_store = FAISS.from_documents(docs, self.embeddings)
+                torch_gc(DEVICE)
 
             vector_store.save_local(vs_path)
             return vs_path, loaded_files
@@ -187,29 +198,54 @@ class LocalDocQA:
                                    query,
                                    vs_path,
                                    chat_history=[],
-                                   streaming=True):
-        self.llm.streaming = streaming
+                                   streaming: bool = STREAMING):
         vector_store = FAISS.load_local(vs_path, self.embeddings)
         FAISS.similarity_search_with_score_by_vector = similarity_search_with_score_by_vector
-        vector_store.chunk_size=self.chunk_size
+        vector_store.chunk_size = self.chunk_size
         related_docs_with_score = vector_store.similarity_search_with_score(query,
                                                                             k=self.top_k)
         related_docs = get_docs_with_score(related_docs_with_score)
         prompt = generate_prompt(related_docs, query)
 
-        if streaming:
-            for result, history in self.llm._call(prompt=prompt,
-                                                  history=chat_history):
-                history[-1][0] = query
-                response = {"query": query,
-                            "result": result,
-                            "source_documents": related_docs}
-                yield response, history
-        else:
-            result, history = self.llm._call(prompt=prompt,
-                                             history=chat_history)
+        # if streaming:
+        #     for result, history in self.llm._stream_call(prompt=prompt,
+        #                                                  history=chat_history):
+        #         history[-1][0] = query
+        #         response = {"query": query,
+        #                     "result": result,
+        #                     "source_documents": related_docs}
+        #         yield response, history
+        # else:
+        for result, history in self.llm._call(prompt=prompt,
+                                              history=chat_history,
+                                              streaming=streaming):
             history[-1][0] = query
             response = {"query": query,
                         "result": result,
                         "source_documents": related_docs}
-            return response, history
+            yield response, history
+
+
+if __name__ == "__main__":
+    local_doc_qa = LocalDocQA()
+    local_doc_qa.init_cfg()
+    query = "本项目使用的embedding模型是什么,消耗多少显存"
+    vs_path = "/Users/liuqian/Downloads/glm-dev/vector_store/aaa"
+    last_print_len = 0
+    for resp, history in local_doc_qa.get_knowledge_based_answer(query=query,
+                                                                 vs_path=vs_path,
+                                                                 chat_history=[],
+                                                                 streaming=True):
+        print(resp["result"][last_print_len:], end="", flush=True)
+        last_print_len = len(resp["result"])
+    source_text = [f"""出处 [{inum + 1}] {os.path.split(doc.metadata['source'])[-1]}:\n\n{doc.page_content}\n\n"""
+                   # f"""相关度:{doc.metadata['score']}\n\n"""
+                   for inum, doc in
+                   enumerate(resp["source_documents"])]
+    print("\n\n" + "\n\n".join(source_text))
+    # for resp, history in local_doc_qa.get_knowledge_based_answer(query=query,
+    #                                                              vs_path=vs_path,
+    #                                                              chat_history=[],
+    #                                                              streaming=False):
+    #     print(resp["result"])
+    pass

+ 6 - 3
cli_demo.py

@@ -32,9 +32,12 @@ if __name__ == "__main__":
         for resp, history in local_doc_qa.get_knowledge_based_answer(query=query,
                                                                      vs_path=vs_path,
                                                                      chat_history=history,
-                                                                     streaming=True):
-            print(resp["result"][last_print_len:], end="", flush=True)
-            last_print_len = len(resp["result"])
+                                                                     streaming=STREAMING):
+            if STREAMING:
+                print(resp["result"][last_print_len:], end="", flush=True)
+                last_print_len = len(resp["result"])
+            else:
+                print(resp["result"])
         if REPLY_WITH_SOURCE:
             source_text = [f"""出处 [{inum + 1}] {os.path.split(doc.metadata['source'])[-1]}:\n\n{doc.page_content}\n\n"""
                            # f"""相关度:{doc.metadata['score']}\n\n"""

+ 5 - 6
configs/model_config.py

@@ -27,6 +27,9 @@ llm_model_dict = {
 # LLM model name
 LLM_MODEL = "chatglm-6b"
 
+# LLM streaming reponse
+STREAMING = True
+
 # Use p-tuning-v2 PrefixEncoder
 USE_PTUNING_V2 = False
 
@@ -38,14 +41,10 @@ VS_ROOT_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "vector_
 UPLOAD_ROOT_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "content", "")
 
 # 基于上下文的prompt模版,请务必保留"{question}"和"{context}"
-PROMPT_TEMPLATE = """已知信息在下方"="包裹的段落,基于以下已知信息,简洁和专业的来回答用户的问题。如果无法从中得到答案,请说 "根据已知信息无法回答该问题" 或 "没有提供足够的相关信息",不允许在答案中添加编造成分,答案请使用中文。 
-
-====================================已知信息===================================================== 
+PROMPT_TEMPLATE = """已知信息:
 {context} 
-================================================================================================
 
-问题:"{question}"
-答案:"""
+根据上述已知信息,简洁和专业的来回答用户的问题。如果无法从中得到答案,请说 “根据已知信息无法回答该问题” 或 “没有提供足够的相关信息”,不允许在答案中添加编造成分,答案请使用中文。 问题是:{question}"""
 
 # 匹配后单段上下文长度
 CHUNK_SIZE = 500

BIN
img/langchain+chatglm2.png


BIN
img/qr_code_10.jpg


BIN
img/qr_code_9.jpg


+ 28 - 24
models/chatglm_llm.py

@@ -4,21 +4,15 @@ from typing import Optional, List
 from langchain.llms.utils import enforce_stop_tokens
 from transformers import AutoTokenizer, AutoModel, AutoConfig
 import torch
-from configs.model_config import LLM_DEVICE
+from configs.model_config import *
 from langchain.callbacks.base import CallbackManager
 from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
 from typing import Dict, Tuple, Union, Optional
+from utils import torch_gc
 
-DEVICE = LLM_DEVICE
+DEVICE_ = LLM_DEVICE
 DEVICE_ID = "0" if torch.cuda.is_available() else None
-CUDA_DEVICE = f"{DEVICE}:{DEVICE_ID}" if DEVICE_ID else DEVICE
-
-
-def torch_gc():
-    if torch.cuda.is_available():
-        with torch.cuda.device(CUDA_DEVICE):
-            torch.cuda.empty_cache()
-            torch.cuda.ipc_collect()
+DEVICE = f"{DEVICE_}:{DEVICE_ID}" if DEVICE_ID else DEVICE_
 
 
 def auto_configure_device_map(num_gpus: int) -> Dict[str, int]:
@@ -59,7 +53,6 @@ class ChatGLM(LLM):
     tokenizer: object = None
     model: object = None
     history_len: int = 10
-    streaming: bool = True
     callback_manager = CallbackManager([StreamingStdOutCallbackHandler()])
 
     def __init__(self):
@@ -72,8 +65,8 @@ class ChatGLM(LLM):
     def _call(self,
               prompt: str,
               history: List[List[str]] = [],
-              stop: Optional[List[str]] = None) -> str:
-        if self.streaming:
+              streaming: bool = STREAMING):  # -> Tuple[str, List[List[str]]]:
+        if streaming:
             for inum, (stream_resp, _) in enumerate(self.model.stream_chat(
                     self.tokenizer,
                     prompt,
@@ -81,25 +74,23 @@ class ChatGLM(LLM):
                     max_length=self.max_token,
                     temperature=self.temperature,
             )):
+                torch_gc(DEVICE)
                 if inum == 0:
                     history += [[prompt, stream_resp]]
                 else:
                     history[-1] = [prompt, stream_resp]
                 yield stream_resp, history
-
         else:
             response, _ = self.model.chat(
-                self.tokenizer,
-                prompt,
-                history=history[-self.history_len:] if self.history_len > 0 else [],
-                max_length=self.max_token,
-                temperature=self.temperature,
+                    self.tokenizer,
+                    prompt,
+                    history=history[-self.history_len:] if self.history_len > 0 else [],
+                    max_length=self.max_token,
+                    temperature=self.temperature,
             )
-            torch_gc()
-            if stop is not None:
-                response = enforce_stop_tokens(response, stop)
-            history = history + [[None, response]]
-            return response, history
+            torch_gc(DEVICE)
+            history += [[prompt, response]]
+            yield response, history
 
     # def chat(self,
     #          prompt: str) -> str:
@@ -191,3 +182,16 @@ class ChatGLM(LLM):
                 print("加载PrefixEncoder模型参数失败")
 
         self.model = self.model.eval()
+
+
+if __name__ == "__main__":
+    llm = ChatGLM()
+    llm.load_model(model_name_or_path=llm_model_dict[LLM_MODEL],
+                   llm_device=LLM_DEVICE, )
+    last_print_len=0
+    for resp, history in llm._call("你好", streaming=True):
+        print(resp[last_print_len:], end="", flush=True)
+        last_print_len = len(resp)
+    for resp, history in llm._call("你好", streaming=False):
+        print(resp)
+    pass

+ 11 - 0
utils/__init__.py

@@ -0,0 +1,11 @@
+import torch.cuda
+import torch.mps
+import torch.backends
+
+def torch_gc(DEVICE):
+    if torch.cuda.is_available():
+        with torch.cuda.device(DEVICE):
+            torch.cuda.empty_cache()
+            torch.cuda.ipc_collect()
+    elif torch.backends.mps.is_available():
+        torch.mps.empty_cache()

+ 22 - 22
webui.py

@@ -29,28 +29,28 @@ llm_model_dict_list = list(llm_model_dict.keys())
 local_doc_qa = LocalDocQA()
 
 
-def get_answer(query, vs_path, history, mode):
-    if mode == "知识库问答":
-        if vs_path:
-            for resp, history in local_doc_qa.get_knowledge_based_answer(
-                    query=query, vs_path=vs_path, chat_history=history):
-                source = "\n\n"
-                source += "".join(
-                    [f"""<details> <summary>出处 [{i + 1}] {os.path.split(doc.metadata["source"])[-1]}</summary>\n"""
-                     f"""{doc.page_content}\n"""
-                     f"""</details>"""
-                     for i, doc in
-                     enumerate(resp["source_documents"])])
-                history[-1][-1] += source
-                yield history, ""
-        else:
-            for resp, history in local_doc_qa.llm._call(query, history):
-                history[-1][-1] = resp + (
-                    "\n\n当前知识库为空,如需基于知识库进行问答,请先加载知识库后,再进行提问。" if mode == "知识库问答" else "")
-                yield history, ""
+def get_answer(query, vs_path, history, mode,
+               streaming: bool = STREAMING):
+    if mode == "知识库问答" and vs_path:
+        for resp, history in local_doc_qa.get_knowledge_based_answer(
+                query=query,
+                vs_path=vs_path,
+                chat_history=history,
+                streaming=streaming):
+            source = "\n\n"
+            source += "".join(
+                [f"""<details> <summary>出处 [{i + 1}] {os.path.split(doc.metadata["source"])[-1]}</summary>\n"""
+                 f"""{doc.page_content}\n"""
+                 f"""</details>"""
+                 for i, doc in
+                 enumerate(resp["source_documents"])])
+            history[-1][-1] += source
+            yield history, ""
     else:
-        for resp, history in local_doc_qa.llm._call(query, history):
-            history[-1][-1] = resp
+        for resp, history in local_doc_qa.llm._call(query, history,
+                                                    streaming=streaming):
+            history[-1][-1] = resp + (
+                "\n\n当前知识库为空,如需基于知识库进行问答,请先加载知识库后,再进行提问。" if mode == "知识库问答" else "")
             yield history, ""
 
 
@@ -84,7 +84,7 @@ def reinit_model(llm_model, embedding_model, llm_history_len, use_ptuning_v2, to
                               embedding_model=embedding_model,
                               llm_history_len=llm_history_len,
                               use_ptuning_v2=use_ptuning_v2,
-                              top_k=top_k)
+                              top_k=top_k,)
         model_status = """模型已成功重新加载,可以开始对话,或从右侧选择模式后开始对话"""
         print(model_status)
     except Exception as e: