|
@@ -9,6 +9,7 @@ import os
|
|
|
from configs.model_config import *
|
|
|
import datetime
|
|
|
from typing import List
|
|
|
+from textsplitter import ChineseTextSplitter
|
|
|
|
|
|
# return top-k text chunk from vector store
|
|
|
VECTOR_SEARCH_TOP_K = 6
|
|
@@ -17,6 +18,18 @@ VECTOR_SEARCH_TOP_K = 6
|
|
|
LLM_HISTORY_LEN = 3
|
|
|
|
|
|
|
|
|
+def load_file(filepath):
|
|
|
+ if filepath.lower().endswith(".pdf"):
|
|
|
+ loader = UnstructuredFileLoader(filepath)
|
|
|
+ textsplitter = ChineseTextSplitter(pdf=True)
|
|
|
+ docs = loader.load_and_split(textsplitter)
|
|
|
+ else:
|
|
|
+ loader = UnstructuredFileLoader(filepath, mode="elements")
|
|
|
+ textsplitter = ChineseTextSplitter(pdf=False)
|
|
|
+ docs = loader.load_and_split(text_splitter=textsplitter)
|
|
|
+ return docs
|
|
|
+
|
|
|
+
|
|
|
class LocalDocQA:
|
|
|
llm: object = None
|
|
|
embeddings: object = None
|
|
@@ -48,10 +61,10 @@ class LocalDocQA:
|
|
|
elif os.path.isfile(filepath):
|
|
|
file = os.path.split(filepath)[-1]
|
|
|
try:
|
|
|
- loader = UnstructuredFileLoader(filepath, mode="elements")
|
|
|
- docs = loader.load()
|
|
|
+ docs = load_file(filepath)
|
|
|
print(f"{file} 已成功加载")
|
|
|
- except:
|
|
|
+ except Exception as e:
|
|
|
+ print(e)
|
|
|
print(f"{file} 未能成功加载")
|
|
|
return None
|
|
|
elif os.path.isdir(filepath):
|
|
@@ -59,25 +72,25 @@ class LocalDocQA:
|
|
|
for file in os.listdir(filepath):
|
|
|
fullfilepath = os.path.join(filepath, file)
|
|
|
try:
|
|
|
- loader = UnstructuredFileLoader(fullfilepath, mode="elements")
|
|
|
- docs += loader.load()
|
|
|
+ docs += load_file(fullfilepath)
|
|
|
print(f"{file} 已成功加载")
|
|
|
- except:
|
|
|
+ except Exception as e:
|
|
|
+ print(e)
|
|
|
print(f"{file} 未能成功加载")
|
|
|
else:
|
|
|
docs = []
|
|
|
for file in filepath:
|
|
|
try:
|
|
|
- loader = UnstructuredFileLoader(file, mode="elements")
|
|
|
- docs += loader.load()
|
|
|
+ docs += load_file(file)
|
|
|
print(f"{file} 已成功加载")
|
|
|
- except:
|
|
|
+ except Exception as e:
|
|
|
+ print(e)
|
|
|
print(f"{file} 未能成功加载")
|
|
|
|
|
|
vector_store = FAISS.from_documents(docs, self.embeddings)
|
|
|
vs_path = f"""./vector_store/{os.path.splitext(file)[0]}_FAISS_{datetime.datetime.now().strftime("%Y%m%d_%H%M%S")}"""
|
|
|
vector_store.save_local(vs_path)
|
|
|
- return vs_path if len(docs)>0 else None
|
|
|
+ return vs_path if len(docs) > 0 else None
|
|
|
|
|
|
def get_knowledge_based_answer(self,
|
|
|
query,
|