local_doc_qa.py 8.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215
  1. from langchain.embeddings.huggingface import HuggingFaceEmbeddings
  2. from langchain.vectorstores import FAISS
  3. from langchain.document_loaders import UnstructuredFileLoader
  4. from models.chatglm_llm import ChatGLM
  5. from configs.model_config import *
  6. import datetime
  7. from textsplitter import ChineseTextSplitter
  8. from typing import List, Tuple
  9. from langchain.docstore.document import Document
  10. import numpy as np
  11. # return top-k text chunk from vector store
  12. VECTOR_SEARCH_TOP_K = 6
  13. # LLM input history length
  14. LLM_HISTORY_LEN = 3
  15. def load_file(filepath):
  16. if filepath.lower().endswith(".md"):
  17. loader = UnstructuredFileLoader(filepath, mode="elements")
  18. docs = loader.load()
  19. elif filepath.lower().endswith(".pdf"):
  20. loader = UnstructuredFileLoader(filepath)
  21. textsplitter = ChineseTextSplitter(pdf=True)
  22. docs = loader.load_and_split(textsplitter)
  23. else:
  24. loader = UnstructuredFileLoader(filepath, mode="elements")
  25. textsplitter = ChineseTextSplitter(pdf=False)
  26. docs = loader.load_and_split(text_splitter=textsplitter)
  27. return docs
  28. def generate_prompt(related_docs: List[str],
  29. query: str,
  30. prompt_template=PROMPT_TEMPLATE) -> str:
  31. context = "\n".join([doc.page_content for doc in related_docs])
  32. prompt = prompt_template.replace("{question}", query).replace("{context}", context)
  33. return prompt
  34. def get_docs_with_score(docs_with_score):
  35. docs=[]
  36. for doc, score in docs_with_score:
  37. doc.metadata["score"] = score
  38. docs.append(doc)
  39. return docs
  40. def seperate_list(ls: List[int]) -> List[List[int]]:
  41. lists = []
  42. ls1 = [ls[0]]
  43. for i in range(1, len(ls)):
  44. if ls[i-1] + 1 == ls[i]:
  45. ls1.append(ls[i])
  46. else:
  47. lists.append(ls1)
  48. ls1 = [ls[i]]
  49. lists.append(ls1)
  50. return lists
  51. def similarity_search_with_score_by_vector(
  52. self,
  53. embedding: List[float],
  54. k: int = 4,
  55. ) -> List[Tuple[Document, float]]:
  56. scores, indices = self.index.search(np.array([embedding], dtype=np.float32), k)
  57. docs = []
  58. id_set = set()
  59. for j, i in enumerate(indices[0]):
  60. if i == -1:
  61. # This happens when not enough docs are returned.
  62. continue
  63. _id = self.index_to_docstore_id[i]
  64. doc = self.docstore.search(_id)
  65. id_set.add(i)
  66. docs_len = len(doc.page_content)
  67. for k in range(1, max(i, len(docs)-i)):
  68. for l in [i+k, i-k]:
  69. if 0 <= l < len(self.index_to_docstore_id):
  70. _id0 = self.index_to_docstore_id[l]
  71. doc0 = self.docstore.search(_id0)
  72. if docs_len + len(doc0.page_content) > self.chunk_size:
  73. break
  74. elif doc0.metadata["source"] == doc.metadata["source"]:
  75. docs_len += len(doc0.page_content)
  76. id_set.add(l)
  77. id_list = sorted(list(id_set))
  78. id_lists = seperate_list(id_list)
  79. for id_seq in id_lists:
  80. for id in id_seq:
  81. if id == id_seq[0]:
  82. _id = self.index_to_docstore_id[id]
  83. doc = self.docstore.search(_id)
  84. else:
  85. _id0 = self.index_to_docstore_id[id]
  86. doc0 = self.docstore.search(_id0)
  87. doc.page_content += doc0.page_content
  88. if not isinstance(doc, Document):
  89. raise ValueError(f"Could not find document for id {_id}, got {doc}")
  90. docs.append((doc, scores[0][j]))
  91. return docs
  92. class LocalDocQA:
  93. llm: object = None
  94. embeddings: object = None
  95. top_k: int = VECTOR_SEARCH_TOP_K
  96. chunk_size: int = CHUNK_SIZE
  97. def init_cfg(self,
  98. embedding_model: str = EMBEDDING_MODEL,
  99. embedding_device=EMBEDDING_DEVICE,
  100. llm_history_len: int = LLM_HISTORY_LEN,
  101. llm_model: str = LLM_MODEL,
  102. llm_device=LLM_DEVICE,
  103. top_k=VECTOR_SEARCH_TOP_K,
  104. use_ptuning_v2: bool = USE_PTUNING_V2
  105. ):
  106. self.llm = ChatGLM()
  107. self.llm.load_model(model_name_or_path=llm_model_dict[llm_model],
  108. llm_device=llm_device,
  109. use_ptuning_v2=use_ptuning_v2)
  110. self.llm.history_len = llm_history_len
  111. self.embeddings = HuggingFaceEmbeddings(model_name=embedding_model_dict[embedding_model],
  112. model_kwargs={'device': embedding_device})
  113. self.top_k = top_k
  114. def init_knowledge_vector_store(self,
  115. filepath: str or List[str],
  116. vs_path: str or os.PathLike = None):
  117. loaded_files = []
  118. if isinstance(filepath, str):
  119. if not os.path.exists(filepath):
  120. print("路径不存在")
  121. return None
  122. elif os.path.isfile(filepath):
  123. file = os.path.split(filepath)[-1]
  124. try:
  125. docs = load_file(filepath)
  126. print(f"{file} 已成功加载")
  127. loaded_files.append(filepath)
  128. except Exception as e:
  129. print(e)
  130. print(f"{file} 未能成功加载")
  131. return None
  132. elif os.path.isdir(filepath):
  133. docs = []
  134. for file in os.listdir(filepath):
  135. fullfilepath = os.path.join(filepath, file)
  136. try:
  137. docs += load_file(fullfilepath)
  138. print(f"{file} 已成功加载")
  139. loaded_files.append(fullfilepath)
  140. except Exception as e:
  141. print(e)
  142. print(f"{file} 未能成功加载")
  143. else:
  144. docs = []
  145. for file in filepath:
  146. try:
  147. docs += load_file(file)
  148. print(f"{file} 已成功加载")
  149. loaded_files.append(file)
  150. except Exception as e:
  151. print(e)
  152. print(f"{file} 未能成功加载")
  153. if len(docs) > 0:
  154. if vs_path and os.path.isdir(vs_path):
  155. vector_store = FAISS.load_local(vs_path, self.embeddings)
  156. vector_store.add_documents(docs)
  157. else:
  158. if not vs_path:
  159. vs_path = f"""{VS_ROOT_PATH}{os.path.splitext(file)[0]}_FAISS_{datetime.datetime.now().strftime("%Y%m%d_%H%M%S")}"""
  160. vector_store = FAISS.from_documents(docs, self.embeddings)
  161. vector_store.save_local(vs_path)
  162. return vs_path, loaded_files
  163. else:
  164. print("文件均未成功加载,请检查依赖包或替换为其他文件再次上传。")
  165. return None, loaded_files
  166. def get_knowledge_based_answer(self,
  167. query,
  168. vs_path,
  169. chat_history=[],
  170. streaming=True):
  171. self.llm.streaming = streaming
  172. vector_store = FAISS.load_local(vs_path, self.embeddings)
  173. FAISS.similarity_search_with_score_by_vector = similarity_search_with_score_by_vector
  174. vector_store.chunk_size=self.chunk_size
  175. related_docs_with_score = vector_store.similarity_search_with_score(query,
  176. k=self.top_k)
  177. related_docs = get_docs_with_score(related_docs_with_score)
  178. prompt = generate_prompt(related_docs, query)
  179. if streaming:
  180. for result, history in self.llm._call(prompt=prompt,
  181. history=chat_history):
  182. history[-1][0] = query
  183. response = {"query": query,
  184. "result": result,
  185. "source_documents": related_docs}
  186. yield response, history
  187. else:
  188. result, history = self.llm._call(prompt=prompt,
  189. history=chat_history)
  190. history[-1][0] = query
  191. response = {"query": query,
  192. "result": result,
  193. "source_documents": related_docs}
  194. return response, history