123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261 |
- """
- Conversation prompt template.
- Now we support
- - Vicuna
- - Koala
- - OpenAssistant/oasst-sft-1-pythia-12b
- - StabilityAI/stablelm-tuned-alpha-7b
- - databricks/dolly-v2-12b
- - THUDM/chatglm-6b
- - Alpaca/LLaMa
- """
- import dataclasses
- from enum import auto, Enum
- from typing import List, Tuple, Any
- class SeparatorStyle(Enum):
- """Different separator style."""
- SINGLE = auto()
- TWO = auto()
- DOLLY = auto()
- OASST_PYTHIA = auto()
- @dataclasses.dataclass
- class Conversation:
- """A class that keeps all conversation history."""
- system: str
- roles: List[str]
- messages: List[List[str]]
- offset: int
- sep_style: SeparatorStyle = SeparatorStyle.SINGLE
- sep: str = "###"
- sep2: str = None
- # Used for gradio server
- skip_next: bool = False
- conv_id: Any = None
- def get_prompt(self):
- if self.sep_style == SeparatorStyle.SINGLE:
- ret = self.system
- for role, message in self.messages:
- if message:
- ret += self.sep + " " + role + ": " + message
- else:
- ret += self.sep + " " + role + ":"
- return ret
- elif self.sep_style == SeparatorStyle.TWO:
- seps = [self.sep, self.sep2]
- ret = self.system + seps[0]
- for i, (role, message) in enumerate(self.messages):
- if message:
- ret += role + ": " + message + seps[i % 2]
- else:
- ret += role + ":"
- return ret
- elif self.sep_style == SeparatorStyle.DOLLY:
- seps = [self.sep, self.sep2]
- ret = self.system
- for i, (role, message) in enumerate(self.messages):
- if message:
- ret += role + ":\n" + message + seps[i % 2]
- if i % 2 == 1:
- ret += "\n\n"
- else:
- ret += role + ":\n"
- return ret
- elif self.sep_style == SeparatorStyle.OASST_PYTHIA:
- ret = self.system
- for role, message in self.messages:
- if message:
- ret += role + message + self.sep
- else:
- ret += role
- return ret
- else:
- raise ValueError(f"Invalid style: {self.sep_style}")
- def append_message(self, role, message):
- self.messages.append([role, message])
- def to_gradio_chatbot(self):
- ret = []
- for i, (role, msg) in enumerate(self.messages[self.offset :]):
- if i % 2 == 0:
- ret.append([msg, None])
- else:
- ret[-1][-1] = msg
- return ret
- def copy(self):
- return Conversation(
- system=self.system,
- roles=self.roles,
- messages=[[x, y] for x, y in self.messages],
- offset=self.offset,
- sep_style=self.sep_style,
- sep=self.sep,
- sep2=self.sep2,
- conv_id=self.conv_id,
- )
- def dict(self):
- return {
- "system": self.system,
- "roles": self.roles,
- "messages": self.messages,
- "offset": self.offset,
- "sep": self.sep,
- "sep2": self.sep2,
- "conv_id": self.conv_id,
- }
- conv_one_shot = Conversation(
- system="A chat between a curious human and an artificial intelligence assistant. "
- "The assistant gives helpful, detailed, and polite answers to the human's questions.",
- roles=("Human", "Assistant"),
- messages=(
- (
- "Human",
- "What are the key differences between renewable and non-renewable energy sources?",
- ),
- (
- "Assistant",
- "Renewable energy sources are those that can be replenished naturally in a relatively "
- "short amount of time, such as solar, wind, hydro, geothermal, and biomass. "
- "Non-renewable energy sources, on the other hand, are finite and will eventually be "
- "depleted, such as coal, oil, and natural gas. Here are some key differences between "
- "renewable and non-renewable energy sources:\n"
- "1. Availability: Renewable energy sources are virtually inexhaustible, while non-renewable "
- "energy sources are finite and will eventually run out.\n"
- "2. Environmental impact: Renewable energy sources have a much lower environmental impact "
- "than non-renewable sources, which can lead to air and water pollution, greenhouse gas emissions, "
- "and other negative effects.\n"
- "3. Cost: Renewable energy sources can be more expensive to initially set up, but they typically "
- "have lower operational costs than non-renewable sources.\n"
- "4. Reliability: Renewable energy sources are often more reliable and can be used in more remote "
- "locations than non-renewable sources.\n"
- "5. Flexibility: Renewable energy sources are often more flexible and can be adapted to different "
- "situations and needs, while non-renewable sources are more rigid and inflexible.\n"
- "6. Sustainability: Renewable energy sources are more sustainable over the long term, while "
- "non-renewable sources are not, and their depletion can lead to economic and social instability.",
- ),
- ),
- offset=2,
- sep_style=SeparatorStyle.SINGLE,
- sep="###",
- )
- conv_vicuna_v1_1 = Conversation(
- system="A chat between a curious user and an artificial intelligence assistant. "
- "The assistant gives helpful, detailed, and polite answers to the user's questions.",
- roles=("USER", "ASSISTANT"),
- messages=(),
- offset=0,
- sep_style=SeparatorStyle.TWO,
- sep=" ",
- sep2="</s>",
- )
- conv_koala_v1 = Conversation(
- system="BEGINNING OF CONVERSATION:",
- roles=("USER", "GPT"),
- messages=(),
- offset=0,
- sep_style=SeparatorStyle.TWO,
- sep=" ",
- sep2="</s>",
- )
- conv_dolly = Conversation(
- system="Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n",
- roles=("### Instruction", "### Response"),
- messages=(),
- offset=0,
- sep_style=SeparatorStyle.DOLLY,
- sep="\n\n",
- sep2="### End",
- )
- conv_oasst = Conversation(
- system="",
- roles=("<|prompter|>", "<|assistant|>"),
- messages=(),
- offset=0,
- sep_style=SeparatorStyle.OASST_PYTHIA,
- sep="<|endoftext|>",
- )
- conv_stablelm = Conversation(
- system="""<|SYSTEM|># StableLM Tuned (Alpha version)
- - StableLM is a helpful and harmless open-source AI language model developed by StabilityAI.
- - StableLM is excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user.
- - StableLM is more than just an information source, StableLM is also able to write poetry, short stories, and make jokes.
- - StableLM will refuse to participate in anything that could harm a human.
- """,
- roles=("<|USER|>", "<|ASSISTANT|>"),
- messages=(),
- offset=0,
- sep_style=SeparatorStyle.OASST_PYTHIA,
- sep="",
- )
- conv_templates = {
- "conv_one_shot": conv_one_shot,
- "vicuna_v1.1": conv_vicuna_v1_1,
- "koala_v1": conv_koala_v1,
- "dolly": conv_dolly,
- "oasst": conv_oasst,
- }
- def get_default_conv_template(model_name):
- model_name = model_name.lower()
- if "vicuna" in model_name or "output" in model_name:
- return conv_vicuna_v1_1
- elif "koala" in model_name:
- return conv_koala_v1
- elif "dolly-v2" in model_name:
- return conv_dolly
- elif "oasst" in model_name and "pythia" in model_name:
- return conv_oasst
- elif "stablelm" in model_name:
- return conv_stablelm
- return conv_one_shot
- def compute_skip_echo_len(model_name, conv, prompt):
- model_name = model_name.lower()
- if "chatglm" in model_name:
- skip_echo_len = len(conv.messages[-2][1]) + 1
- elif "dolly-v2" in model_name:
- special_toks = ["### Instruction:", "### Response:", "### End"]
- skip_echo_len = len(prompt)
- for tok in special_toks:
- skip_echo_len -= prompt.count(tok) * len(tok)
- elif "oasst" in model_name and "pythia" in model_name:
- special_toks = ["<|prompter|>", "<|assistant|>", "<|endoftext|>"]
- skip_echo_len = len(prompt)
- for tok in special_toks:
- skip_echo_len -= prompt.count(tok) * len(tok)
- elif "stablelm" in model_name:
- special_toks = ["<|SYSTEM|>", "<|USER|>", "<|ASSISTANT|>"]
- skip_echo_len = len(prompt)
- for tok in special_toks:
- skip_echo_len -= prompt.count(tok) * len(tok)
- else:
- skip_echo_len = len(prompt) + 1 - prompt.count("</s>") * 3
- return skip_echo_len
- if __name__ == "__main__":
- print(default_conversation.get_prompt())
|