Răsfoiți Sursa

update chatglm_llm.py

imClumsyPanda 2 ani în urmă
părinte
comite
54c983f4bc
1 a modificat fișierele cu 9 adăugiri și 8 ștergeri
  1. 9 8
      models/chatglm_llm.py

+ 9 - 8
models/chatglm_llm.py

@@ -72,29 +72,29 @@ class ChatGLM(LLM):
         response, _ = self.model.chat(
         response, _ = self.model.chat(
             self.tokenizer,
             self.tokenizer,
             prompt,
             prompt,
-            history=self.history[-self.history_len:] if self.history_len>0 else [],
+            history=self.history[-self.history_len:] if self.history_len > 0 else [],
             max_length=self.max_token,
             max_length=self.max_token,
             temperature=self.temperature,
             temperature=self.temperature,
         )
         )
         torch_gc()
         torch_gc()
         if stop is not None:
         if stop is not None:
             response = enforce_stop_tokens(response, stop)
             response = enforce_stop_tokens(response, stop)
-        self.history = self.history+[[None, response]]
+        self.history = self.history + [[None, response]]
         return response
         return response
 
 
     def chat(self,
     def chat(self,
-              prompt: str) -> str:
+             prompt: str) -> str:
         response, _ = self.model.chat(
         response, _ = self.model.chat(
             self.tokenizer,
             self.tokenizer,
             prompt,
             prompt,
-            history=[],#self.history[-self.history_len:] if self.history_len>0 else 
+            history=self.history[-self.history_len:] if self.history_len > 0 else [],
             max_length=self.max_token,
             max_length=self.max_token,
             temperature=self.temperature,
             temperature=self.temperature,
         )
         )
         torch_gc()
         torch_gc()
-        self.history = self.history+[[None, response]]
+        self.history = self.history + [[None, response]]
         return response
         return response
-        
+
     def load_model(self,
     def load_model(self,
                    model_name_or_path: str = "THUDM/chatglm-6b",
                    model_name_or_path: str = "THUDM/chatglm-6b",
                    llm_device=LLM_DEVICE,
                    llm_device=LLM_DEVICE,
@@ -126,7 +126,7 @@ class ChatGLM(LLM):
                     AutoModel.from_pretrained(
                     AutoModel.from_pretrained(
                         model_name_or_path,
                         model_name_or_path,
                         config=model_config,
                         config=model_config,
-                        trust_remote_code=True, 
+                        trust_remote_code=True,
                         **kwargs)
                         **kwargs)
                     .half()
                     .half()
                     .cuda()
                     .cuda()
@@ -159,7 +159,8 @@ class ChatGLM(LLM):
                         new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v
                         new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v
                 self.model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict)
                 self.model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict)
                 self.model.transformer.prefix_encoder.float()
                 self.model.transformer.prefix_encoder.float()
-            except Exception:
+            except Exception as e:
+                print(e)
                 print("加载PrefixEncoder模型参数失败")
                 print("加载PrefixEncoder模型参数失败")
 
 
         self.model = self.model.eval()
         self.model = self.model.eval()