Browse Source

feat: 添加mmr相似度搜索,支持返回相似度分数

wangxinkai 2 years ago
parent
commit
059fe82887
2 changed files with 150 additions and 6 deletions
  1. 8 6
      chains/local_doc_qa.py
  2. 142 0
      chains/test.ipynb

+ 8 - 6
chains/local_doc_qa.py

@@ -1,7 +1,9 @@
 from langchain.chains import RetrievalQA
 from langchain.prompts import PromptTemplate
-from langchain.embeddings.huggingface import HuggingFaceEmbeddings
-from langchain.vectorstores import FAISS
+# from langchain.embeddings.huggingface import HuggingFaceEmbeddings
+from chains.lib.embeddings import MyEmbeddings
+# from langchain.vectorstores import FAISS
+from chains.lib.vectorstores import FAISSVS
 from langchain.document_loaders import UnstructuredFileLoader
 from models.chatglm_llm import ChatGLM
 import sentence_transformers
@@ -50,7 +52,7 @@ class LocalDocQA:
                             use_ptuning_v2=use_ptuning_v2)
         self.llm.history_len = llm_history_len
 
-        self.embeddings = HuggingFaceEmbeddings(model_name=embedding_model_dict[embedding_model],
+        self.embeddings = MyEmbeddings(model_name=embedding_model_dict[embedding_model],
                                                 model_kwargs={'device': embedding_device})
         # self.embeddings.client = sentence_transformers.SentenceTransformer(self.embeddings.model_name,
         #                                                                    device=embedding_device)
@@ -97,12 +99,12 @@ class LocalDocQA:
                     print(f"{file} 未能成功加载")
         if len(docs) > 0:
             if vs_path and os.path.isdir(vs_path):
-                vector_store = FAISS.load_local(vs_path, self.embeddings)
+                vector_store = FAISSVS.load_local(vs_path, self.embeddings)
                 vector_store.add_documents(docs)
             else:
                 if not vs_path:
                     vs_path = f"""{VS_ROOT_PATH}{os.path.splitext(file)[0]}_FAISS_{datetime.datetime.now().strftime("%Y%m%d_%H%M%S")}"""
-                vector_store = FAISS.from_documents(docs, self.embeddings)
+                vector_store = FAISSVS.from_documents(docs, self.embeddings)
 
             vector_store.save_local(vs_path)
             return vs_path, loaded_files
@@ -127,7 +129,7 @@ class LocalDocQA:
             input_variables=["context", "question"]
         )
         self.llm.history = chat_history
-        vector_store = FAISS.load_local(vs_path, self.embeddings)
+        vector_store = FAISSVS.load_local(vs_path, self.embeddings)
         knowledge_chain = RetrievalQA.from_llm(
             llm=self.llm,
             retriever=vector_store.as_retriever(search_kwargs={"k": self.top_k}),

File diff suppressed because it is too large
+ 142 - 0
chains/test.ipynb


Some files were not shown because too many files changed in this diff