Procházet zdrojové kódy

fix bug in chatglm_llm.py

imClumsyPanda před 2 roky
rodič
revize
f147043253
2 změnil soubory, kde provedl 3 přidání a 9 odebrání
  1. 3 1
      chains/local_doc_qa.py
  2. 0 8
      models/chatglm_llm.py

+ 3 - 1
chains/local_doc_qa.py

@@ -41,10 +41,12 @@ class LocalDocQA:
                  llm_model: str = LLM_MODEL,
                  llm_model: str = LLM_MODEL,
                  llm_device=LLM_DEVICE,
                  llm_device=LLM_DEVICE,
                  top_k=VECTOR_SEARCH_TOP_K,
                  top_k=VECTOR_SEARCH_TOP_K,
+                 use_ptuning_v2: bool = USE_PTUNING_V2
                  ):
                  ):
         self.llm = ChatGLM()
         self.llm = ChatGLM()
         self.llm.load_model(model_name_or_path=llm_model_dict[llm_model],
         self.llm.load_model(model_name_or_path=llm_model_dict[llm_model],
-                            llm_device=llm_device)
+                            llm_device=llm_device,
+                            use_ptuning_v2=use_ptuning_v2)
         self.llm.history_len = llm_history_len
         self.llm.history_len = llm_history_len
 
 
         self.embeddings = HuggingFaceEmbeddings(model_name=embedding_model_dict[embedding_model], )
         self.embeddings = HuggingFaceEmbeddings(model_name=embedding_model_dict[embedding_model], )

+ 0 - 8
models/chatglm_llm.py

@@ -127,14 +127,6 @@ class ChatGLM(LLM):
                     device_map = auto_configure_device_map(num_gpus)
                     device_map = auto_configure_device_map(num_gpus)
 
 
                 self.model = dispatch_model(model, device_map=device_map)
                 self.model = dispatch_model(model, device_map=device_map)
-            self.model = (
-                AutoModel.from_pretrained(
-                    model_name_or_path,
-                    config=model_config,
-                    trust_remote_code=True)
-                .half()
-                .cuda()
-            )
         else:
         else:
             self.model = (
             self.model = (
                 AutoModel.from_pretrained(
                 AutoModel.from_pretrained(