|
@@ -32,17 +32,13 @@ from tenacity import (
|
|
from langchain.llms.base import BaseLLM
|
|
from langchain.llms.base import BaseLLM
|
|
from langchain.schema import Generation, LLMResult
|
|
from langchain.schema import Generation, LLMResult
|
|
from langchain.utils import get_from_dict_or_env
|
|
from langchain.utils import get_from_dict_or_env
|
|
-from .conversation import (
|
|
|
|
- get_default_conv_template,
|
|
|
|
- compute_skip_echo_len,
|
|
|
|
- SeparatorStyle,
|
|
|
|
-)
|
|
|
|
|
|
+
|
|
import requests
|
|
import requests
|
|
import json
|
|
import json
|
|
|
|
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
logger = logging.getLogger(__name__)
|
|
-FAST_CHAT_API = "http://localhost:21001/worker_generate_stream"
|
|
|
|
|
|
+FAST_CHAT_API = "http://localhost:21002/worker_generate_stream"
|
|
|
|
|
|
|
|
|
|
def _streaming_response_template() -> Dict[str, Any]:
|
|
def _streaming_response_template() -> Dict[str, Any]:
|
|
@@ -68,7 +64,7 @@ class BaseFastChat(BaseLLM):
|
|
"""Model name to use."""
|
|
"""Model name to use."""
|
|
temperature: float = 0.7
|
|
temperature: float = 0.7
|
|
"""What sampling temperature to use."""
|
|
"""What sampling temperature to use."""
|
|
- max_new_tokens: int = 256
|
|
|
|
|
|
+ max_new_tokens: int = 200
|
|
stop: int = 20
|
|
stop: int = 20
|
|
batch_size: int = 20
|
|
batch_size: int = 20
|
|
"""Maximum number of retries to make when generating."""
|
|
"""Maximum number of retries to make when generating."""
|
|
@@ -142,17 +138,12 @@ class BaseFastChat(BaseLLM):
|
|
headers = {"User-Agent": "fastchat Client"}
|
|
headers = {"User-Agent": "fastchat Client"}
|
|
for _prompts in sub_prompts:
|
|
for _prompts in sub_prompts:
|
|
|
|
|
|
- conv = get_default_conv_template(self.model_name).copy()
|
|
|
|
- conv.append_message(conv.roles[0], _prompts[0])
|
|
|
|
- conv.append_message(conv.roles[1], None)
|
|
|
|
- prompt = conv.get_prompt()
|
|
|
|
- params["prompt"] = prompt
|
|
|
|
|
|
+ params["prompt"] = _prompts[0]
|
|
|
|
|
|
- # use get_default_conv_template stop unit
|
|
|
|
if stop is not None:
|
|
if stop is not None:
|
|
if "stop" in params:
|
|
if "stop" in params:
|
|
raise ValueError("`stop` found in both the input and default params.")
|
|
raise ValueError("`stop` found in both the input and default params.")
|
|
- params["stop"] = stop[0]
|
|
|
|
|
|
+ params["stop"] = stop
|
|
|
|
|
|
if self.streaming:
|
|
if self.streaming:
|
|
if len(_prompts) > 1:
|
|
if len(_prompts) > 1:
|
|
@@ -170,7 +161,7 @@ class BaseFastChat(BaseLLM):
|
|
):
|
|
):
|
|
if stream_resp:
|
|
if stream_resp:
|
|
data = json.loads(stream_resp.decode("utf-8"))
|
|
data = json.loads(stream_resp.decode("utf-8"))
|
|
- skip_echo_len = compute_skip_echo_len(self.model_name, conv, prompt)
|
|
|
|
|
|
+ skip_echo_len = len(_prompts[0])
|
|
output = data["text"][skip_echo_len:].strip()
|
|
output = data["text"][skip_echo_len:].strip()
|
|
data["text"] = output
|
|
data["text"] = output
|
|
self.callback_manager.on_llm_new_token(
|
|
self.callback_manager.on_llm_new_token(
|
|
@@ -193,7 +184,7 @@ class BaseFastChat(BaseLLM):
|
|
):
|
|
):
|
|
if stream_resp:
|
|
if stream_resp:
|
|
data = json.loads(stream_resp.decode("utf-8"))
|
|
data = json.loads(stream_resp.decode("utf-8"))
|
|
- skip_echo_len = compute_skip_echo_len(self.model_name, conv, prompt)
|
|
|
|
|
|
+ skip_echo_len = len(_prompts[0])
|
|
output = data["text"][skip_echo_len:].strip()
|
|
output = data["text"][skip_echo_len:].strip()
|
|
data["text"] = output
|
|
data["text"] = output
|
|
_update_response(response_template, data)
|
|
_update_response(response_template, data)
|
|
@@ -214,16 +205,11 @@ class BaseFastChat(BaseLLM):
|
|
headers = {"User-Agent": "fastchat Client"}
|
|
headers = {"User-Agent": "fastchat Client"}
|
|
for _prompts in sub_prompts:
|
|
for _prompts in sub_prompts:
|
|
|
|
|
|
- conv = get_default_conv_template(self.model_name).copy()
|
|
|
|
- conv.append_message(conv.roles[0], _prompts[0])
|
|
|
|
- conv.append_message(conv.roles[1], None)
|
|
|
|
- prompt = conv.get_prompt()
|
|
|
|
- params["prompt"] = prompt
|
|
|
|
- # use get_default_conv_template stop unit
|
|
|
|
|
|
+ params["prompt"] = _prompts[0]
|
|
if stop is not None:
|
|
if stop is not None:
|
|
if "stop" in params:
|
|
if "stop" in params:
|
|
raise ValueError("`stop` found in both the input and default params.")
|
|
raise ValueError("`stop` found in both the input and default params.")
|
|
- params["stop"] = stop[0]
|
|
|
|
|
|
+ params["stop"] = stop
|
|
|
|
|
|
if self.streaming:
|
|
if self.streaming:
|
|
if len(_prompts) > 1:
|
|
if len(_prompts) > 1:
|
|
@@ -241,7 +227,7 @@ class BaseFastChat(BaseLLM):
|
|
):
|
|
):
|
|
if stream_resp:
|
|
if stream_resp:
|
|
data = json.loads(stream_resp.decode("utf-8"))
|
|
data = json.loads(stream_resp.decode("utf-8"))
|
|
- skip_echo_len = compute_skip_echo_len(self.model_name, conv, prompt)
|
|
|
|
|
|
+ skip_echo_len = len(_prompts[0])
|
|
output = data["text"][skip_echo_len:].strip()
|
|
output = data["text"][skip_echo_len:].strip()
|
|
data["text"] = output
|
|
data["text"] = output
|
|
self.callback_manager.on_llm_new_token(
|
|
self.callback_manager.on_llm_new_token(
|
|
@@ -264,7 +250,7 @@ class BaseFastChat(BaseLLM):
|
|
):
|
|
):
|
|
if stream_resp:
|
|
if stream_resp:
|
|
data = json.loads(stream_resp.decode("utf-8"))
|
|
data = json.loads(stream_resp.decode("utf-8"))
|
|
- skip_echo_len = compute_skip_echo_len(self.model_name, conv, prompt)
|
|
|
|
|
|
+ skip_echo_len = len(_prompts[0])
|
|
output = data["text"][skip_echo_len:].strip()
|
|
output = data["text"][skip_echo_len:].strip()
|
|
data["text"] = output
|
|
data["text"] = output
|
|
_update_response(response_template, data)
|
|
_update_response(response_template, data)
|
|
@@ -336,16 +322,11 @@ class BaseFastChat(BaseLLM):
|
|
yield token
|
|
yield token
|
|
"""
|
|
"""
|
|
params = self._invocation_params
|
|
params = self._invocation_params
|
|
- conv = get_default_conv_template(self.model_name).copy()
|
|
|
|
- conv.append_message(conv.roles[0], prompt)
|
|
|
|
- conv.append_message(conv.roles[1], None)
|
|
|
|
- prompt = conv.get_prompt()
|
|
|
|
params["prompt"] = prompt
|
|
params["prompt"] = prompt
|
|
- # use get_default_conv_template stop unit
|
|
|
|
if stop is not None:
|
|
if stop is not None:
|
|
if "stop" in params:
|
|
if "stop" in params:
|
|
raise ValueError("`stop` found in both the input and default params.")
|
|
raise ValueError("`stop` found in both the input and default params.")
|
|
- params["stop"] = stop[0]
|
|
|
|
|
|
+ params["stop"] = stop
|
|
|
|
|
|
headers = {"User-Agent": "fastchat Client"}
|
|
headers = {"User-Agent": "fastchat Client"}
|
|
response = requests.post(
|
|
response = requests.post(
|
|
@@ -359,7 +340,7 @@ class BaseFastChat(BaseLLM):
|
|
):
|
|
):
|
|
if stream_resp:
|
|
if stream_resp:
|
|
data = json.loads(stream_resp.decode("utf-8"))
|
|
data = json.loads(stream_resp.decode("utf-8"))
|
|
- skip_echo_len = compute_skip_echo_len(self.model_name, conv, prompt)
|
|
|
|
|
|
+ skip_echo_len = len(_prompts[0])
|
|
output = data["text"][skip_echo_len:].strip()
|
|
output = data["text"][skip_echo_len:].strip()
|
|
data["text"] = output
|
|
data["text"] = output
|
|
yield data
|
|
yield data
|