Переглянути джерело

update webui.py and local_doc_qa.py

imClumsyPanda 2 роки тому
батько
коміт
88ab9a1d21
3 змінених файлів з 59 додано та 29 видалено
  1. 32 11
      chains/local_doc_qa.py
  2. 5 5
      models/chatglm_llm.py
  3. 22 13
      webui.py

+ 32 - 11
chains/local_doc_qa.py

@@ -1,9 +1,8 @@
 from langchain.chains import RetrievalQA
 from langchain.chains import RetrievalQA
 from langchain.prompts import PromptTemplate
 from langchain.prompts import PromptTemplate
-# from langchain.embeddings.huggingface import HuggingFaceEmbeddings
-from chains.lib.embeddings import MyEmbeddings
-# from langchain.vectorstores import FAISS
-from chains.lib.vectorstores import FAISSVS
+from langchain.embeddings.huggingface import HuggingFaceEmbeddings
+from langchain.vectorstores import FAISS
+from langchain.vectorstores.base import VectorStoreRetriever
 from langchain.document_loaders import UnstructuredFileLoader
 from langchain.document_loaders import UnstructuredFileLoader
 from models.chatglm_llm import ChatGLM
 from models.chatglm_llm import ChatGLM
 import sentence_transformers
 import sentence_transformers
@@ -12,6 +11,7 @@ from configs.model_config import *
 import datetime
 import datetime
 from typing import List
 from typing import List
 from textsplitter import ChineseTextSplitter
 from textsplitter import ChineseTextSplitter
+from langchain.docstore.document import Document
 
 
 # return top-k text chunk from vector store
 # return top-k text chunk from vector store
 VECTOR_SEARCH_TOP_K = 6
 VECTOR_SEARCH_TOP_K = 6
@@ -21,7 +21,10 @@ LLM_HISTORY_LEN = 3
 
 
 
 
 def load_file(filepath):
 def load_file(filepath):
-    if filepath.lower().endswith(".pdf"):
+    if filepath.lower().endswith(".md"):
+        loader = UnstructuredFileLoader(filepath, mode="elements")
+        docs = loader.load()
+    elif filepath.lower().endswith(".pdf"):
         loader = UnstructuredFileLoader(filepath)
         loader = UnstructuredFileLoader(filepath)
         textsplitter = ChineseTextSplitter(pdf=True)
         textsplitter = ChineseTextSplitter(pdf=True)
         docs = loader.load_and_split(textsplitter)
         docs = loader.load_and_split(textsplitter)
@@ -32,6 +35,22 @@ def load_file(filepath):
     return docs
     return docs
 
 
 
 
+def get_relevant_documents(self, query: str) -> List[Document]:
+    if self.search_type == "similarity":
+        docs = self.vectorstore._similarity_search_with_relevance_scores(query, **self.search_kwargs)
+        for doc in docs:
+            doc[0].metadata["score"] = doc[1]
+        docs = [doc[0] for doc in docs]
+    elif self.search_type == "mmr":
+        docs = self.vectorstore.max_marginal_relevance_search(
+            query, **self.search_kwargs
+        )
+    else:
+        raise ValueError(f"search_type of {self.search_type} not allowed.")
+    return docs
+
+
+
 class LocalDocQA:
 class LocalDocQA:
     llm: object = None
     llm: object = None
     embeddings: object = None
     embeddings: object = None
@@ -52,7 +71,7 @@ class LocalDocQA:
                             use_ptuning_v2=use_ptuning_v2)
                             use_ptuning_v2=use_ptuning_v2)
         self.llm.history_len = llm_history_len
         self.llm.history_len = llm_history_len
 
 
-        self.embeddings = MyEmbeddings(model_name=embedding_model_dict[embedding_model],
+        self.embeddings = HuggingFaceEmbeddings(model_name=embedding_model_dict[embedding_model],
                                                 model_kwargs={'device': embedding_device})
                                                 model_kwargs={'device': embedding_device})
         # self.embeddings.client = sentence_transformers.SentenceTransformer(self.embeddings.model_name,
         # self.embeddings.client = sentence_transformers.SentenceTransformer(self.embeddings.model_name,
         #                                                                    device=embedding_device)
         #                                                                    device=embedding_device)
@@ -99,12 +118,12 @@ class LocalDocQA:
                     print(f"{file} 未能成功加载")
                     print(f"{file} 未能成功加载")
         if len(docs) > 0:
         if len(docs) > 0:
             if vs_path and os.path.isdir(vs_path):
             if vs_path and os.path.isdir(vs_path):
-                vector_store = FAISSVS.load_local(vs_path, self.embeddings)
+                vector_store = FAISS.load_local(vs_path, self.embeddings)
                 vector_store.add_documents(docs)
                 vector_store.add_documents(docs)
             else:
             else:
                 if not vs_path:
                 if not vs_path:
                     vs_path = f"""{VS_ROOT_PATH}{os.path.splitext(file)[0]}_FAISS_{datetime.datetime.now().strftime("%Y%m%d_%H%M%S")}"""
                     vs_path = f"""{VS_ROOT_PATH}{os.path.splitext(file)[0]}_FAISS_{datetime.datetime.now().strftime("%Y%m%d_%H%M%S")}"""
-                vector_store = FAISSVS.from_documents(docs, self.embeddings)
+                vector_store = FAISS.from_documents(docs, self.embeddings)
 
 
             vector_store.save_local(vs_path)
             vector_store.save_local(vs_path)
             return vs_path, loaded_files
             return vs_path, loaded_files
@@ -129,10 +148,13 @@ class LocalDocQA:
             input_variables=["context", "question"]
             input_variables=["context", "question"]
         )
         )
         self.llm.history = chat_history
         self.llm.history = chat_history
-        vector_store = FAISSVS.load_local(vs_path, self.embeddings)
+        vector_store = FAISS.load_local(vs_path, self.embeddings)
+        vs_r = vector_store.as_retriever(search_type="mmr",
+                                         search_kwargs={"k": self.top_k})
+        # VectorStoreRetriever.get_relevant_documents = get_relevant_documents
         knowledge_chain = RetrievalQA.from_llm(
         knowledge_chain = RetrievalQA.from_llm(
             llm=self.llm,
             llm=self.llm,
-            retriever=vector_store.as_retriever(search_kwargs={"k": self.top_k}),
+            retriever=vs_r,
             prompt=prompt
             prompt=prompt
         )
         )
         knowledge_chain.combine_documents_chain.document_prompt = PromptTemplate(
         knowledge_chain.combine_documents_chain.document_prompt = PromptTemplate(
@@ -140,7 +162,6 @@ class LocalDocQA:
         )
         )
 
 
         knowledge_chain.return_source_documents = True
         knowledge_chain.return_source_documents = True
-        
         result = knowledge_chain({"query": query})
         result = knowledge_chain({"query": query})
         self.llm.history[-1][0] = query
         self.llm.history[-1][0] = query
         return result, self.llm.history
         return result, self.llm.history

+ 5 - 5
models/chatglm_llm.py

@@ -72,16 +72,16 @@ class ChatGLM(LLM):
               stream=True) -> str:
               stream=True) -> str:
         if stream:
         if stream:
             self.history = self.history + [[None, ""]]
             self.history = self.history + [[None, ""]]
-            response, _ = self.model.stream_chat(
+            for response, history in self.model.stream_chat(
                 self.tokenizer,
                 self.tokenizer,
                 prompt,
                 prompt,
                 history=self.history[-self.history_len:] if self.history_len > 0 else [],
                 history=self.history[-self.history_len:] if self.history_len > 0 else [],
                 max_length=self.max_token,
                 max_length=self.max_token,
                 temperature=self.temperature,
                 temperature=self.temperature,
-            )
-            torch_gc()
-            self.history[-1][-1] = response
-            yield response
+            ):
+                torch_gc()
+                self.history[-1][-1] = response
+                yield response
         else:
         else:
             response, _ = self.model.chat(
             response, _ = self.model.chat(
                 self.tokenizer,
                 self.tokenizer,

+ 22 - 13
webui.py

@@ -30,19 +30,28 @@ local_doc_qa = LocalDocQA()
 
 
 
 
 def get_answer(query, vs_path, history, mode):
 def get_answer(query, vs_path, history, mode):
-    if vs_path and mode == "知识库问答":
-        resp, history = local_doc_qa.get_knowledge_based_answer(
-            query=query, vs_path=vs_path, chat_history=history)
-        source = "".join([f"""<details> <summary>出处 {i + 1}</summary>
-{doc.page_content}
-
-<b>所属文件:</b>{doc.metadata["source"]}
-</details>""" for i, doc in enumerate(resp["source_documents"])])
-        history[-1][-1] += source
+    if mode == "知识库问答":
+        if vs_path:
+            for resp, history in local_doc_qa.get_knowledge_based_answer(
+                query=query, vs_path=vs_path, chat_history=history):
+    #         source = "".join([f"""<details> <summary>出处 {i + 1}</summary>
+    # {doc.page_content}
+    #
+    # <b>所属文件:</b>{doc.metadata["source"]}
+    # </details>""" for i, doc in enumerate(resp["source_documents"])])
+    #         history[-1][-1] += source
+                yield history, ""
+        else:
+            history = history + [[query, ""]]
+            for resp in local_doc_qa.llm._call(query):
+                history[-1][-1] = resp + (
+                    "\n\n当前知识库为空,如需基于知识库进行问答,请先加载知识库后,再进行提问。" if mode == "知识库问答" else "")
+                yield history, ""
     else:
     else:
-        resp = local_doc_qa.llm._call(query)
-        history = history + [[query, resp + ("\n\n当前知识库为空,如需基于知识库进行问答,请先加载知识库后,再进行提问。" if mode == "知识库问答" else "")]]
-    return history, ""
+        history = history + [[query, ""]]
+        for resp in local_doc_qa.llm._call(query):
+            history[-1][-1] = resp
+            yield history, ""
 
 
 
 
 def update_status(history, status):
 def update_status(history, status):
@@ -62,7 +71,7 @@ def init_model():
         print(e)
         print(e)
         reply = """模型未成功加载,请到页面左上角"模型配置"选项卡中重新选择后点击"加载模型"按钮"""
         reply = """模型未成功加载,请到页面左上角"模型配置"选项卡中重新选择后点击"加载模型"按钮"""
         if str(e) == "Unknown platform: darwin":
         if str(e) == "Unknown platform: darwin":
-            print("报错可能因为您使用的是 macOS 操作系统,需先下载模型至本地后执行 Web UI,具体方法请参考项目 README 中本地部署方法及常见问题:"
+            print("报错可能因为您使用的是 macOS 操作系统,需先下载模型至本地后执行 Web UI,具体方法请参考项目 README 中本地部署方法及常见问题:"
                   " https://github.com/imClumsyPanda/langchain-ChatGLM")
                   " https://github.com/imClumsyPanda/langchain-ChatGLM")
         else:
         else:
             print(reply)
             print(reply)