浏览代码

update local_doc_qa.py

imClumsyPanda 2 年之前
父节点
当前提交
8ae84c6c93
共有 5 个文件被更改,包括 51 次插入11 次删除
  1. 23 10
      chains/local_doc_qa.py
  2. 1 1
      cli_demo.py
  3. 二进制
      content/1.pdf
  4. 2 0
      textsplitter/__init__.py
  5. 25 0
      textsplitter/chinese_text_splitter.py

+ 23 - 10
chains/local_doc_qa.py

@@ -9,6 +9,7 @@ import os
 from configs.model_config import *
 import datetime
 from typing import List
+from textsplitter import ChineseTextSplitter
 
 # return top-k text chunk from vector store
 VECTOR_SEARCH_TOP_K = 6
@@ -17,6 +18,18 @@ VECTOR_SEARCH_TOP_K = 6
 LLM_HISTORY_LEN = 3
 
 
+def load_file(filepath):
+    if filepath.lower().endswith(".pdf"):
+        loader = UnstructuredFileLoader(filepath)
+        textsplitter = ChineseTextSplitter(pdf=True)
+        docs = loader.load_and_split(textsplitter)
+    else:
+        loader = UnstructuredFileLoader(filepath, mode="elements")
+        textsplitter = ChineseTextSplitter(pdf=False)
+        docs = loader.load_and_split(text_splitter=textsplitter)
+    return docs
+
+
 class LocalDocQA:
     llm: object = None
     embeddings: object = None
@@ -48,10 +61,10 @@ class LocalDocQA:
             elif os.path.isfile(filepath):
                 file = os.path.split(filepath)[-1]
                 try:
-                    loader = UnstructuredFileLoader(filepath, mode="elements")
-                    docs = loader.load()
+                    docs = load_file(filepath)
                     print(f"{file} 已成功加载")
-                except:
+                except Exception as e:
+                    print(e)
                     print(f"{file} 未能成功加载")
                     return None
             elif os.path.isdir(filepath):
@@ -59,25 +72,25 @@ class LocalDocQA:
                 for file in os.listdir(filepath):
                     fullfilepath = os.path.join(filepath, file)
                     try:
-                        loader = UnstructuredFileLoader(fullfilepath, mode="elements")
-                        docs += loader.load()
+                        docs += load_file(fullfilepath)
                         print(f"{file} 已成功加载")
-                    except:
+                    except Exception as e:
+                        print(e)
                         print(f"{file} 未能成功加载")
         else:
             docs = []
             for file in filepath:
                 try:
-                    loader = UnstructuredFileLoader(file, mode="elements")
-                    docs += loader.load()
+                    docs += load_file(file)
                     print(f"{file} 已成功加载")
-                except:
+                except Exception as e:
+                    print(e)
                     print(f"{file} 未能成功加载")
 
         vector_store = FAISS.from_documents(docs, self.embeddings)
         vs_path = f"""./vector_store/{os.path.splitext(file)[0]}_FAISS_{datetime.datetime.now().strftime("%Y%m%d_%H%M%S")}"""
         vector_store.save_local(vs_path)
-        return vs_path if len(docs)>0 else None
+        return vs_path if len(docs) > 0 else None
 
     def get_knowledge_based_answer(self,
                                    query,

+ 1 - 1
cli_demo.py

@@ -2,7 +2,7 @@ from configs.model_config import *
 from chains.local_doc_qa import LocalDocQA
 
 # return top-k text chunk from vector store
-VECTOR_SEARCH_TOP_K = 10
+VECTOR_SEARCH_TOP_K = 6
 
 # LLM input history length
 LLM_HISTORY_LEN = 3

二进制
content/1.pdf


+ 2 - 0
textsplitter/__init__.py

@@ -0,0 +1,2 @@
+
+from .chinese_text_splitter import *

+ 25 - 0
textsplitter/chinese_text_splitter.py

@@ -0,0 +1,25 @@
+from langchain.text_splitter import CharacterTextSplitter
+import re
+from typing import List
+
+
+class ChineseTextSplitter(CharacterTextSplitter):
+    def __init__(self, pdf: bool = False, **kwargs):
+        super().__init__(**kwargs)
+        self.pdf = pdf
+
+    def split_text(self, text: str) -> List[str]:
+        if self.pdf:
+            text = re.sub(r"\n{3,}", "\n", text)
+            text = re.sub('\s', ' ', text)
+            text = text.replace("\n\n", "")
+        sent_sep_pattern = re.compile('([﹒﹔﹖﹗.。!?]["’”」』]{0,2}|(?=["‘“「『]{1,2}|$))')  # del :;
+        sent_list = []
+        for ele in sent_sep_pattern.split(text):
+            if sent_sep_pattern.match(ele) and sent_list:
+                sent_list[-1] += ele
+            elif ele:
+                sent_list.append(ele)
+        return sent_list
+
+