ソースを参照

删除无用的格式化内容

glide-the 2 年 前
コミット
ee0ba302f5
1 ファイル変更13 行追加32 行削除
  1. 13 32
      fastchat/api/fastchat_api.py

+ 13 - 32
fastchat/api/fastchat_api.py

@@ -32,17 +32,13 @@ from tenacity import (
 from langchain.llms.base import BaseLLM
 from langchain.schema import Generation, LLMResult
 from langchain.utils import get_from_dict_or_env
-from .conversation import (
-    get_default_conv_template,
-    compute_skip_echo_len,
-    SeparatorStyle,
-)
+
 import requests
 import json
 
 
 logger = logging.getLogger(__name__)
-FAST_CHAT_API = "http://localhost:21001/worker_generate_stream"
+FAST_CHAT_API = "http://localhost:21002/worker_generate_stream"
 
 
 def _streaming_response_template() -> Dict[str, Any]:
@@ -68,7 +64,7 @@ class BaseFastChat(BaseLLM):
     """Model name to use."""
     temperature: float = 0.7
     """What sampling temperature to use."""
-    max_new_tokens: int = 256
+    max_new_tokens: int = 200
     stop: int = 20
     batch_size: int = 20
     """Maximum number of retries to make when generating."""
@@ -142,17 +138,12 @@ class BaseFastChat(BaseLLM):
         headers = {"User-Agent": "fastchat Client"}
         for _prompts in sub_prompts:
 
-            conv = get_default_conv_template(self.model_name).copy()
-            conv.append_message(conv.roles[0], _prompts[0])
-            conv.append_message(conv.roles[1], None)
-            prompt = conv.get_prompt()
-            params["prompt"] = prompt
+            params["prompt"] = _prompts[0]
 
-            # use get_default_conv_template stop unit
             if stop is not None:
                 if "stop" in params:
                     raise ValueError("`stop` found in both the input and default params.")
-                params["stop"] = stop[0]
+                params["stop"] = stop
 
             if self.streaming:
                 if len(_prompts) > 1:
@@ -170,7 +161,7 @@ class BaseFastChat(BaseLLM):
                 ):
                     if stream_resp:
                         data = json.loads(stream_resp.decode("utf-8"))
-                        skip_echo_len = compute_skip_echo_len(self.model_name, conv, prompt)
+                        skip_echo_len = len(_prompts[0])
                         output = data["text"][skip_echo_len:].strip()
                         data["text"] = output
                         self.callback_manager.on_llm_new_token(
@@ -193,7 +184,7 @@ class BaseFastChat(BaseLLM):
                 ):
                     if stream_resp:
                         data = json.loads(stream_resp.decode("utf-8"))
-                        skip_echo_len = compute_skip_echo_len(self.model_name, conv, prompt)
+                        skip_echo_len = len(_prompts[0])
                         output = data["text"][skip_echo_len:].strip()
                         data["text"] = output
                         _update_response(response_template, data)
@@ -214,16 +205,11 @@ class BaseFastChat(BaseLLM):
         headers = {"User-Agent": "fastchat Client"}
         for _prompts in sub_prompts:
 
-            conv = get_default_conv_template(self.model_name).copy()
-            conv.append_message(conv.roles[0], _prompts[0])
-            conv.append_message(conv.roles[1], None)
-            prompt = conv.get_prompt()
-            params["prompt"] = prompt
-            # use get_default_conv_template stop unit
+            params["prompt"] = _prompts[0]
             if stop is not None:
                 if "stop" in params:
                     raise ValueError("`stop` found in both the input and default params.")
-                params["stop"] = stop[0]
+                params["stop"] = stop
 
             if self.streaming:
                 if len(_prompts) > 1:
@@ -241,7 +227,7 @@ class BaseFastChat(BaseLLM):
                 ):
                     if stream_resp:
                         data = json.loads(stream_resp.decode("utf-8"))
-                        skip_echo_len = compute_skip_echo_len(self.model_name, conv, prompt)
+                        skip_echo_len = len(_prompts[0])
                         output = data["text"][skip_echo_len:].strip()
                         data["text"] = output
                         self.callback_manager.on_llm_new_token(
@@ -264,7 +250,7 @@ class BaseFastChat(BaseLLM):
                 ):
                     if stream_resp:
                         data = json.loads(stream_resp.decode("utf-8"))
-                        skip_echo_len = compute_skip_echo_len(self.model_name, conv, prompt)
+                        skip_echo_len = len(_prompts[0])
                         output = data["text"][skip_echo_len:].strip()
                         data["text"] = output
                         _update_response(response_template, data)
@@ -336,16 +322,11 @@ class BaseFastChat(BaseLLM):
                     yield token
         """
         params = self._invocation_params
-        conv = get_default_conv_template(self.model_name).copy()
-        conv.append_message(conv.roles[0], prompt)
-        conv.append_message(conv.roles[1], None)
-        prompt = conv.get_prompt()
         params["prompt"] = prompt
-        # use get_default_conv_template stop unit
         if stop is not None:
             if "stop" in params:
                 raise ValueError("`stop` found in both the input and default params.")
-            params["stop"] = stop[0]
+            params["stop"] = stop
 
         headers = {"User-Agent": "fastchat Client"}
         response = requests.post(
@@ -359,7 +340,7 @@ class BaseFastChat(BaseLLM):
         ):
             if stream_resp:
                 data = json.loads(stream_resp.decode("utf-8"))
-                skip_echo_len = compute_skip_echo_len(self.model_name, conv, prompt)
+                skip_echo_len = len(_prompts[0])
                 output = data["text"][skip_echo_len:].strip()
                 data["text"] = output
                 yield data