fastchat_api.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459
  1. """Wrapper around FastChat APIs."""
  2. from __future__ import annotations
  3. import logging
  4. import sys
  5. import warnings
  6. from typing import (
  7. AbstractSet,
  8. Any,
  9. Callable,
  10. Collection,
  11. Dict,
  12. Generator,
  13. List,
  14. Literal,
  15. Mapping,
  16. Optional,
  17. Set,
  18. Tuple,
  19. Union,
  20. )
  21. from pydantic import Extra, Field, root_validator
  22. from tenacity import (
  23. before_sleep_log,
  24. retry,
  25. retry_if_exception_type,
  26. stop_after_attempt,
  27. wait_exponential,
  28. )
  29. from langchain.llms.base import BaseLLM
  30. from langchain.schema import Generation, LLMResult
  31. from langchain.utils import get_from_dict_or_env
  32. import requests
  33. import json
  34. logger = logging.getLogger(__name__)
  35. FAST_CHAT_API = "http://localhost:21002/worker_generate_stream"
  36. def _streaming_response_template() -> Dict[str, Any]:
  37. """
  38. :return: 响应结构
  39. """
  40. return {
  41. "text": "",
  42. "error_code": 0,
  43. }
  44. def _update_response(response: Dict[str, Any], stream_response: Dict[str, Any]) -> None:
  45. """Update response from the stream response."""
  46. response["text"] += stream_response["text"]
  47. response["error_code"] += stream_response["error_code"]
  48. class BaseFastChat(BaseLLM):
  49. """Wrapper around FastChat large language models."""
  50. model_name: str = "text-davinci-003"
  51. """Model name to use."""
  52. temperature: float = 0.7
  53. """What sampling temperature to use."""
  54. max_new_tokens: int = 200
  55. stop: int = 20
  56. batch_size: int = 20
  57. """Maximum number of retries to make when generating."""
  58. streaming: bool = False
  59. """Penalizes repeated tokens."""
  60. n: int = 1
  61. """Whether to stream the results or not."""
  62. allowed_special: Union[Literal["all"], AbstractSet[str]] = set()
  63. """Set of special tokens that are allowed。"""
  64. disallowed_special: Union[Literal["all"], Collection[str]] = "all"
  65. """Set of special tokens that are not allowed。"""
  66. class Config:
  67. """Configuration for this pydantic object."""
  68. extra = Extra.ignore
  69. @root_validator(pre=True)
  70. def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
  71. """Build extra kwargs from additional params that were passed in."""
  72. all_required_field_names = {field.alias for field in cls.__fields__.values()}
  73. extra = values.get("model_kwargs", {})
  74. for field_name in list(values):
  75. if field_name not in all_required_field_names:
  76. if field_name in extra:
  77. raise ValueError(f"Found {field_name} supplied twice.")
  78. logger.warning(
  79. f"""WARNING! {field_name} is not default parameter.
  80. {field_name} was transfered to model_kwargs.
  81. Please confirm that {field_name} is what you intended."""
  82. )
  83. extra[field_name] = values.pop(field_name)
  84. values["model_kwargs"] = extra
  85. return values
  86. @property
  87. def _default_params(self) -> Dict[str, Any]:
  88. """Get the default parameters for calling FastChat API."""
  89. normal_params = {
  90. "model": self.model_name,
  91. "prompt": '',
  92. "max_new_tokens": self.max_new_tokens,
  93. "temperature": self.temperature,
  94. }
  95. return {**normal_params}
  96. def _generate(
  97. self, prompts: List[str], stop: Optional[List[str]] = None
  98. ) -> LLMResult:
  99. """Call out to FastChat's endpoint with k unique prompts.
  100. Args:
  101. prompts: The prompts to pass into the model.
  102. stop: Optional list of stop words to use when generating.
  103. Returns:
  104. The full LLM output.
  105. Example:
  106. .. code-block:: python
  107. response = fastchat.generate(["Tell me a joke."])
  108. """
  109. # TODO: write a unit test for this
  110. params = self._invocation_params
  111. sub_prompts = self.get_sub_prompts(params, prompts)
  112. choices = []
  113. token_usage: Dict[str, int] = {}
  114. headers = {"User-Agent": "fastchat Client"}
  115. for _prompts in sub_prompts:
  116. params["prompt"] = _prompts[0]
  117. if stop is not None:
  118. if "stop" in params:
  119. raise ValueError("`stop` found in both the input and default params.")
  120. params["stop"] = stop
  121. if self.streaming:
  122. if len(_prompts) > 1:
  123. raise ValueError("Cannot stream results with multiple prompts.")
  124. response_template = _streaming_response_template()
  125. response = requests.post(
  126. FAST_CHAT_API,
  127. headers=headers,
  128. json=params,
  129. stream=True,
  130. )
  131. for stream_resp in response.iter_lines(
  132. chunk_size=8192, decode_unicode=False, delimiter=b"\0"
  133. ):
  134. if stream_resp:
  135. data = json.loads(stream_resp.decode("utf-8"))
  136. skip_echo_len = len(_prompts[0])
  137. output = data["text"][skip_echo_len:].strip()
  138. data["text"] = output
  139. self.callback_manager.on_llm_new_token(
  140. output,
  141. verbose=self.verbose,
  142. logprobs=data["error_code"],
  143. )
  144. _update_response(response_template, data)
  145. choices.append(response_template)
  146. else:
  147. response_template = _streaming_response_template()
  148. response = requests.post(
  149. FAST_CHAT_API,
  150. headers=headers,
  151. json=params,
  152. stream=True,
  153. )
  154. for stream_resp in response.iter_lines(
  155. chunk_size=8192, decode_unicode=False, delimiter=b"\0"
  156. ):
  157. if stream_resp:
  158. data = json.loads(stream_resp.decode("utf-8"))
  159. skip_echo_len = len(_prompts[0])
  160. output = data["text"][skip_echo_len:].strip()
  161. data["text"] = output
  162. _update_response(response_template, data)
  163. choices.append(response_template)
  164. return self.create_llm_result(choices, prompts, token_usage)
  165. async def _agenerate(
  166. self, prompts: List[str], stop: Optional[List[str]] = None
  167. ) -> LLMResult:
  168. """Call out to FastChat's endpoint async with k unique prompts."""
  169. params = self._invocation_params
  170. sub_prompts = self.get_sub_prompts(params, prompts)
  171. choices = []
  172. token_usage: Dict[str, int] = {}
  173. headers = {"User-Agent": "fastchat Client"}
  174. for _prompts in sub_prompts:
  175. params["prompt"] = _prompts[0]
  176. if stop is not None:
  177. if "stop" in params:
  178. raise ValueError("`stop` found in both the input and default params.")
  179. params["stop"] = stop
  180. if self.streaming:
  181. if len(_prompts) > 1:
  182. raise ValueError("Cannot stream results with multiple prompts.")
  183. response_template = _streaming_response_template()
  184. response = requests.post(
  185. FAST_CHAT_API,
  186. headers=headers,
  187. json=params,
  188. stream=True,
  189. )
  190. for stream_resp in response.iter_lines(
  191. chunk_size=8192, decode_unicode=False, delimiter=b"\0"
  192. ):
  193. if stream_resp:
  194. data = json.loads(stream_resp.decode("utf-8"))
  195. skip_echo_len = len(_prompts[0])
  196. output = data["text"][skip_echo_len:].strip()
  197. data["text"] = output
  198. self.callback_manager.on_llm_new_token(
  199. output,
  200. verbose=self.verbose,
  201. logprobs=data["error_code"],
  202. )
  203. _update_response(response_template, data)
  204. choices.append(response_template)
  205. else:
  206. response_template = _streaming_response_template()
  207. response = requests.post(
  208. FAST_CHAT_API,
  209. headers=headers,
  210. json=params,
  211. stream=True,
  212. )
  213. for stream_resp in response.iter_lines(
  214. chunk_size=8192, decode_unicode=False, delimiter=b"\0"
  215. ):
  216. if stream_resp:
  217. data = json.loads(stream_resp.decode("utf-8"))
  218. skip_echo_len = len(_prompts[0])
  219. output = data["text"][skip_echo_len:].strip()
  220. data["text"] = output
  221. _update_response(response_template, data)
  222. choices.append(response_template)
  223. return self.create_llm_result(choices, prompts, token_usage)
  224. def get_sub_prompts(
  225. self,
  226. params: Dict[str, Any],
  227. prompts: List[str],
  228. ) -> List[List[str]]:
  229. """Get the sub prompts for llm call."""
  230. if params["max_new_tokens"] == -1:
  231. if len(prompts) != 1:
  232. raise ValueError(
  233. "max_new_tokens set to -1 not supported for multiple inputs."
  234. )
  235. params["max_new_tokens"] = self.max_new_tokens_for_prompt(prompts[0])
  236. # append pload
  237. sub_prompts = [
  238. prompts[i: i + self.batch_size]
  239. for i in range(0, len(prompts), self.batch_size)
  240. ]
  241. return sub_prompts
  242. def create_llm_result(
  243. self, choices: Any, prompts: List[str], token_usage: Dict[str, int]
  244. ) -> LLMResult:
  245. """Create the LLMResult from the choices and prompts."""
  246. generations = []
  247. for i, _ in enumerate(prompts):
  248. sub_choices = choices[i * self.n: (i + 1) * self.n]
  249. generations.append(
  250. [
  251. Generation(
  252. text=choice["text"],
  253. generation_info=dict(
  254. finish_reason='over',
  255. logprobs=choice["text"],
  256. ),
  257. )
  258. for choice in sub_choices
  259. ]
  260. )
  261. llm_output = {"token_usage": token_usage, "model_name": self.model_name}
  262. return LLMResult(generations=generations, llm_output=llm_output)
  263. def stream(self, prompt: str, stop: Optional[List[str]] = None) -> Generator:
  264. """Call FastChat with streaming flag and return the resulting generator.
  265. BETA: this is a beta feature while we figure out the right abstraction.
  266. Once that happens, this interface could change.
  267. Args:
  268. prompt: The prompts to pass into the model.
  269. stop: Optional list of stop words to use when generating.
  270. Returns:
  271. A generator representing the stream of tokens from OpenAI.
  272. Example:
  273. .. code-block:: python
  274. generator = fastChat.stream("Tell me a joke.")
  275. for token in generator:
  276. yield token
  277. """
  278. params = self._invocation_params
  279. params["prompt"] = prompt
  280. if stop is not None:
  281. if "stop" in params:
  282. raise ValueError("`stop` found in both the input and default params.")
  283. params["stop"] = stop
  284. headers = {"User-Agent": "fastchat Client"}
  285. response = requests.post(
  286. FAST_CHAT_API,
  287. headers=headers,
  288. json=params,
  289. stream=True,
  290. )
  291. for stream_resp in response.iter_lines(
  292. chunk_size=8192, decode_unicode=False, delimiter=b"\0"
  293. ):
  294. if stream_resp:
  295. data = json.loads(stream_resp.decode("utf-8"))
  296. skip_echo_len = len(_prompts[0])
  297. output = data["text"][skip_echo_len:].strip()
  298. data["text"] = output
  299. yield data
  300. @property
  301. def _invocation_params(self) -> Dict[str, Any]:
  302. """Get the parameters used to invoke the model."""
  303. return self._default_params
  304. @property
  305. def _identifying_params(self) -> Mapping[str, Any]:
  306. """Get the identifying parameters."""
  307. return {**{"model_name": self.model_name}, **self._default_params}
  308. @property
  309. def _llm_type(self) -> str:
  310. """Return type of llm."""
  311. return "fastChat"
  312. def get_num_tokens(self, text: str) -> int:
  313. """Calculate num tokens with tiktoken package."""
  314. # tiktoken NOT supported for Python < 3.8
  315. if sys.version_info[1] < 8:
  316. return super().get_num_tokens(text)
  317. try:
  318. import tiktoken
  319. except ImportError:
  320. raise ValueError(
  321. "Could not import tiktoken python package. "
  322. "This is needed in order to calculate get_num_tokens. "
  323. "Please install it with `pip install tiktoken`."
  324. )
  325. enc = tiktoken.encoding_for_model(self.model_name)
  326. tokenized_text = enc.encode(
  327. text,
  328. allowed_special=self.allowed_special,
  329. disallowed_special=self.disallowed_special,
  330. )
  331. # calculate the number of tokens in the encoded text
  332. return len(tokenized_text)
  333. def modelname_to_contextsize(self, modelname: str) -> int:
  334. """Calculate the maximum number of tokens possible to generate for a model.
  335. Args:
  336. modelname: The modelname we want to know the context size for.
  337. Returns:
  338. The maximum context size
  339. Example:
  340. .. code-block:: python
  341. max_new_tokens = openai.modelname_to_contextsize("text-davinci-003")
  342. """
  343. model_token_mapping = {
  344. "vicuna-13b": 2049,
  345. "koala": 2049,
  346. "dolly-v2": 2049,
  347. "oasst": 2049,
  348. "stablelm": 2049,
  349. }
  350. context_size = model_token_mapping.get(modelname, None)
  351. if context_size is None:
  352. raise ValueError(
  353. f"Unknown model: {modelname}. Please provide a valid OpenAI model name."
  354. "Known models are: " + ", ".join(model_token_mapping.keys())
  355. )
  356. return context_size
  357. def max_new_tokens_for_prompt(self, prompt: str) -> int:
  358. """Calculate the maximum number of tokens possible to generate for a prompt.
  359. Args:
  360. prompt: The prompt to pass into the model.
  361. Returns:
  362. The maximum number of tokens to generate for a prompt.
  363. Example:
  364. .. code-block:: python
  365. max_new_tokens = openai.max_token_for_prompt("Tell me a joke.")
  366. """
  367. num_tokens = self.get_num_tokens(prompt)
  368. # get max context size for model by name
  369. max_size = self.modelname_to_contextsize(self.model_name)
  370. return max_size - num_tokens
  371. class FastChat(BaseFastChat):
  372. """Wrapper around OpenAI large language models.
  373. To use, you should have the ``openai`` python package installed, and the
  374. environment variable ``OPENAI_API_KEY`` set with your API key.
  375. Any parameters that are valid to be passed to the openai.create call can be passed
  376. in, even if not explicitly saved on this class.
  377. Example:
  378. .. code-block:: python
  379. from langchain.llms import OpenAI
  380. openai = FastChat(model_name="vicuna")
  381. """
  382. @property
  383. def _invocation_params(self) -> Dict[str, Any]:
  384. return {**{"model": self.model_name}, **super()._invocation_params}