chatglm_llm.py 1.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849
  1. from langchain.llms.base import LLM
  2. from typing import Optional, List
  3. from langchain.llms.utils import enforce_stop_tokens
  4. from transformers import AutoTokenizer, AutoModel
  5. """ChatGLM_G is a wrapper around the ChatGLM model to fit LangChain framework. May not be an optimal implementation"""
  6. class ChatGLM(LLM):
  7. max_token: int = 10000
  8. temperature: float = 0.1
  9. top_p = 0.9
  10. history = []
  11. tokenizer = AutoTokenizer.from_pretrained(
  12. "THUDM/chatglm-6b",
  13. trust_remote_code=True
  14. )
  15. model = (
  16. AutoModel.from_pretrained(
  17. "THUDM/chatglm-6b",
  18. trust_remote_code=True)
  19. .half()
  20. .cuda()
  21. )
  22. def __init__(self):
  23. super().__init__()
  24. @property
  25. def _llm_type(self) -> str:
  26. return "ChatGLM"
  27. def _call(self,
  28. prompt: str,
  29. stop: Optional[List[str]] = None) -> str:
  30. response, updated_history = self.model.chat(
  31. self.tokenizer,
  32. prompt,
  33. history=self.history,
  34. max_length=self.max_token,
  35. temperature=self.temperature,
  36. )
  37. print("history: ", self.history)
  38. if stop is not None:
  39. response = enforce_stop_tokens(response, stop)
  40. self.history = updated_history
  41. return response