123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459 |
- """Wrapper around FastChat APIs."""
- from __future__ import annotations
- import logging
- import sys
- import warnings
- from typing import (
- AbstractSet,
- Any,
- Callable,
- Collection,
- Dict,
- Generator,
- List,
- Literal,
- Mapping,
- Optional,
- Set,
- Tuple,
- Union,
- )
- from pydantic import Extra, Field, root_validator
- from tenacity import (
- before_sleep_log,
- retry,
- retry_if_exception_type,
- stop_after_attempt,
- wait_exponential,
- )
- from langchain.llms.base import BaseLLM
- from langchain.schema import Generation, LLMResult
- from langchain.utils import get_from_dict_or_env
- import requests
- import json
- logger = logging.getLogger(__name__)
- FAST_CHAT_API = "http://localhost:21002/worker_generate_stream"
- def _streaming_response_template() -> Dict[str, Any]:
- """
- :return: 响应结构
- """
- return {
- "text": "",
- "error_code": 0,
- }
- def _update_response(response: Dict[str, Any], stream_response: Dict[str, Any]) -> None:
- """Update response from the stream response."""
- response["text"] += stream_response["text"]
- response["error_code"] += stream_response["error_code"]
- class BaseFastChat(BaseLLM):
- """Wrapper around FastChat large language models."""
- model_name: str = "text-davinci-003"
- """Model name to use."""
- temperature: float = 0.7
- """What sampling temperature to use."""
- max_new_tokens: int = 200
- stop: int = 20
- batch_size: int = 20
- """Maximum number of retries to make when generating."""
- streaming: bool = False
- """Penalizes repeated tokens."""
- n: int = 1
- """Whether to stream the results or not."""
- allowed_special: Union[Literal["all"], AbstractSet[str]] = set()
- """Set of special tokens that are allowed。"""
- disallowed_special: Union[Literal["all"], Collection[str]] = "all"
- """Set of special tokens that are not allowed。"""
- class Config:
- """Configuration for this pydantic object."""
- extra = Extra.ignore
- @root_validator(pre=True)
- def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
- """Build extra kwargs from additional params that were passed in."""
- all_required_field_names = {field.alias for field in cls.__fields__.values()}
- extra = values.get("model_kwargs", {})
- for field_name in list(values):
- if field_name not in all_required_field_names:
- if field_name in extra:
- raise ValueError(f"Found {field_name} supplied twice.")
- logger.warning(
- f"""WARNING! {field_name} is not default parameter.
- {field_name} was transfered to model_kwargs.
- Please confirm that {field_name} is what you intended."""
- )
- extra[field_name] = values.pop(field_name)
- values["model_kwargs"] = extra
- return values
- @property
- def _default_params(self) -> Dict[str, Any]:
- """Get the default parameters for calling FastChat API."""
- normal_params = {
- "model": self.model_name,
- "prompt": '',
- "max_new_tokens": self.max_new_tokens,
- "temperature": self.temperature,
- }
- return {**normal_params}
- def _generate(
- self, prompts: List[str], stop: Optional[List[str]] = None
- ) -> LLMResult:
- """Call out to FastChat's endpoint with k unique prompts.
- Args:
- prompts: The prompts to pass into the model.
- stop: Optional list of stop words to use when generating.
- Returns:
- The full LLM output.
- Example:
- .. code-block:: python
- response = fastchat.generate(["Tell me a joke."])
- """
- # TODO: write a unit test for this
- params = self._invocation_params
- sub_prompts = self.get_sub_prompts(params, prompts)
- choices = []
- token_usage: Dict[str, int] = {}
- headers = {"User-Agent": "fastchat Client"}
- for _prompts in sub_prompts:
- params["prompt"] = _prompts[0]
- if stop is not None:
- if "stop" in params:
- raise ValueError("`stop` found in both the input and default params.")
- params["stop"] = stop
- if self.streaming:
- if len(_prompts) > 1:
- raise ValueError("Cannot stream results with multiple prompts.")
- response_template = _streaming_response_template()
- response = requests.post(
- FAST_CHAT_API,
- headers=headers,
- json=params,
- stream=True,
- )
- for stream_resp in response.iter_lines(
- chunk_size=8192, decode_unicode=False, delimiter=b"\0"
- ):
- if stream_resp:
- data = json.loads(stream_resp.decode("utf-8"))
- skip_echo_len = len(_prompts[0])
- output = data["text"][skip_echo_len:].strip()
- data["text"] = output
- self.callback_manager.on_llm_new_token(
- output,
- verbose=self.verbose,
- logprobs=data["error_code"],
- )
- _update_response(response_template, data)
- choices.append(response_template)
- else:
- response_template = _streaming_response_template()
- response = requests.post(
- FAST_CHAT_API,
- headers=headers,
- json=params,
- stream=True,
- )
- for stream_resp in response.iter_lines(
- chunk_size=8192, decode_unicode=False, delimiter=b"\0"
- ):
- if stream_resp:
- data = json.loads(stream_resp.decode("utf-8"))
- skip_echo_len = len(_prompts[0])
- output = data["text"][skip_echo_len:].strip()
- data["text"] = output
- _update_response(response_template, data)
- choices.append(response_template)
- return self.create_llm_result(choices, prompts, token_usage)
- async def _agenerate(
- self, prompts: List[str], stop: Optional[List[str]] = None
- ) -> LLMResult:
- """Call out to FastChat's endpoint async with k unique prompts."""
- params = self._invocation_params
- sub_prompts = self.get_sub_prompts(params, prompts)
- choices = []
- token_usage: Dict[str, int] = {}
- headers = {"User-Agent": "fastchat Client"}
- for _prompts in sub_prompts:
- params["prompt"] = _prompts[0]
- if stop is not None:
- if "stop" in params:
- raise ValueError("`stop` found in both the input and default params.")
- params["stop"] = stop
- if self.streaming:
- if len(_prompts) > 1:
- raise ValueError("Cannot stream results with multiple prompts.")
- response_template = _streaming_response_template()
- response = requests.post(
- FAST_CHAT_API,
- headers=headers,
- json=params,
- stream=True,
- )
- for stream_resp in response.iter_lines(
- chunk_size=8192, decode_unicode=False, delimiter=b"\0"
- ):
- if stream_resp:
- data = json.loads(stream_resp.decode("utf-8"))
- skip_echo_len = len(_prompts[0])
- output = data["text"][skip_echo_len:].strip()
- data["text"] = output
- self.callback_manager.on_llm_new_token(
- output,
- verbose=self.verbose,
- logprobs=data["error_code"],
- )
- _update_response(response_template, data)
- choices.append(response_template)
- else:
- response_template = _streaming_response_template()
- response = requests.post(
- FAST_CHAT_API,
- headers=headers,
- json=params,
- stream=True,
- )
- for stream_resp in response.iter_lines(
- chunk_size=8192, decode_unicode=False, delimiter=b"\0"
- ):
- if stream_resp:
- data = json.loads(stream_resp.decode("utf-8"))
- skip_echo_len = len(_prompts[0])
- output = data["text"][skip_echo_len:].strip()
- data["text"] = output
- _update_response(response_template, data)
- choices.append(response_template)
- return self.create_llm_result(choices, prompts, token_usage)
- def get_sub_prompts(
- self,
- params: Dict[str, Any],
- prompts: List[str],
- ) -> List[List[str]]:
- """Get the sub prompts for llm call."""
- if params["max_new_tokens"] == -1:
- if len(prompts) != 1:
- raise ValueError(
- "max_new_tokens set to -1 not supported for multiple inputs."
- )
- params["max_new_tokens"] = self.max_new_tokens_for_prompt(prompts[0])
- # append pload
- sub_prompts = [
- prompts[i: i + self.batch_size]
- for i in range(0, len(prompts), self.batch_size)
- ]
- return sub_prompts
- def create_llm_result(
- self, choices: Any, prompts: List[str], token_usage: Dict[str, int]
- ) -> LLMResult:
- """Create the LLMResult from the choices and prompts."""
- generations = []
- for i, _ in enumerate(prompts):
- sub_choices = choices[i * self.n: (i + 1) * self.n]
- generations.append(
- [
- Generation(
- text=choice["text"],
- generation_info=dict(
- finish_reason='over',
- logprobs=choice["text"],
- ),
- )
- for choice in sub_choices
- ]
- )
- llm_output = {"token_usage": token_usage, "model_name": self.model_name}
- return LLMResult(generations=generations, llm_output=llm_output)
- def stream(self, prompt: str, stop: Optional[List[str]] = None) -> Generator:
- """Call FastChat with streaming flag and return the resulting generator.
- BETA: this is a beta feature while we figure out the right abstraction.
- Once that happens, this interface could change.
- Args:
- prompt: The prompts to pass into the model.
- stop: Optional list of stop words to use when generating.
- Returns:
- A generator representing the stream of tokens from OpenAI.
- Example:
- .. code-block:: python
- generator = fastChat.stream("Tell me a joke.")
- for token in generator:
- yield token
- """
- params = self._invocation_params
- params["prompt"] = prompt
- if stop is not None:
- if "stop" in params:
- raise ValueError("`stop` found in both the input and default params.")
- params["stop"] = stop
- headers = {"User-Agent": "fastchat Client"}
- response = requests.post(
- FAST_CHAT_API,
- headers=headers,
- json=params,
- stream=True,
- )
- for stream_resp in response.iter_lines(
- chunk_size=8192, decode_unicode=False, delimiter=b"\0"
- ):
- if stream_resp:
- data = json.loads(stream_resp.decode("utf-8"))
- skip_echo_len = len(_prompts[0])
- output = data["text"][skip_echo_len:].strip()
- data["text"] = output
- yield data
- @property
- def _invocation_params(self) -> Dict[str, Any]:
- """Get the parameters used to invoke the model."""
- return self._default_params
- @property
- def _identifying_params(self) -> Mapping[str, Any]:
- """Get the identifying parameters."""
- return {**{"model_name": self.model_name}, **self._default_params}
- @property
- def _llm_type(self) -> str:
- """Return type of llm."""
- return "fastChat"
- def get_num_tokens(self, text: str) -> int:
- """Calculate num tokens with tiktoken package."""
- # tiktoken NOT supported for Python < 3.8
- if sys.version_info[1] < 8:
- return super().get_num_tokens(text)
- try:
- import tiktoken
- except ImportError:
- raise ValueError(
- "Could not import tiktoken python package. "
- "This is needed in order to calculate get_num_tokens. "
- "Please install it with `pip install tiktoken`."
- )
- enc = tiktoken.encoding_for_model(self.model_name)
- tokenized_text = enc.encode(
- text,
- allowed_special=self.allowed_special,
- disallowed_special=self.disallowed_special,
- )
- # calculate the number of tokens in the encoded text
- return len(tokenized_text)
- def modelname_to_contextsize(self, modelname: str) -> int:
- """Calculate the maximum number of tokens possible to generate for a model.
- Args:
- modelname: The modelname we want to know the context size for.
- Returns:
- The maximum context size
- Example:
- .. code-block:: python
- max_new_tokens = openai.modelname_to_contextsize("text-davinci-003")
- """
- model_token_mapping = {
- "vicuna-13b": 2049,
- "koala": 2049,
- "dolly-v2": 2049,
- "oasst": 2049,
- "stablelm": 2049,
- }
- context_size = model_token_mapping.get(modelname, None)
- if context_size is None:
- raise ValueError(
- f"Unknown model: {modelname}. Please provide a valid OpenAI model name."
- "Known models are: " + ", ".join(model_token_mapping.keys())
- )
- return context_size
- def max_new_tokens_for_prompt(self, prompt: str) -> int:
- """Calculate the maximum number of tokens possible to generate for a prompt.
- Args:
- prompt: The prompt to pass into the model.
- Returns:
- The maximum number of tokens to generate for a prompt.
- Example:
- .. code-block:: python
- max_new_tokens = openai.max_token_for_prompt("Tell me a joke.")
- """
- num_tokens = self.get_num_tokens(prompt)
- # get max context size for model by name
- max_size = self.modelname_to_contextsize(self.model_name)
- return max_size - num_tokens
- class FastChat(BaseFastChat):
- """Wrapper around OpenAI large language models.
- To use, you should have the ``openai`` python package installed, and the
- environment variable ``OPENAI_API_KEY`` set with your API key.
- Any parameters that are valid to be passed to the openai.create call can be passed
- in, even if not explicitly saved on this class.
- Example:
- .. code-block:: python
- from langchain.llms import OpenAI
- openai = FastChat(model_name="vicuna")
- """
- @property
- def _invocation_params(self) -> Dict[str, Any]:
- return {**{"model": self.model_name}, **super()._invocation_params}
|