|
@@ -1,9 +1,8 @@
|
|
|
from langchain.chains import RetrievalQA
|
|
|
from langchain.prompts import PromptTemplate
|
|
|
-# from langchain.embeddings.huggingface import HuggingFaceEmbeddings
|
|
|
-from chains.lib.embeddings import MyEmbeddings
|
|
|
-# from langchain.vectorstores import FAISS
|
|
|
-from chains.lib.vectorstores import FAISSVS
|
|
|
+from langchain.embeddings.huggingface import HuggingFaceEmbeddings
|
|
|
+from langchain.vectorstores import FAISS
|
|
|
+from langchain.vectorstores.base import VectorStoreRetriever
|
|
|
from langchain.document_loaders import UnstructuredFileLoader
|
|
|
from models.chatglm_llm import ChatGLM
|
|
|
import sentence_transformers
|
|
@@ -12,6 +11,7 @@ from configs.model_config import *
|
|
|
import datetime
|
|
|
from typing import List
|
|
|
from textsplitter import ChineseTextSplitter
|
|
|
+from langchain.docstore.document import Document
|
|
|
|
|
|
# return top-k text chunk from vector store
|
|
|
VECTOR_SEARCH_TOP_K = 6
|
|
@@ -21,7 +21,10 @@ LLM_HISTORY_LEN = 3
|
|
|
|
|
|
|
|
|
def load_file(filepath):
|
|
|
- if filepath.lower().endswith(".pdf"):
|
|
|
+ if filepath.lower().endswith(".md"):
|
|
|
+ loader = UnstructuredFileLoader(filepath, mode="elements")
|
|
|
+ docs = loader.load()
|
|
|
+ elif filepath.lower().endswith(".pdf"):
|
|
|
loader = UnstructuredFileLoader(filepath)
|
|
|
textsplitter = ChineseTextSplitter(pdf=True)
|
|
|
docs = loader.load_and_split(textsplitter)
|
|
@@ -32,6 +35,22 @@ def load_file(filepath):
|
|
|
return docs
|
|
|
|
|
|
|
|
|
+def get_relevant_documents(self, query: str) -> List[Document]:
|
|
|
+ if self.search_type == "similarity":
|
|
|
+ docs = self.vectorstore._similarity_search_with_relevance_scores(query, **self.search_kwargs)
|
|
|
+ for doc in docs:
|
|
|
+ doc[0].metadata["score"] = doc[1]
|
|
|
+ docs = [doc[0] for doc in docs]
|
|
|
+ elif self.search_type == "mmr":
|
|
|
+ docs = self.vectorstore.max_marginal_relevance_search(
|
|
|
+ query, **self.search_kwargs
|
|
|
+ )
|
|
|
+ else:
|
|
|
+ raise ValueError(f"search_type of {self.search_type} not allowed.")
|
|
|
+ return docs
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
class LocalDocQA:
|
|
|
llm: object = None
|
|
|
embeddings: object = None
|
|
@@ -52,7 +71,7 @@ class LocalDocQA:
|
|
|
use_ptuning_v2=use_ptuning_v2)
|
|
|
self.llm.history_len = llm_history_len
|
|
|
|
|
|
- self.embeddings = MyEmbeddings(model_name=embedding_model_dict[embedding_model],
|
|
|
+ self.embeddings = HuggingFaceEmbeddings(model_name=embedding_model_dict[embedding_model],
|
|
|
model_kwargs={'device': embedding_device})
|
|
|
# self.embeddings.client = sentence_transformers.SentenceTransformer(self.embeddings.model_name,
|
|
|
# device=embedding_device)
|
|
@@ -99,12 +118,12 @@ class LocalDocQA:
|
|
|
print(f"{file} 未能成功加载")
|
|
|
if len(docs) > 0:
|
|
|
if vs_path and os.path.isdir(vs_path):
|
|
|
- vector_store = FAISSVS.load_local(vs_path, self.embeddings)
|
|
|
+ vector_store = FAISS.load_local(vs_path, self.embeddings)
|
|
|
vector_store.add_documents(docs)
|
|
|
else:
|
|
|
if not vs_path:
|
|
|
vs_path = f"""{VS_ROOT_PATH}{os.path.splitext(file)[0]}_FAISS_{datetime.datetime.now().strftime("%Y%m%d_%H%M%S")}"""
|
|
|
- vector_store = FAISSVS.from_documents(docs, self.embeddings)
|
|
|
+ vector_store = FAISS.from_documents(docs, self.embeddings)
|
|
|
|
|
|
vector_store.save_local(vs_path)
|
|
|
return vs_path, loaded_files
|
|
@@ -129,10 +148,13 @@ class LocalDocQA:
|
|
|
input_variables=["context", "question"]
|
|
|
)
|
|
|
self.llm.history = chat_history
|
|
|
- vector_store = FAISSVS.load_local(vs_path, self.embeddings)
|
|
|
+ vector_store = FAISS.load_local(vs_path, self.embeddings)
|
|
|
+ vs_r = vector_store.as_retriever(search_type="mmr",
|
|
|
+ search_kwargs={"k": self.top_k})
|
|
|
+ # VectorStoreRetriever.get_relevant_documents = get_relevant_documents
|
|
|
knowledge_chain = RetrievalQA.from_llm(
|
|
|
llm=self.llm,
|
|
|
- retriever=vector_store.as_retriever(search_kwargs={"k": self.top_k}),
|
|
|
+ retriever=vs_r,
|
|
|
prompt=prompt
|
|
|
)
|
|
|
knowledge_chain.combine_documents_chain.document_prompt = PromptTemplate(
|
|
@@ -140,7 +162,6 @@ class LocalDocQA:
|
|
|
)
|
|
|
|
|
|
knowledge_chain.return_source_documents = True
|
|
|
-
|
|
|
result = knowledge_chain({"query": query})
|
|
|
self.llm.history[-1][0] = query
|
|
|
return result, self.llm.history
|