123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128 |
- import torch
- from langchain.agents import ZeroShotAgent, Tool, AgentExecutor
- from langchain.llms import OpenAI
- from langchain.memory import ConversationBufferMemory, ReadOnlySharedMemory
- from langchain.chains import LLMChain, RetrievalQA
- from langchain.embeddings.huggingface import HuggingFaceEmbeddings
- from langchain.prompts import PromptTemplate
- from langchain.text_splitter import CharacterTextSplitter
- from langchain.vectorstores import Chroma
- from langchain.document_loaders import TextLoader
- from models import ChatGLM
- import sentence_transformers
- import os
- import readline
- from pathlib import Path
- class ChatglmWithSharedMemoryOpenaiLLM:
- def __init__(self, params: dict = None):
- params = params or {}
- self.embedding_model = params.get('embedding_model', 'text2vec')
- self.vector_search_top_k = params.get('vector_search_top_k', 6)
- self.llm_model = params.get('llm_model', 'chatglm-6b')
- self.llm_history_len = params.get('llm_history_len', 10)
- self.device = 'cuda' if params.get('use_cuda', False) else 'cpu'
- self._embedding_model_dict = {
- "ernie-tiny": "nghuyong/ernie-3.0-nano-zh",
- "ernie-base": "nghuyong/ernie-3.0-base-zh",
- "text2vec": "GanymedeNil/text2vec-large-chinese",
- }
- self._llm_model_dict = {
- "chatglm-6b-int4-qe": "THUDM/chatglm-6b-int4-qe",
- "chatglm-6b-int4": "THUDM/chatglm-6b-int4",
- "chatglm-6b": "THUDM/chatglm-6b",
- }
- self.init_cfg()
- self.init_docsearch()
- self.init_state_of_history()
- self.summry_chain, self.memory = self.agents_answer()
- self.agent_chain = self.create_agent_chain()
- def init_cfg(self):
- self.chatglm = ChatGLM()
- self.chatglm.load_model(model_name_or_path=self._llm_model_dict[self.llm_model])
- self.chatglm.history_len = self.llm_history_len
- self.embeddings = HuggingFaceEmbeddings(model_name=self._embedding_model_dict[self.embedding_model],)
- self.embeddings.client = sentence_transformers.SentenceTransformer(self.embeddings.model_name,
- device=self.device)
- def init_docsearch(self):
- doc_path = str(Path.cwd() / "content/state_of_the_search.txt")
- loader = TextLoader(doc_path)
- documents = loader.load()
- text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
- texts = text_splitter.split_documents(documents)
- docsearch = Chroma.from_documents(texts, self.embeddings, collection_name="state-of-search")
- self.state_of_search = RetrievalQA.from_chain_type(llm=self.chatglm, chain_type="stuff", retriever=docsearch.as_retriever())
- def init_state_of_history(self):
- doc_path = str(Path.cwd() / "content/state_of_the_history.txt")
- loader = TextLoader(doc_path)
- documents = loader.load()
- text_splitter = CharacterTextSplitter(chunk_size=100, chunk_overlap=0)
- texts = text_splitter.split_documents(documents)
- docsearch = Chroma.from_documents(texts, self.embeddings, collection_name="state-of-history")
- self.state_of_history = RetrievalQA.from_chain_type(llm=self.chatglm, chain_type="stuff", retriever=docsearch.as_retriever())
- def agents_answer(self):
- template = """This is a conversation between a human and a bot:
- {chat_history}
- Write a summary of the conversation for {input}:
- """
- prompt = PromptTemplate(
- input_variables=["input", "chat_history"],
- template=template
- )
- memory = ConversationBufferMemory(memory_key="chat_history")
- readonlymemory = ReadOnlySharedMemory(memory=memory)
- summry_chain = LLMChain(
- llm=self.chatglm,
- prompt=prompt,
- verbose=True,
- memory=readonlymemory, # use the read-only memory to prevent the tool from modifying the memory
- )
- return summry_chain, memory
- def create_agent_chain(self):
- tools = [
- Tool(
- name="State of Search QA System",
- func=self.state_of_search.run,
- description="当您需要搜索有关问题时非常有用。输入应该是一个完整的问题。"
- ),
- Tool(
- name="state-of-history-qa",
- func=self.state_of_history.run,
- description="跟露露的历史对话 - 当提出我们之间发生了什么事请时,这里面的回答是很有用的"
- ),
- Tool(
- name="Summary",
- func=self.summry_chain.run,
- description="useful for when you summarize a conversation. The input to this tool should be a string, representing who will read this summary."
- )
- ]
- prefix = """你需要充当一个倾听者,尽量回答人类的问题,你可以使用这里工具,它们非常有用:"""
- suffix = """Begin!
- {chat_history}
- Question: {input}
- {agent_scratchpad}"""
- prompt = ZeroShotAgent.create_prompt(
- tools,
- prefix=prefix,
- suffix=suffix,
- input_variables=["input", "chat_history", "agent_scratchpad"]
- )
- llm_chain = LLMChain(llm=OpenAI(temperature=0), prompt=prompt)
- agent = ZeroShotAgent(llm_chain=llm_chain, tools=tools, verbose=True)
- agent_chain = AgentExecutor.from_agent_and_tools(agent=agent, tools=tools, verbose=True, memory=self.memory)
- return agent_chain
|