|
@@ -72,29 +72,29 @@ class ChatGLM(LLM):
|
|
response, _ = self.model.chat(
|
|
response, _ = self.model.chat(
|
|
self.tokenizer,
|
|
self.tokenizer,
|
|
prompt,
|
|
prompt,
|
|
- history=self.history[-self.history_len:] if self.history_len>0 else [],
|
|
|
|
|
|
+ history=self.history[-self.history_len:] if self.history_len > 0 else [],
|
|
max_length=self.max_token,
|
|
max_length=self.max_token,
|
|
temperature=self.temperature,
|
|
temperature=self.temperature,
|
|
)
|
|
)
|
|
torch_gc()
|
|
torch_gc()
|
|
if stop is not None:
|
|
if stop is not None:
|
|
response = enforce_stop_tokens(response, stop)
|
|
response = enforce_stop_tokens(response, stop)
|
|
- self.history = self.history+[[None, response]]
|
|
|
|
|
|
+ self.history = self.history + [[None, response]]
|
|
return response
|
|
return response
|
|
|
|
|
|
def chat(self,
|
|
def chat(self,
|
|
- prompt: str) -> str:
|
|
|
|
|
|
+ prompt: str) -> str:
|
|
response, _ = self.model.chat(
|
|
response, _ = self.model.chat(
|
|
self.tokenizer,
|
|
self.tokenizer,
|
|
prompt,
|
|
prompt,
|
|
- history=[],#self.history[-self.history_len:] if self.history_len>0 else
|
|
|
|
|
|
+ history=self.history[-self.history_len:] if self.history_len > 0 else [],
|
|
max_length=self.max_token,
|
|
max_length=self.max_token,
|
|
temperature=self.temperature,
|
|
temperature=self.temperature,
|
|
)
|
|
)
|
|
torch_gc()
|
|
torch_gc()
|
|
- self.history = self.history+[[None, response]]
|
|
|
|
|
|
+ self.history = self.history + [[None, response]]
|
|
return response
|
|
return response
|
|
-
|
|
|
|
|
|
+
|
|
def load_model(self,
|
|
def load_model(self,
|
|
model_name_or_path: str = "THUDM/chatglm-6b",
|
|
model_name_or_path: str = "THUDM/chatglm-6b",
|
|
llm_device=LLM_DEVICE,
|
|
llm_device=LLM_DEVICE,
|
|
@@ -126,7 +126,7 @@ class ChatGLM(LLM):
|
|
AutoModel.from_pretrained(
|
|
AutoModel.from_pretrained(
|
|
model_name_or_path,
|
|
model_name_or_path,
|
|
config=model_config,
|
|
config=model_config,
|
|
- trust_remote_code=True,
|
|
|
|
|
|
+ trust_remote_code=True,
|
|
**kwargs)
|
|
**kwargs)
|
|
.half()
|
|
.half()
|
|
.cuda()
|
|
.cuda()
|
|
@@ -159,7 +159,8 @@ class ChatGLM(LLM):
|
|
new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v
|
|
new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v
|
|
self.model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict)
|
|
self.model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict)
|
|
self.model.transformer.prefix_encoder.float()
|
|
self.model.transformer.prefix_encoder.float()
|
|
- except Exception:
|
|
|
|
|
|
+ except Exception as e:
|
|
|
|
+ print(e)
|
|
print("加载PrefixEncoder模型参数失败")
|
|
print("加载PrefixEncoder模型参数失败")
|
|
|
|
|
|
self.model = self.model.eval()
|
|
self.model = self.model.eval()
|