chatglm_llm.py 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182
  1. import json
  2. from langchain.llms.base import LLM
  3. from typing import List, Dict, Optional
  4. from transformers import AutoTokenizer, AutoModel, AutoConfig
  5. import torch
  6. from configs.model_config import *
  7. from utils import torch_gc
  8. DEVICE_ = LLM_DEVICE
  9. DEVICE_ID = "0" if torch.cuda.is_available() else None
  10. DEVICE = f"{DEVICE_}:{DEVICE_ID}" if DEVICE_ID else DEVICE_
  11. def auto_configure_device_map(num_gpus: int) -> Dict[str, int]:
  12. # transformer.word_embeddings 占用1层
  13. # transformer.final_layernorm 和 lm_head 占用1层
  14. # transformer.layers 占用 28 层
  15. # 总共30层分配到num_gpus张卡上
  16. num_trans_layers = 28
  17. per_gpu_layers = 30 / num_gpus
  18. # bugfix: 在linux中调用torch.embedding传入的weight,input不在同一device上,导致RuntimeError
  19. # windows下 model.device 会被设置成 transformer.word_embeddings.device
  20. # linux下 model.device 会被设置成 lm_head.device
  21. # 在调用chat或者stream_chat时,input_ids会被放到model.device上
  22. # 如果transformer.word_embeddings.device和model.device不同,则会导致RuntimeError
  23. # 因此这里将transformer.word_embeddings,transformer.final_layernorm,lm_head都放到第一张卡上
  24. device_map = {'transformer.word_embeddings': 0,
  25. 'transformer.final_layernorm': 0, 'lm_head': 0}
  26. used = 2
  27. gpu_target = 0
  28. for i in range(num_trans_layers):
  29. if used >= per_gpu_layers:
  30. gpu_target += 1
  31. used = 0
  32. assert gpu_target < num_gpus
  33. device_map[f'transformer.layers.{i}'] = gpu_target
  34. used += 1
  35. return device_map
  36. class ChatGLM(LLM):
  37. max_token: int = 10000
  38. temperature: float = 0.01
  39. top_p = 0.9
  40. # history = []
  41. tokenizer: object = None
  42. model: object = None
  43. history_len: int = 10
  44. def __init__(self):
  45. super().__init__()
  46. @property
  47. def _llm_type(self) -> str:
  48. return "ChatGLM"
  49. def _call(self,
  50. prompt: str,
  51. history: List[List[str]] = [],
  52. streaming: bool = STREAMING): # -> Tuple[str, List[List[str]]]:
  53. if streaming:
  54. for inum, (stream_resp, _) in enumerate(self.model.stream_chat(
  55. self.tokenizer,
  56. prompt,
  57. history=history[-self.history_len:-1] if self.history_len > 0 else [],
  58. max_length=self.max_token,
  59. temperature=self.temperature,
  60. )):
  61. torch_gc()
  62. if inum == 0:
  63. history += [[prompt, stream_resp]]
  64. else:
  65. history[-1] = [prompt, stream_resp]
  66. yield stream_resp, history
  67. torch_gc()
  68. else:
  69. response, _ = self.model.chat(
  70. self.tokenizer,
  71. prompt,
  72. history=history[-self.history_len:] if self.history_len > 0 else [],
  73. max_length=self.max_token,
  74. temperature=self.temperature,
  75. )
  76. torch_gc()
  77. history += [[prompt, response]]
  78. yield response, history
  79. torch_gc()
  80. # def chat(self,
  81. # prompt: str) -> str:
  82. # response, _ = self.model.chat(
  83. # self.tokenizer,
  84. # prompt,
  85. # history=self.history[-self.history_len:] if self.history_len > 0 else [],
  86. # max_length=self.max_token,
  87. # temperature=self.temperature,
  88. # )
  89. # torch_gc()
  90. # self.history = self.history + [[None, response]]
  91. # return response
  92. def load_model(self,
  93. model_name_or_path: str = "THUDM/chatglm-6b",
  94. llm_device=LLM_DEVICE,
  95. use_ptuning_v2=False,
  96. use_lora=False,
  97. device_map: Optional[Dict[str, int]] = None,
  98. **kwargs):
  99. self.tokenizer = AutoTokenizer.from_pretrained(
  100. model_name_or_path,
  101. trust_remote_code=True
  102. )
  103. model_config = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=True)
  104. if use_ptuning_v2:
  105. try:
  106. prefix_encoder_file = open('ptuning-v2/config.json', 'r')
  107. prefix_encoder_config = json.loads(prefix_encoder_file.read())
  108. prefix_encoder_file.close()
  109. model_config.pre_seq_len = prefix_encoder_config['pre_seq_len']
  110. model_config.prefix_projection = prefix_encoder_config['prefix_projection']
  111. except Exception as e:
  112. print(e)
  113. print("加载PrefixEncoder config.json失败")
  114. self.model = AutoModel.from_pretrained(model_name_or_path, config=model_config, trust_remote_code=True,
  115. **kwargs)
  116. if LLM_LORA_PATH and use_lora:
  117. from peft import PeftModel
  118. self.model = PeftModel.from_pretrained(self.model, LLM_LORA_PATH)
  119. if torch.cuda.is_available() and llm_device.lower().startswith("cuda"):
  120. # 根据当前设备GPU数量决定是否进行多卡部署
  121. num_gpus = torch.cuda.device_count()
  122. if num_gpus < 2 and device_map is None:
  123. self.model = self.model.half().cuda()
  124. else:
  125. from accelerate import dispatch_model
  126. model = AutoModel.from_pretrained(model_name_or_path, trust_remote_code=True,
  127. config=model_config, **kwargs)
  128. if LLM_LORA_PATH and use_lora:
  129. from peft import PeftModel
  130. model_auto = PeftModel.from_pretrained(model, LLM_LORA_PATH)
  131. # 可传入device_map自定义每张卡的部署情况
  132. if device_map is None:
  133. device_map = auto_configure_device_map(num_gpus)
  134. self.model = dispatch_model(model_auto.half(), device_map=device_map)
  135. else:
  136. self.model = self.model.float().to(llm_device)
  137. if use_ptuning_v2:
  138. try:
  139. prefix_state_dict = torch.load('ptuning-v2/pytorch_model.bin')
  140. new_prefix_state_dict = {}
  141. for k, v in prefix_state_dict.items():
  142. if k.startswith("transformer.prefix_encoder."):
  143. new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v
  144. self.model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict)
  145. self.model.transformer.prefix_encoder.float()
  146. except Exception as e:
  147. print(e)
  148. print("加载PrefixEncoder模型参数失败")
  149. self.model = self.model.eval()
  150. if __name__ == "__main__":
  151. llm = ChatGLM()
  152. llm.load_model(model_name_or_path=llm_model_dict[LLM_MODEL],
  153. llm_device=LLM_DEVICE, )
  154. last_print_len = 0
  155. for resp, history in llm._call("你好", streaming=True):
  156. print(resp[last_print_len:], end="", flush=True)
  157. last_print_len = len(resp)
  158. for resp, history in llm._call("你好", streaming=False):
  159. print(resp)
  160. pass