Browse Source

add stream support to cli_demo.py

imClumsyPanda 2 years ago
parent
commit
b4aefca555
3 changed files with 91 additions and 85 deletions
  1. 37 46
      chains/local_doc_qa.py
  2. 12 6
      cli_demo.py
  3. 42 33
      models/chatglm_llm.py

+ 37 - 46
chains/local_doc_qa.py

@@ -2,7 +2,6 @@ from langchain.chains import RetrievalQA
 from langchain.prompts import PromptTemplate
 from langchain.embeddings.huggingface import HuggingFaceEmbeddings
 from langchain.vectorstores import FAISS
-from langchain.vectorstores.base import VectorStoreRetriever
 from langchain.document_loaders import UnstructuredFileLoader
 from models.chatglm_llm import ChatGLM
 import sentence_transformers
@@ -34,22 +33,20 @@ def load_file(filepath):
         docs = loader.load_and_split(text_splitter=textsplitter)
     return docs
 
-
-def get_relevant_documents(self, query: str) -> List[Document]:
-    if self.search_type == "similarity":
-        docs = self.vectorstore._similarity_search_with_relevance_scores(query, **self.search_kwargs)
-        for doc in docs:
-            doc[0].metadata["score"] = doc[1]
-        docs = [doc[0] for doc in docs]
-    elif self.search_type == "mmr":
-        docs = self.vectorstore.max_marginal_relevance_search(
-            query, **self.search_kwargs
-        )
-    else:
-        raise ValueError(f"search_type of {self.search_type} not allowed.")
-    return docs
+def generate_prompt(related_docs: List[str],
+                    query: str,
+                    prompt_template=PROMPT_TEMPLATE) -> str:
+    context = "\n".join([doc.page_content for doc in related_docs])
+    prompt = prompt_template.replace("{question}", query).replace("{context}", context)
+    return prompt
 
 
+def get_docs_with_score(docs_with_score):
+    docs=[]
+    for doc, score in docs_with_score:
+        doc.metadata["score"] = score
+        docs.append(doc)
+    return docs
 
 class LocalDocQA:
     llm: object = None
@@ -73,8 +70,6 @@ class LocalDocQA:
 
         self.embeddings = HuggingFaceEmbeddings(model_name=embedding_model_dict[embedding_model],
                                                 model_kwargs={'device': embedding_device})
-        # self.embeddings.client = sentence_transformers.SentenceTransformer(self.embeddings.model_name,
-        #                                                                    device=embedding_device)
         self.top_k = top_k
 
     def init_knowledge_vector_store(self,
@@ -134,34 +129,30 @@ class LocalDocQA:
     def get_knowledge_based_answer(self,
                                    query,
                                    vs_path,
-                                   chat_history=[], ):
-        prompt_template = """基于以下已知信息,简洁和专业的来回答用户的问题。
-    如果无法从中得到答案,请说 "根据已知信息无法回答该问题" 或 "没有提供足够的相关信息",不允许在答案中添加编造成分,答案请使用中文。
-    
-    已知内容:
-    {context}
-    
-    问题:
-    {question}"""
-        prompt = PromptTemplate(
-            template=prompt_template,
-            input_variables=["context", "question"]
-        )
-        self.llm.history = chat_history
+                                   chat_history=[],
+                                   streaming=True):
+        self.llm.streaming = streaming
         vector_store = FAISS.load_local(vs_path, self.embeddings)
-        vs_r = vector_store.as_retriever(search_type="mmr",
-                                         search_kwargs={"k": self.top_k})
-        # VectorStoreRetriever.get_relevant_documents = get_relevant_documents
-        knowledge_chain = RetrievalQA.from_llm(
-            llm=self.llm,
-            retriever=vs_r,
-            prompt=prompt
-        )
-        knowledge_chain.combine_documents_chain.document_prompt = PromptTemplate(
-            input_variables=["page_content"], template="{page_content}"
-        )
+        related_docs_with_score = vector_store.similarity_search_with_score(query,
+                                                                            k=self.top_k)
+        related_docs = get_docs_with_score(related_docs_with_score)
+        prompt = generate_prompt(related_docs, query)
 
-        knowledge_chain.return_source_documents = True
-        result = knowledge_chain({"query": query})
-        self.llm.history[-1][0] = query
-        return result, self.llm.history
+        if streaming:
+            for result, history in self.llm._call(prompt=prompt,
+                                                  history=chat_history):
+                history[-1] = list(history[-1])
+                history[-1][0] = query
+                response = {"query": query,
+                            "result": result,
+                            "source_documents": related_docs}
+                yield response, history
+        else:
+            result, history = self.llm._call(prompt=prompt,
+                                             history=chat_history)
+            history[-1] = list(history[-1])
+            history[-1][0] = query
+            response = {"query": query,
+                        "result": result,
+                        "source_documents": related_docs}
+            return response, history

+ 12 - 6
cli_demo.py

@@ -28,10 +28,16 @@ if __name__ == "__main__":
     history = []
     while True:
         query = input("Input your question 请输入问题:")
-        resp, history = local_doc_qa.get_knowledge_based_answer(query=query,
-                                                                vs_path=vs_path,
-                                                                chat_history=history)
+        last_print_len = 0
+        for resp, history in local_doc_qa.get_knowledge_based_answer(query=query,
+                                                                     vs_path=vs_path,
+                                                                     chat_history=history,
+                                                                     streaming=True):
+            print(resp["result"][last_print_len:], end="", flush=True)
+            last_print_len = len(resp["result"])
         if REPLY_WITH_SOURCE:
-            print(resp)
-        else:
-            print(resp["result"])
+            source_text = [f"""出处 [{inum + 1}] {os.path.split(doc.metadata['source'])[-1]}:\n\n{doc.page_content}\n\n"""
+                           # f"""相关度:{doc.metadata['score']}\n\n"""
+                           for inum, doc in
+                           enumerate(resp["source_documents"])]
+            print("\n\n" + "\n\n".join(source_text))

+ 42 - 33
models/chatglm_llm.py

@@ -5,7 +5,8 @@ from langchain.llms.utils import enforce_stop_tokens
 from transformers import AutoTokenizer, AutoModel, AutoConfig
 import torch
 from configs.model_config import LLM_DEVICE
-
+from langchain.callbacks.base import CallbackManager
+from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
 from typing import Dict, Tuple, Union, Optional
 
 DEVICE = LLM_DEVICE
@@ -54,10 +55,12 @@ class ChatGLM(LLM):
     max_token: int = 10000
     temperature: float = 0.01
     top_p = 0.9
-    history = []
+    # history = []
     tokenizer: object = None
     model: object = None
     history_len: int = 10
+    streaming: bool = True
+    callback_manager = CallbackManager([StreamingStdOutCallbackHandler()])
 
     def __init__(self):
         super().__init__()
@@ -68,46 +71,45 @@ class ChatGLM(LLM):
 
     def _call(self,
               prompt: str,
-              stop: Optional[List[str]] = None,
-              stream=True) -> str:
-        if stream:
-            self.history = self.history + [[None, ""]]
-            for response, history in self.model.stream_chat(
-                self.tokenizer,
-                prompt,
-                history=self.history[-self.history_len:] if self.history_len > 0 else [],
-                max_length=self.max_token,
-                temperature=self.temperature,
+              history: List[List[str]] = [],
+              stop: Optional[List[str]] = None) -> str:
+        if self.streaming:
+            history = history + [[None, ""]]
+            for stream_resp, history in self.model.stream_chat(
+                    self.tokenizer,
+                    prompt,
+                    history=history[-self.history_len:] if self.history_len > 0 else [],
+                    max_length=self.max_token,
+                    temperature=self.temperature,
             ):
-                torch_gc()
-                self.history[-1][-1] = response
-                yield response
+                yield stream_resp, history
+
         else:
             response, _ = self.model.chat(
                 self.tokenizer,
                 prompt,
-                history=self.history[-self.history_len:] if self.history_len > 0 else [],
+                history=history[-self.history_len:] if self.history_len > 0 else [],
                 max_length=self.max_token,
                 temperature=self.temperature,
             )
             torch_gc()
             if stop is not None:
                 response = enforce_stop_tokens(response, stop)
-            self.history = self.history + [[None, response]]
-            return response
-
-    def chat(self,
-             prompt: str) -> str:
-        response, _ = self.model.chat(
-            self.tokenizer,
-            prompt,
-            history=self.history[-self.history_len:] if self.history_len > 0 else [],
-            max_length=self.max_token,
-            temperature=self.temperature,
-        )
-        torch_gc()
-        self.history = self.history + [[None, response]]
-        return response
+            history = history + [[None, response]]
+            return response, history
+
+    # def chat(self,
+    #          prompt: str) -> str:
+    #     response, _ = self.model.chat(
+    #         self.tokenizer,
+    #         prompt,
+    #         history=self.history[-self.history_len:] if self.history_len > 0 else [],
+    #         max_length=self.max_token,
+    #         temperature=self.temperature,
+    #     )
+    #     torch_gc()
+    #     self.history = self.history + [[None, response]]
+    #     return response
 
     def load_model(self,
                    model_name_or_path: str = "THUDM/chatglm-6b",
@@ -149,7 +151,13 @@ class ChatGLM(LLM):
             else:
                 from accelerate import dispatch_model
 
-                model = AutoModel.from_pretrained(model_name_or_path, trust_remote_code=True, **kwargs).half()
+                model = (
+                    AutoModel.from_pretrained(
+                        model_name_or_path,
+                        trust_remote_code=True,
+                        config=model_config,
+                        **kwargs)
+                    .half())
                 # 可传入device_map自定义每张卡的部署情况
                 if device_map is None:
                     device_map = auto_configure_device_map(num_gpus)
@@ -160,7 +168,8 @@ class ChatGLM(LLM):
                 AutoModel.from_pretrained(
                     model_name_or_path,
                     config=model_config,
-                    trust_remote_code=True)
+                    trust_remote_code=True,
+                    **kwargs)
                 .float()
                 .to(llm_device)
             )