local_doc_qa.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116
  1. from langchain.chains import RetrievalQA
  2. from langchain.prompts import PromptTemplate
  3. from langchain.embeddings.huggingface import HuggingFaceEmbeddings
  4. from langchain.vectorstores import FAISS
  5. from langchain.document_loaders import UnstructuredFileLoader
  6. from models.chatglm_llm import ChatGLM
  7. import sentence_transformers
  8. import os
  9. from configs.model_config import *
  10. import datetime
  11. from typing import List
  12. # return top-k text chunk from vector store
  13. VECTOR_SEARCH_TOP_K = 6
  14. # LLM input history length
  15. LLM_HISTORY_LEN = 3
  16. # Show reply with source text from input document
  17. REPLY_WITH_SOURCE = True
  18. class LocalDocQA:
  19. llm: object = None
  20. embeddings: object = None
  21. def init_cfg(self,
  22. embedding_model: str = EMBEDDING_MODEL,
  23. embedding_device=EMBEDDING_DEVICE,
  24. llm_history_len: int = LLM_HISTORY_LEN,
  25. llm_model: str = LLM_MODEL,
  26. llm_device=LLM_DEVICE,
  27. top_k=VECTOR_SEARCH_TOP_K,
  28. ):
  29. self.llm = ChatGLM()
  30. self.llm.load_model(model_name_or_path=llm_model_dict[llm_model],
  31. llm_device=llm_device)
  32. self.llm.history_len = llm_history_len
  33. self.embeddings = HuggingFaceEmbeddings(model_name=embedding_model_dict[embedding_model], )
  34. self.embeddings.client = sentence_transformers.SentenceTransformer(self.embeddings.model_name,
  35. device=embedding_device)
  36. self.top_k = top_k
  37. def init_knowledge_vector_store(self,
  38. filepath: str or List[str]):
  39. if isinstance(filepath, str):
  40. if not os.path.exists(filepath):
  41. print("路径不存在")
  42. return None
  43. elif os.path.isfile(filepath):
  44. file = os.path.split(filepath)[-1]
  45. try:
  46. loader = UnstructuredFileLoader(filepath, mode="elements")
  47. docs = loader.load()
  48. print(f"{file} 已成功加载")
  49. except:
  50. print(f"{file} 未能成功加载")
  51. return None
  52. elif os.path.isdir(filepath):
  53. docs = []
  54. for file in os.listdir(filepath):
  55. fullfilepath = os.path.join(filepath, file)
  56. try:
  57. loader = UnstructuredFileLoader(fullfilepath, mode="elements")
  58. docs += loader.load()
  59. print(f"{file} 已成功加载")
  60. except:
  61. print(f"{file} 未能成功加载")
  62. else:
  63. docs = []
  64. for file in filepath:
  65. try:
  66. loader = UnstructuredFileLoader(file, mode="elements")
  67. docs += loader.load()
  68. print(f"{file} 已成功加载")
  69. except:
  70. print(f"{file} 未能成功加载")
  71. vector_store = FAISS.from_documents(docs, self.embeddings)
  72. vs_path = f"""./vector_store/{os.path.splitext(file)[0]}_FAISS_{datetime.datetime.now().strftime("%Y%m%d_%H%M%S")}"""
  73. vector_store.save_local(vs_path)
  74. return vs_path if len(docs)>0 else None
  75. def get_knowledge_based_answer(self,
  76. query,
  77. vs_path,
  78. chat_history=[], ):
  79. prompt_template = """基于以下已知信息,简洁和专业的来回答用户的问题。
  80. 如果无法从中得到答案,请说 "根据已知信息无法回答该问题" 或 "没有提供足够的相关信息",不允许在答案中添加编造成分,答案请使用中文。
  81. 已知内容:
  82. {context}
  83. 问题:
  84. {question}"""
  85. prompt = PromptTemplate(
  86. template=prompt_template,
  87. input_variables=["context", "question"]
  88. )
  89. self.llm.history = chat_history
  90. vector_store = FAISS.load_local(vs_path, self.embeddings)
  91. knowledge_chain = RetrievalQA.from_llm(
  92. llm=self.llm,
  93. retriever=vector_store.as_retriever(search_kwargs={"k": self.top_k}),
  94. prompt=prompt
  95. )
  96. knowledge_chain.combine_documents_chain.document_prompt = PromptTemplate(
  97. input_variables=["page_content"], template="{page_content}"
  98. )
  99. knowledge_chain.return_source_documents = True
  100. result = knowledge_chain({"query": query})
  101. self.llm.history[-1][0] = query
  102. return result, self.llm.history