|
@@ -1,7 +1,9 @@
|
|
|
from langchain.chains import RetrievalQA
|
|
|
from langchain.prompts import PromptTemplate
|
|
|
-from langchain.embeddings.huggingface import HuggingFaceEmbeddings
|
|
|
-from langchain.vectorstores import FAISS
|
|
|
+# from langchain.embeddings.huggingface import HuggingFaceEmbeddings
|
|
|
+from chains.lib.embeddings import MyEmbeddings
|
|
|
+# from langchain.vectorstores import FAISS
|
|
|
+from chains.lib.vectorstores import FAISSVS
|
|
|
from langchain.document_loaders import UnstructuredFileLoader
|
|
|
from models.chatglm_llm import ChatGLM
|
|
|
import sentence_transformers
|
|
@@ -50,7 +52,7 @@ class LocalDocQA:
|
|
|
use_ptuning_v2=use_ptuning_v2)
|
|
|
self.llm.history_len = llm_history_len
|
|
|
|
|
|
- self.embeddings = HuggingFaceEmbeddings(model_name=embedding_model_dict[embedding_model],
|
|
|
+ self.embeddings = MyEmbeddings(model_name=embedding_model_dict[embedding_model],
|
|
|
model_kwargs={'device': embedding_device})
|
|
|
# self.embeddings.client = sentence_transformers.SentenceTransformer(self.embeddings.model_name,
|
|
|
# device=embedding_device)
|
|
@@ -97,12 +99,12 @@ class LocalDocQA:
|
|
|
print(f"{file} 未能成功加载")
|
|
|
if len(docs) > 0:
|
|
|
if vs_path and os.path.isdir(vs_path):
|
|
|
- vector_store = FAISS.load_local(vs_path, self.embeddings)
|
|
|
+ vector_store = FAISSVS.load_local(vs_path, self.embeddings)
|
|
|
vector_store.add_documents(docs)
|
|
|
else:
|
|
|
if not vs_path:
|
|
|
vs_path = f"""{VS_ROOT_PATH}{os.path.splitext(file)[0]}_FAISS_{datetime.datetime.now().strftime("%Y%m%d_%H%M%S")}"""
|
|
|
- vector_store = FAISS.from_documents(docs, self.embeddings)
|
|
|
+ vector_store = FAISSVS.from_documents(docs, self.embeddings)
|
|
|
|
|
|
vector_store.save_local(vs_path)
|
|
|
return vs_path, loaded_files
|
|
@@ -127,7 +129,7 @@ class LocalDocQA:
|
|
|
input_variables=["context", "question"]
|
|
|
)
|
|
|
self.llm.history = chat_history
|
|
|
- vector_store = FAISS.load_local(vs_path, self.embeddings)
|
|
|
+ vector_store = FAISSVS.load_local(vs_path, self.embeddings)
|
|
|
knowledge_chain = RetrievalQA.from_llm(
|
|
|
llm=self.llm,
|
|
|
retriever=vector_store.as_retriever(search_kwargs={"k": self.top_k}),
|