소스 검색

feat: 添加mmr相似度搜索,支持返回相似度分数

wangxinkai 2 년 전
부모
커밋
059fe82887
2개의 변경된 파일150개의 추가작업 그리고 6개의 파일을 삭제
  1. 8 6
      chains/local_doc_qa.py
  2. 142 0
      chains/test.ipynb

+ 8 - 6
chains/local_doc_qa.py

@@ -1,7 +1,9 @@
 from langchain.chains import RetrievalQA
 from langchain.prompts import PromptTemplate
-from langchain.embeddings.huggingface import HuggingFaceEmbeddings
-from langchain.vectorstores import FAISS
+# from langchain.embeddings.huggingface import HuggingFaceEmbeddings
+from chains.lib.embeddings import MyEmbeddings
+# from langchain.vectorstores import FAISS
+from chains.lib.vectorstores import FAISSVS
 from langchain.document_loaders import UnstructuredFileLoader
 from models.chatglm_llm import ChatGLM
 import sentence_transformers
@@ -50,7 +52,7 @@ class LocalDocQA:
                             use_ptuning_v2=use_ptuning_v2)
         self.llm.history_len = llm_history_len
 
-        self.embeddings = HuggingFaceEmbeddings(model_name=embedding_model_dict[embedding_model],
+        self.embeddings = MyEmbeddings(model_name=embedding_model_dict[embedding_model],
                                                 model_kwargs={'device': embedding_device})
         # self.embeddings.client = sentence_transformers.SentenceTransformer(self.embeddings.model_name,
         #                                                                    device=embedding_device)
@@ -97,12 +99,12 @@ class LocalDocQA:
                     print(f"{file} 未能成功加载")
         if len(docs) > 0:
             if vs_path and os.path.isdir(vs_path):
-                vector_store = FAISS.load_local(vs_path, self.embeddings)
+                vector_store = FAISSVS.load_local(vs_path, self.embeddings)
                 vector_store.add_documents(docs)
             else:
                 if not vs_path:
                     vs_path = f"""{VS_ROOT_PATH}{os.path.splitext(file)[0]}_FAISS_{datetime.datetime.now().strftime("%Y%m%d_%H%M%S")}"""
-                vector_store = FAISS.from_documents(docs, self.embeddings)
+                vector_store = FAISSVS.from_documents(docs, self.embeddings)
 
             vector_store.save_local(vs_path)
             return vs_path, loaded_files
@@ -127,7 +129,7 @@ class LocalDocQA:
             input_variables=["context", "question"]
         )
         self.llm.history = chat_history
-        vector_store = FAISS.load_local(vs_path, self.embeddings)
+        vector_store = FAISSVS.load_local(vs_path, self.embeddings)
         knowledge_chain = RetrievalQA.from_llm(
             llm=self.llm,
             retriever=vector_store.as_retriever(search_kwargs={"k": self.top_k}),

파일 크기가 너무 크기때문에 변경 상태를 표시하지 않습니다.
+ 142 - 0
chains/test.ipynb


이 변경점에서 너무 많은 파일들이 변경되어 몇몇 파일들은 표시되지 않았습니다.