conversation.py 8.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261
  1. """
  2. Conversation prompt template.
  3. Now we support
  4. - Vicuna
  5. - Koala
  6. - OpenAssistant/oasst-sft-1-pythia-12b
  7. - StabilityAI/stablelm-tuned-alpha-7b
  8. - databricks/dolly-v2-12b
  9. - THUDM/chatglm-6b
  10. - Alpaca/LLaMa
  11. """
  12. import dataclasses
  13. from enum import auto, Enum
  14. from typing import List, Tuple, Any
  15. class SeparatorStyle(Enum):
  16. """Different separator style."""
  17. SINGLE = auto()
  18. TWO = auto()
  19. DOLLY = auto()
  20. OASST_PYTHIA = auto()
  21. @dataclasses.dataclass
  22. class Conversation:
  23. """A class that keeps all conversation history."""
  24. system: str
  25. roles: List[str]
  26. messages: List[List[str]]
  27. offset: int
  28. sep_style: SeparatorStyle = SeparatorStyle.SINGLE
  29. sep: str = "###"
  30. sep2: str = None
  31. # Used for gradio server
  32. skip_next: bool = False
  33. conv_id: Any = None
  34. def get_prompt(self):
  35. if self.sep_style == SeparatorStyle.SINGLE:
  36. ret = self.system
  37. for role, message in self.messages:
  38. if message:
  39. ret += self.sep + " " + role + ": " + message
  40. else:
  41. ret += self.sep + " " + role + ":"
  42. return ret
  43. elif self.sep_style == SeparatorStyle.TWO:
  44. seps = [self.sep, self.sep2]
  45. ret = self.system + seps[0]
  46. for i, (role, message) in enumerate(self.messages):
  47. if message:
  48. ret += role + ": " + message + seps[i % 2]
  49. else:
  50. ret += role + ":"
  51. return ret
  52. elif self.sep_style == SeparatorStyle.DOLLY:
  53. seps = [self.sep, self.sep2]
  54. ret = self.system
  55. for i, (role, message) in enumerate(self.messages):
  56. if message:
  57. ret += role + ":\n" + message + seps[i % 2]
  58. if i % 2 == 1:
  59. ret += "\n\n"
  60. else:
  61. ret += role + ":\n"
  62. return ret
  63. elif self.sep_style == SeparatorStyle.OASST_PYTHIA:
  64. ret = self.system
  65. for role, message in self.messages:
  66. if message:
  67. ret += role + message + self.sep
  68. else:
  69. ret += role
  70. return ret
  71. else:
  72. raise ValueError(f"Invalid style: {self.sep_style}")
  73. def append_message(self, role, message):
  74. self.messages.append([role, message])
  75. def to_gradio_chatbot(self):
  76. ret = []
  77. for i, (role, msg) in enumerate(self.messages[self.offset :]):
  78. if i % 2 == 0:
  79. ret.append([msg, None])
  80. else:
  81. ret[-1][-1] = msg
  82. return ret
  83. def copy(self):
  84. return Conversation(
  85. system=self.system,
  86. roles=self.roles,
  87. messages=[[x, y] for x, y in self.messages],
  88. offset=self.offset,
  89. sep_style=self.sep_style,
  90. sep=self.sep,
  91. sep2=self.sep2,
  92. conv_id=self.conv_id,
  93. )
  94. def dict(self):
  95. return {
  96. "system": self.system,
  97. "roles": self.roles,
  98. "messages": self.messages,
  99. "offset": self.offset,
  100. "sep": self.sep,
  101. "sep2": self.sep2,
  102. "conv_id": self.conv_id,
  103. }
  104. conv_one_shot = Conversation(
  105. system="A chat between a curious human and an artificial intelligence assistant. "
  106. "The assistant gives helpful, detailed, and polite answers to the human's questions.",
  107. roles=("Human", "Assistant"),
  108. messages=(
  109. (
  110. "Human",
  111. "What are the key differences between renewable and non-renewable energy sources?",
  112. ),
  113. (
  114. "Assistant",
  115. "Renewable energy sources are those that can be replenished naturally in a relatively "
  116. "short amount of time, such as solar, wind, hydro, geothermal, and biomass. "
  117. "Non-renewable energy sources, on the other hand, are finite and will eventually be "
  118. "depleted, such as coal, oil, and natural gas. Here are some key differences between "
  119. "renewable and non-renewable energy sources:\n"
  120. "1. Availability: Renewable energy sources are virtually inexhaustible, while non-renewable "
  121. "energy sources are finite and will eventually run out.\n"
  122. "2. Environmental impact: Renewable energy sources have a much lower environmental impact "
  123. "than non-renewable sources, which can lead to air and water pollution, greenhouse gas emissions, "
  124. "and other negative effects.\n"
  125. "3. Cost: Renewable energy sources can be more expensive to initially set up, but they typically "
  126. "have lower operational costs than non-renewable sources.\n"
  127. "4. Reliability: Renewable energy sources are often more reliable and can be used in more remote "
  128. "locations than non-renewable sources.\n"
  129. "5. Flexibility: Renewable energy sources are often more flexible and can be adapted to different "
  130. "situations and needs, while non-renewable sources are more rigid and inflexible.\n"
  131. "6. Sustainability: Renewable energy sources are more sustainable over the long term, while "
  132. "non-renewable sources are not, and their depletion can lead to economic and social instability.",
  133. ),
  134. ),
  135. offset=2,
  136. sep_style=SeparatorStyle.SINGLE,
  137. sep="###",
  138. )
  139. conv_vicuna_v1_1 = Conversation(
  140. system="A chat between a curious user and an artificial intelligence assistant. "
  141. "The assistant gives helpful, detailed, and polite answers to the user's questions.",
  142. roles=("USER", "ASSISTANT"),
  143. messages=(),
  144. offset=0,
  145. sep_style=SeparatorStyle.TWO,
  146. sep=" ",
  147. sep2="</s>",
  148. )
  149. conv_koala_v1 = Conversation(
  150. system="BEGINNING OF CONVERSATION:",
  151. roles=("USER", "GPT"),
  152. messages=(),
  153. offset=0,
  154. sep_style=SeparatorStyle.TWO,
  155. sep=" ",
  156. sep2="</s>",
  157. )
  158. conv_dolly = Conversation(
  159. system="Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n",
  160. roles=("### Instruction", "### Response"),
  161. messages=(),
  162. offset=0,
  163. sep_style=SeparatorStyle.DOLLY,
  164. sep="\n\n",
  165. sep2="### End",
  166. )
  167. conv_oasst = Conversation(
  168. system="",
  169. roles=("<|prompter|>", "<|assistant|>"),
  170. messages=(),
  171. offset=0,
  172. sep_style=SeparatorStyle.OASST_PYTHIA,
  173. sep="<|endoftext|>",
  174. )
  175. conv_stablelm = Conversation(
  176. system="""<|SYSTEM|># StableLM Tuned (Alpha version)
  177. - StableLM is a helpful and harmless open-source AI language model developed by StabilityAI.
  178. - StableLM is excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user.
  179. - StableLM is more than just an information source, StableLM is also able to write poetry, short stories, and make jokes.
  180. - StableLM will refuse to participate in anything that could harm a human.
  181. """,
  182. roles=("<|USER|>", "<|ASSISTANT|>"),
  183. messages=(),
  184. offset=0,
  185. sep_style=SeparatorStyle.OASST_PYTHIA,
  186. sep="",
  187. )
  188. conv_templates = {
  189. "conv_one_shot": conv_one_shot,
  190. "vicuna_v1.1": conv_vicuna_v1_1,
  191. "koala_v1": conv_koala_v1,
  192. "dolly": conv_dolly,
  193. "oasst": conv_oasst,
  194. }
  195. def get_default_conv_template(model_name):
  196. model_name = model_name.lower()
  197. if "vicuna" in model_name or "output" in model_name:
  198. return conv_vicuna_v1_1
  199. elif "koala" in model_name:
  200. return conv_koala_v1
  201. elif "dolly-v2" in model_name:
  202. return conv_dolly
  203. elif "oasst" in model_name and "pythia" in model_name:
  204. return conv_oasst
  205. elif "stablelm" in model_name:
  206. return conv_stablelm
  207. return conv_one_shot
  208. def compute_skip_echo_len(model_name, conv, prompt):
  209. model_name = model_name.lower()
  210. if "chatglm" in model_name:
  211. skip_echo_len = len(conv.messages[-2][1]) + 1
  212. elif "dolly-v2" in model_name:
  213. special_toks = ["### Instruction:", "### Response:", "### End"]
  214. skip_echo_len = len(prompt)
  215. for tok in special_toks:
  216. skip_echo_len -= prompt.count(tok) * len(tok)
  217. elif "oasst" in model_name and "pythia" in model_name:
  218. special_toks = ["<|prompter|>", "<|assistant|>", "<|endoftext|>"]
  219. skip_echo_len = len(prompt)
  220. for tok in special_toks:
  221. skip_echo_len -= prompt.count(tok) * len(tok)
  222. elif "stablelm" in model_name:
  223. special_toks = ["<|SYSTEM|>", "<|USER|>", "<|ASSISTANT|>"]
  224. skip_echo_len = len(prompt)
  225. for tok in special_toks:
  226. skip_echo_len -= prompt.count(tok) * len(tok)
  227. else:
  228. skip_echo_len = len(prompt) + 1 - prompt.count("</s>") * 3
  229. return skip_echo_len
  230. if __name__ == "__main__":
  231. print(default_conversation.get_prompt())