knowledge_based_chatglm.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124
  1. from langchain.chains import RetrievalQA
  2. from langchain.prompts import PromptTemplate
  3. from langchain.embeddings.huggingface import HuggingFaceEmbeddings
  4. from langchain.vectorstores import FAISS
  5. from langchain.document_loaders import UnstructuredFileLoader
  6. from chatglm_llm import ChatGLM
  7. import sentence_transformers
  8. import torch
  9. import os
  10. import readline
  11. # Global Parameters
  12. EMBEDDING_MODEL = "text2vec"
  13. VECTOR_SEARCH_TOP_K = 6
  14. LLM_MODEL = "chatglm-6b"
  15. LLM_HISTORY_LEN = 3
  16. DEVICE = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
  17. # Show reply with source text from input document
  18. REPLY_WITH_SOURCE = True
  19. embedding_model_dict = {
  20. "ernie-tiny": "nghuyong/ernie-3.0-nano-zh",
  21. "ernie-base": "nghuyong/ernie-3.0-base-zh",
  22. "text2vec": "GanymedeNil/text2vec-large-chinese",
  23. }
  24. llm_model_dict = {
  25. "chatglm-6b-int4-qe": "THUDM/chatglm-6b-int4-qe",
  26. "chatglm-6b-int4": "THUDM/chatglm-6b-int4",
  27. "chatglm-6b": "THUDM/chatglm-6b",
  28. }
  29. def init_cfg(LLM_MODEL, EMBEDDING_MODEL, LLM_HISTORY_LEN, V_SEARCH_TOP_K=6):
  30. global chatglm, embeddings, VECTOR_SEARCH_TOP_K
  31. VECTOR_SEARCH_TOP_K = V_SEARCH_TOP_K
  32. chatglm = ChatGLM()
  33. chatglm.load_model(model_name_or_path=llm_model_dict[LLM_MODEL])
  34. chatglm.history_len = LLM_HISTORY_LEN
  35. embeddings = HuggingFaceEmbeddings(model_name=embedding_model_dict[EMBEDDING_MODEL],)
  36. embeddings.client = sentence_transformers.SentenceTransformer(embeddings.model_name,
  37. device=DEVICE)
  38. def init_knowledge_vector_store(filepath:str):
  39. if not os.path.exists(filepath):
  40. print("路径不存在")
  41. return None
  42. elif os.path.isfile(filepath):
  43. file = os.path.split(filepath)[-1]
  44. try:
  45. loader = UnstructuredFileLoader(filepath, mode="elements")
  46. docs = loader.load()
  47. print(f"{file} 已成功加载")
  48. except:
  49. print(f"{file} 未能成功加载")
  50. return None
  51. elif os.path.isdir(filepath):
  52. docs = []
  53. for file in os.listdir(filepath):
  54. fullfilepath = os.path.join(filepath, file)
  55. try:
  56. loader = UnstructuredFileLoader(fullfilepath, mode="elements")
  57. docs += loader.load()
  58. print(f"{file} 已成功加载")
  59. except:
  60. print(f"{file} 未能成功加载")
  61. vector_store = FAISS.from_documents(docs, embeddings)
  62. return vector_store
  63. def get_knowledge_based_answer(query, vector_store, chat_history=[]):
  64. global chatglm, embeddings
  65. prompt_template = """基于以下已知信息,简洁和专业的来回答用户的问题。
  66. 如果无法从中得到答案,请说 "根据已知信息无法回答该问题" 或 "没有提供足够的相关信息",不允许在答案中添加编造成分,答案请使用中文。
  67. 已知内容:
  68. {context}
  69. 问题:
  70. {question}"""
  71. prompt = PromptTemplate(
  72. template=prompt_template,
  73. input_variables=["context", "question"]
  74. )
  75. chatglm.history = chat_history
  76. knowledge_chain = RetrievalQA.from_llm(
  77. llm=chatglm,
  78. retriever=vector_store.as_retriever(search_kwargs={"k": VECTOR_SEARCH_TOP_K}),
  79. prompt=prompt
  80. )
  81. knowledge_chain.combine_documents_chain.document_prompt = PromptTemplate(
  82. input_variables=["page_content"], template="{page_content}"
  83. )
  84. knowledge_chain.return_source_documents = True
  85. result = knowledge_chain({"query": query})
  86. chatglm.history[-1][0] = query
  87. return result, chatglm.history
  88. if __name__ == "__main__":
  89. init_cfg(LLM_MODEL, EMBEDDING_MODEL, LLM_HISTORY_LEN)
  90. vector_store = None
  91. while not vector_store:
  92. filepath = input("Input your local knowledge file path 请输入本地知识文件路径:")
  93. vector_store = init_knowledge_vector_store(filepath)
  94. history = []
  95. while True:
  96. query = input("Input your question 请输入问题:")
  97. resp, history = get_knowledge_based_answer(query=query,
  98. vector_store=vector_store,
  99. chat_history=history)
  100. if REPLY_WITH_SOURCE:
  101. print(resp)
  102. else:
  103. print(resp["result"])