chatglm_llm.py 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181
  1. import json
  2. from langchain.llms.base import LLM
  3. from typing import Optional, List
  4. from langchain.llms.utils import enforce_stop_tokens
  5. from transformers import AutoTokenizer, AutoModel, AutoConfig
  6. import torch
  7. from configs.model_config import LLM_DEVICE
  8. from typing import Dict, Tuple, Union, Optional
  9. DEVICE = LLM_DEVICE
  10. DEVICE_ID = "0" if torch.cuda.is_available() else None
  11. CUDA_DEVICE = f"{DEVICE}:{DEVICE_ID}" if DEVICE_ID else DEVICE
  12. def torch_gc():
  13. if torch.cuda.is_available():
  14. with torch.cuda.device(CUDA_DEVICE):
  15. torch.cuda.empty_cache()
  16. torch.cuda.ipc_collect()
  17. def auto_configure_device_map(num_gpus: int) -> Dict[str, int]:
  18. # transformer.word_embeddings 占用1层
  19. # transformer.final_layernorm 和 lm_head 占用1层
  20. # transformer.layers 占用 28 层
  21. # 总共30层分配到num_gpus张卡上
  22. num_trans_layers = 28
  23. per_gpu_layers = 30 / num_gpus
  24. # bugfix: 在linux中调用torch.embedding传入的weight,input不在同一device上,导致RuntimeError
  25. # windows下 model.device 会被设置成 transformer.word_embeddings.device
  26. # linux下 model.device 会被设置成 lm_head.device
  27. # 在调用chat或者stream_chat时,input_ids会被放到model.device上
  28. # 如果transformer.word_embeddings.device和model.device不同,则会导致RuntimeError
  29. # 因此这里将transformer.word_embeddings,transformer.final_layernorm,lm_head都放到第一张卡上
  30. device_map = {'transformer.word_embeddings': 0,
  31. 'transformer.final_layernorm': 0, 'lm_head': 0}
  32. used = 2
  33. gpu_target = 0
  34. for i in range(num_trans_layers):
  35. if used >= per_gpu_layers:
  36. gpu_target += 1
  37. used = 0
  38. assert gpu_target < num_gpus
  39. device_map[f'transformer.layers.{i}'] = gpu_target
  40. used += 1
  41. return device_map
  42. class ChatGLM(LLM):
  43. max_token: int = 10000
  44. temperature: float = 0.01
  45. top_p = 0.9
  46. history = []
  47. tokenizer: object = None
  48. model: object = None
  49. history_len: int = 10
  50. def __init__(self):
  51. super().__init__()
  52. @property
  53. def _llm_type(self) -> str:
  54. return "ChatGLM"
  55. def _call(self,
  56. prompt: str,
  57. stop: Optional[List[str]] = None,
  58. stream=True) -> str:
  59. if stream:
  60. self.history = self.history + [[None, ""]]
  61. for response, history in self.model.stream_chat(
  62. self.tokenizer,
  63. prompt,
  64. history=self.history[-self.history_len:] if self.history_len > 0 else [],
  65. max_length=self.max_token,
  66. temperature=self.temperature,
  67. ):
  68. torch_gc()
  69. self.history[-1][-1] = response
  70. yield response
  71. else:
  72. response, _ = self.model.chat(
  73. self.tokenizer,
  74. prompt,
  75. history=self.history[-self.history_len:] if self.history_len > 0 else [],
  76. max_length=self.max_token,
  77. temperature=self.temperature,
  78. )
  79. torch_gc()
  80. if stop is not None:
  81. response = enforce_stop_tokens(response, stop)
  82. self.history = self.history + [[None, response]]
  83. return response
  84. def chat(self,
  85. prompt: str) -> str:
  86. response, _ = self.model.chat(
  87. self.tokenizer,
  88. prompt,
  89. history=self.history[-self.history_len:] if self.history_len > 0 else [],
  90. max_length=self.max_token,
  91. temperature=self.temperature,
  92. )
  93. torch_gc()
  94. self.history = self.history + [[None, response]]
  95. return response
  96. def load_model(self,
  97. model_name_or_path: str = "THUDM/chatglm-6b",
  98. llm_device=LLM_DEVICE,
  99. use_ptuning_v2=False,
  100. device_map: Optional[Dict[str, int]] = None,
  101. **kwargs):
  102. self.tokenizer = AutoTokenizer.from_pretrained(
  103. model_name_or_path,
  104. trust_remote_code=True
  105. )
  106. model_config = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=True)
  107. if use_ptuning_v2:
  108. try:
  109. prefix_encoder_file = open('ptuning-v2/config.json', 'r')
  110. prefix_encoder_config = json.loads(prefix_encoder_file.read())
  111. prefix_encoder_file.close()
  112. model_config.pre_seq_len = prefix_encoder_config['pre_seq_len']
  113. model_config.prefix_projection = prefix_encoder_config['prefix_projection']
  114. except Exception as e:
  115. print(e)
  116. print("加载PrefixEncoder config.json失败")
  117. if torch.cuda.is_available() and llm_device.lower().startswith("cuda"):
  118. # 根据当前设备GPU数量决定是否进行多卡部署
  119. num_gpus = torch.cuda.device_count()
  120. if num_gpus < 2 and device_map is None:
  121. self.model = (
  122. AutoModel.from_pretrained(
  123. model_name_or_path,
  124. config=model_config,
  125. trust_remote_code=True,
  126. **kwargs)
  127. .half()
  128. .cuda()
  129. )
  130. else:
  131. from accelerate import dispatch_model
  132. model = AutoModel.from_pretrained(model_name_or_path, trust_remote_code=True, **kwargs).half()
  133. # 可传入device_map自定义每张卡的部署情况
  134. if device_map is None:
  135. device_map = auto_configure_device_map(num_gpus)
  136. self.model = dispatch_model(model, device_map=device_map)
  137. else:
  138. self.model = (
  139. AutoModel.from_pretrained(
  140. model_name_or_path,
  141. config=model_config,
  142. trust_remote_code=True)
  143. .float()
  144. .to(llm_device)
  145. )
  146. if use_ptuning_v2:
  147. try:
  148. prefix_state_dict = torch.load('ptuning-v2/pytorch_model.bin')
  149. new_prefix_state_dict = {}
  150. for k, v in prefix_state_dict.items():
  151. if k.startswith("transformer.prefix_encoder."):
  152. new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v
  153. self.model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict)
  154. self.model.transformer.prefix_encoder.float()
  155. except Exception as e:
  156. print(e)
  157. print("加载PrefixEncoder模型参数失败")
  158. self.model = self.model.eval()