chatglm_with_shared_memory_openai_llm.py 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128
  1. import torch
  2. from langchain.agents import ZeroShotAgent, Tool, AgentExecutor
  3. from langchain.llms import OpenAI
  4. from langchain.memory import ConversationBufferMemory, ReadOnlySharedMemory
  5. from langchain.chains import LLMChain, RetrievalQA
  6. from langchain.embeddings.huggingface import HuggingFaceEmbeddings
  7. from langchain.prompts import PromptTemplate
  8. from langchain.text_splitter import CharacterTextSplitter
  9. from langchain.vectorstores import Chroma
  10. from langchain.document_loaders import TextLoader
  11. from models import ChatGLM
  12. import sentence_transformers
  13. import os
  14. import readline
  15. from pathlib import Path
  16. class ChatglmWithSharedMemoryOpenaiLLM:
  17. def __init__(self, params: dict = None):
  18. params = params or {}
  19. self.embedding_model = params.get('embedding_model', 'text2vec')
  20. self.vector_search_top_k = params.get('vector_search_top_k', 6)
  21. self.llm_model = params.get('llm_model', 'chatglm-6b')
  22. self.llm_history_len = params.get('llm_history_len', 10)
  23. self.device = 'cuda' if params.get('use_cuda', False) else 'cpu'
  24. self._embedding_model_dict = {
  25. "ernie-tiny": "nghuyong/ernie-3.0-nano-zh",
  26. "ernie-base": "nghuyong/ernie-3.0-base-zh",
  27. "text2vec": "GanymedeNil/text2vec-large-chinese",
  28. }
  29. self._llm_model_dict = {
  30. "chatglm-6b-int4-qe": "THUDM/chatglm-6b-int4-qe",
  31. "chatglm-6b-int4": "THUDM/chatglm-6b-int4",
  32. "chatglm-6b": "THUDM/chatglm-6b",
  33. }
  34. self.init_cfg()
  35. self.init_docsearch()
  36. self.init_state_of_history()
  37. self.summry_chain, self.memory = self.agents_answer()
  38. self.agent_chain = self.create_agent_chain()
  39. def init_cfg(self):
  40. self.chatglm = ChatGLM()
  41. self.chatglm.load_model(model_name_or_path=self._llm_model_dict[self.llm_model])
  42. self.chatglm.history_len = self.llm_history_len
  43. self.embeddings = HuggingFaceEmbeddings(model_name=self._embedding_model_dict[self.embedding_model],)
  44. self.embeddings.client = sentence_transformers.SentenceTransformer(self.embeddings.model_name,
  45. device=self.device)
  46. def init_docsearch(self):
  47. doc_path = str(Path.cwd() / "content/state_of_the_search.txt")
  48. loader = TextLoader(doc_path)
  49. documents = loader.load()
  50. text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
  51. texts = text_splitter.split_documents(documents)
  52. docsearch = Chroma.from_documents(texts, self.embeddings, collection_name="state-of-search")
  53. self.state_of_search = RetrievalQA.from_chain_type(llm=self.chatglm, chain_type="stuff", retriever=docsearch.as_retriever())
  54. def init_state_of_history(self):
  55. doc_path = str(Path.cwd() / "content/state_of_the_history.txt")
  56. loader = TextLoader(doc_path)
  57. documents = loader.load()
  58. text_splitter = CharacterTextSplitter(chunk_size=100, chunk_overlap=0)
  59. texts = text_splitter.split_documents(documents)
  60. docsearch = Chroma.from_documents(texts, self.embeddings, collection_name="state-of-history")
  61. self.state_of_history = RetrievalQA.from_chain_type(llm=self.chatglm, chain_type="stuff", retriever=docsearch.as_retriever())
  62. def agents_answer(self):
  63. template = """This is a conversation between a human and a bot:
  64. {chat_history}
  65. Write a summary of the conversation for {input}:
  66. """
  67. prompt = PromptTemplate(
  68. input_variables=["input", "chat_history"],
  69. template=template
  70. )
  71. memory = ConversationBufferMemory(memory_key="chat_history")
  72. readonlymemory = ReadOnlySharedMemory(memory=memory)
  73. summry_chain = LLMChain(
  74. llm=self.chatglm,
  75. prompt=prompt,
  76. verbose=True,
  77. memory=readonlymemory, # use the read-only memory to prevent the tool from modifying the memory
  78. )
  79. return summry_chain, memory
  80. def create_agent_chain(self):
  81. tools = [
  82. Tool(
  83. name="State of Search QA System",
  84. func=self.state_of_search.run,
  85. description="当您需要搜索有关问题时非常有用。输入应该是一个完整的问题。"
  86. ),
  87. Tool(
  88. name="state-of-history-qa",
  89. func=self.state_of_history.run,
  90. description="跟露露的历史对话 - 当提出我们之间发生了什么事请时,这里面的回答是很有用的"
  91. ),
  92. Tool(
  93. name="Summary",
  94. func=self.summry_chain.run,
  95. description="useful for when you summarize a conversation. The input to this tool should be a string, representing who will read this summary."
  96. )
  97. ]
  98. prefix = """你需要充当一个倾听者,尽量回答人类的问题,你可以使用这里工具,它们非常有用:"""
  99. suffix = """Begin!
  100. {chat_history}
  101. Question: {input}
  102. {agent_scratchpad}"""
  103. prompt = ZeroShotAgent.create_prompt(
  104. tools,
  105. prefix=prefix,
  106. suffix=suffix,
  107. input_variables=["input", "chat_history", "agent_scratchpad"]
  108. )
  109. llm_chain = LLMChain(llm=OpenAI(temperature=0), prompt=prompt)
  110. agent = ZeroShotAgent(llm_chain=llm_chain, tools=tools, verbose=True)
  111. agent_chain = AgentExecutor.from_agent_and_tools(agent=agent, tools=tools, verbose=True, memory=self.memory)
  112. return agent_chain