llama_llm.py 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180
  1. from langchain.llms.base import LLM
  2. import torch
  3. import transformers
  4. from typing import Optional, List, Dict, Any
  5. from models.loader import LoaderCheckPoint
  6. from models.extensions.callback import (Iteratorize, Stream, FixedLengthQueue)
  7. import models.shared as shared
  8. def _streaming_response_template() -> Dict[str, Any]:
  9. """
  10. :return: 响应结构
  11. """
  12. return {
  13. "text": ""
  14. }
  15. def _update_response(response: Dict[str, Any], stream_response: str) -> None:
  16. """Update response from the stream response."""
  17. response["text"] += stream_response
  18. class LLamaLLM(LLM):
  19. llm: LoaderCheckPoint = None
  20. history = []
  21. history_len: int = 10
  22. generate_params: object = {'max_new_tokens': 200,
  23. 'do_sample': True,
  24. 'temperature': 0.7,
  25. 'top_p': 0.1,
  26. 'typical_p': 1,
  27. 'repetition_penalty': 1.18,
  28. 'encoder_repetition_penalty': 1,
  29. 'top_k': 40, 'min_length': 0,
  30. 'no_repeat_ngram_size': 0,
  31. 'num_beams': 1,
  32. 'penalty_alpha': 0,
  33. 'length_penalty': 1,
  34. 'early_stopping': False,
  35. 'eos_token_id': [2],
  36. 'stopping_criteria': []
  37. }
  38. state: object = {'max_new_tokens': 200,
  39. 'seed': 1,
  40. 'temperature': 0, 'top_p': 0.1,
  41. 'top_k': 40, 'typical_p': 1,
  42. 'repetition_penalty': 1.18,
  43. 'encoder_repetition_penalty': 1,
  44. 'no_repeat_ngram_size': 0,
  45. 'min_length': 0,
  46. 'do_sample': True,
  47. 'penalty_alpha': 0,
  48. 'num_beams': 1,
  49. 'length_penalty': 1,
  50. 'early_stopping': False, 'add_bos_token': True, 'ban_eos_token': False,
  51. 'truncation_length': 2048, 'custom_stopping_strings': '',
  52. 'cpu_memory': 0, 'auto_devices': False, 'disk': False, 'cpu': False, 'bf16': False,
  53. 'load_in_8bit': False, 'wbits': 'None', 'groupsize': 'None', 'model_type': 'None',
  54. 'pre_layer': 0, 'gpu_memory_0': 0}
  55. def __init__(self, llm: LoaderCheckPoint = None):
  56. super().__init__()
  57. self.llm = llm
  58. @property
  59. def _llm_type(self) -> str:
  60. return "LLamaLLM"
  61. def encode(self, prompt, add_special_tokens=True, add_bos_token=True, truncation_length=None):
  62. input_ids = self.llm.tokenizer.encode(str(prompt), return_tensors='pt',
  63. add_special_tokens=add_special_tokens)
  64. # This is a hack for making replies more creative.
  65. if not add_bos_token and input_ids[0][0] == self.llm.tokenizer.bos_token_id:
  66. input_ids = input_ids[:, 1:]
  67. # Llama adds this extra token when the first character is '\n', and this
  68. # compromises the stopping criteria, so we just remove it
  69. if type(self.llm.tokenizer) is transformers.LlamaTokenizer and input_ids[0][0] == 29871:
  70. input_ids = input_ids[:, 1:]
  71. # Handling truncation
  72. if truncation_length is not None:
  73. input_ids = input_ids[:, -truncation_length:]
  74. return input_ids.cuda()
  75. def decode(self, output_ids):
  76. reply = self.llm.tokenizer.decode(output_ids, skip_special_tokens=True)
  77. reply = reply.replace(r'<|endoftext|>', '')
  78. return reply
  79. def get_max_prompt_length(self):
  80. max_length = self.state['truncation_length'] - self.state['max_new_tokens']
  81. return max_length
  82. def generate_with_callback(self, callback=None, **kwargs):
  83. kwargs['stopping_criteria'].append(Stream(callback_func=callback))
  84. self.llm.clear_torch_cache()
  85. with torch.no_grad():
  86. self.llm.model.generate(**kwargs)
  87. def generate_with_streaming(self, callback=None, **kwargs):
  88. return Iteratorize(self.generate_with_callback, kwargs, callback)
  89. # 将历史对话数组转换为文本格式
  90. def history_to_text(self):
  91. formatted_history = ''
  92. history = self.history[-self.history_len:] if self.history_len > 0 else []
  93. for entry in history:
  94. role, content = entry
  95. formatted_history += f"### {role}: {content}\n"
  96. return formatted_history
  97. def generate_softprompt_history_tensors(self, input_ids):
  98. """
  99. 历史对话软提示
  100. 这段代码首先定义了一个名为 history_to_text 的函数,用于将 self.history
  101. 数组转换为所需的文本格式。然后,我们将格式化后的历史文本
  102. 再用 self.encode 将其转换为向量表示。最后,将历史对话向量与当前输入的对话向量拼接在一起。
  103. :return:
  104. """
  105. # 对话内容
  106. # 处理历史对话
  107. formatted_history = self.history_to_text()
  108. history_input_ids = self.encode(formatted_history, add_bos_token=self.state['add_bos_token'],
  109. truncation_length=self.get_max_prompt_length())
  110. # 将历史对话向量与当前对话向量拼接
  111. inputs_embeds = torch.cat((history_input_ids, input_ids), dim=1)
  112. filler_input_ids = torch.zeros((1, inputs_embeds.shape[1]), dtype=input_ids.dtype).to(self.llm.model.device)
  113. return inputs_embeds, filler_input_ids
  114. def _call(self,
  115. prompt: str,
  116. stop: Optional[List[str]] = None) -> str:
  117. input_ids = self.encode(prompt, add_bos_token=self.state['add_bos_token'],
  118. truncation_length=self.get_max_prompt_length())
  119. # self.history[-self.history_len:] if self.history_len > 0 else []
  120. output = input_ids[0]
  121. inputs_embeds, filler_input_ids = self.generate_softprompt_history_tensors(input_ids)
  122. # self.generate_params.update({'inputs_embeds': inputs_embeds})
  123. self.generate_params.update({'inputs': inputs_embeds})
  124. shared.stop_everything = False
  125. stopped = False
  126. response_template = _streaming_response_template()
  127. with self.generate_with_streaming(**self.generate_params) as generator:
  128. last_reply_index = 0
  129. # Create a FixedLengthQueue with the desired stop sequence and a maximum length.
  130. queue = FixedLengthQueue(stop)
  131. for output in generator:
  132. new_tokens = len(output) - len(input_ids[0])
  133. reply = self.decode(output[-new_tokens:])
  134. new_reply = len(reply)-last_reply_index
  135. output_reply = reply[-new_reply:]
  136. if last_reply_index > 0 or new_tokens == self.generate_params['max_new_tokens'] - 1 or stopped:
  137. if stop:
  138. queue.add(output_reply)
  139. pos = queue.contains_stop_sequence()
  140. if pos != -1:
  141. shared.stop_everything = True
  142. stopped = True
  143. _update_response(response_template, output_reply)
  144. last_reply_index = len(reply)
  145. if stopped:
  146. break
  147. response = response_template['text']
  148. self.history = self.history + [[None, response]]
  149. return response