chatglm_llm.py 2.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475
  1. from langchain.llms.base import LLM
  2. from typing import Optional, List
  3. from langchain.llms.utils import enforce_stop_tokens
  4. from transformers import AutoTokenizer, AutoModel
  5. import torch
  6. from configs.model_config import LLM_DEVICE
  7. DEVICE = LLM_DEVICE
  8. DEVICE_ID = "0" if torch.cuda.is_available() else None
  9. CUDA_DEVICE = f"{DEVICE}:{DEVICE_ID}" if DEVICE_ID else DEVICE
  10. def torch_gc():
  11. if torch.cuda.is_available():
  12. with torch.cuda.device(CUDA_DEVICE):
  13. torch.cuda.empty_cache()
  14. torch.cuda.ipc_collect()
  15. class ChatGLM(LLM):
  16. max_token: int = 10000
  17. temperature: float = 0.01
  18. top_p = 0.9
  19. history = []
  20. tokenizer: object = None
  21. model: object = None
  22. history_len: int = 10
  23. def __init__(self):
  24. super().__init__()
  25. @property
  26. def _llm_type(self) -> str:
  27. return "ChatGLM"
  28. def _call(self,
  29. prompt: str,
  30. stop: Optional[List[str]] = None) -> str:
  31. response, _ = self.model.chat(
  32. self.tokenizer,
  33. prompt,
  34. history=self.history[-self.history_len:] if self.history_len>0 else [],
  35. max_length=self.max_token,
  36. temperature=self.temperature,
  37. )
  38. torch_gc()
  39. if stop is not None:
  40. response = enforce_stop_tokens(response, stop)
  41. self.history = self.history+[[None, response]]
  42. return response
  43. def load_model(self,
  44. model_name_or_path: str = "THUDM/chatglm-6b",
  45. llm_device=LLM_DEVICE):
  46. self.tokenizer = AutoTokenizer.from_pretrained(
  47. model_name_or_path,
  48. trust_remote_code=True
  49. )
  50. if torch.cuda.is_available() and llm_device.lower().startswith("cuda"):
  51. self.model = (
  52. AutoModel.from_pretrained(
  53. model_name_or_path,
  54. trust_remote_code=True)
  55. .half()
  56. .cuda()
  57. )
  58. else:
  59. self.model = (
  60. AutoModel.from_pretrained(
  61. model_name_or_path,
  62. trust_remote_code=True)
  63. .float()
  64. .to(llm_device)
  65. )
  66. self.model = self.model.eval()