Browse Source

update local_doc_qa.py

imClumsyPanda 2 năm trước cách đây
mục cha
commit
a4e67a67b4
1 tập tin đã thay đổi với 16 bổ sung7 xóa
  1. 16 7
      chains/local_doc_qa.py

+ 16 - 7
chains/local_doc_qa.py

@@ -82,15 +82,19 @@ def similarity_search_with_score_by_vector(
         id_set.add(i)
         docs_len = len(doc.page_content)
         for k in range(1, max(i, len(docs) - i)):
+            break_flag = False
             for l in [i + k, i - k]:
                 if 0 <= l < len(self.index_to_docstore_id):
                     _id0 = self.index_to_docstore_id[l]
                     doc0 = self.docstore.search(_id0)
                     if docs_len + len(doc0.page_content) > self.chunk_size:
+                        break_flag=True
                         break
                     elif doc0.metadata["source"] == doc.metadata["source"]:
                         docs_len += len(doc0.page_content)
                         id_set.add(l)
+            if break_flag:
+                break
     id_list = sorted(list(id_set))
     id_lists = seperate_list(id_list)
     for id_seq in id_lists:
@@ -225,8 +229,8 @@ class LocalDocQA:
 if __name__ == "__main__":
     local_doc_qa = LocalDocQA()
     local_doc_qa.init_cfg()
-    query = "你好"
-    vs_path = "/Users/liuqian/Downloads/glm-dev/vector_store/123"
+    query = "本项目使用的embedding模型是什么,消耗多少显存"
+    vs_path = "/Users/liuqian/Downloads/glm-dev/vector_store/aaa"
     last_print_len = 0
     for resp, history in local_doc_qa.get_knowledge_based_answer(query=query,
                                                                  vs_path=vs_path,
@@ -234,9 +238,14 @@ if __name__ == "__main__":
                                                                  streaming=True):
         print(resp["result"][last_print_len:], end="", flush=True)
         last_print_len = len(resp["result"])
-    for resp, history in local_doc_qa.get_knowledge_based_answer(query=query,
-                                                                 vs_path=vs_path,
-                                                                 chat_history=[],
-                                                                 streaming=False):
-        print(resp["result"])
+    source_text = [f"""出处 [{inum + 1}] {os.path.split(doc.metadata['source'])[-1]}:\n\n{doc.page_content}\n\n"""
+                   # f"""相关度:{doc.metadata['score']}\n\n"""
+                   for inum, doc in
+                   enumerate(resp["source_documents"])]
+    print("\n\n" + "\n\n".join(source_text))
+    # for resp, history in local_doc_qa.get_knowledge_based_answer(query=query,
+    #                                                              vs_path=vs_path,
+    #                                                              chat_history=[],
+    #                                                              streaming=False):
+    #     print(resp["result"])
     pass