瀏覽代碼

Support p-tuning-v2

Thaumstrial 2 年之前
父節點
當前提交
2cd52f6605
共有 4 個文件被更改,包括 49 次插入5 次删除
  1. 3 0
      configs/model_config.py
  2. 34 2
      models/chatglm_llm.py
  3. 5 0
      ptuning-v2/readme.md
  4. 7 3
      webui.py

+ 3 - 0
configs/model_config.py

@@ -24,6 +24,9 @@ llm_model_dict = {
 # LLM model name
 LLM_MODEL = "chatglm-6b"
 
+# Use p-tuning-v2 PrefixEncoder
+USE_PTUNING_V2 = False
+
 # LLM running device
 LLM_DEVICE = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
 

+ 34 - 2
models/chatglm_llm.py

@@ -1,7 +1,10 @@
+import json
+import os
+
 from langchain.llms.base import LLM
 from typing import Optional, List
 from langchain.llms.utils import enforce_stop_tokens
-from transformers import AutoTokenizer, AutoModel
+from transformers import AutoTokenizer, AutoModel, AutoConfig
 import torch
 from configs.model_config import LLM_DEVICE
 
@@ -51,15 +54,30 @@ class ChatGLM(LLM):
 
     def load_model(self,
                    model_name_or_path: str = "THUDM/chatglm-6b",
-                   llm_device=LLM_DEVICE):
+                   llm_device=LLM_DEVICE,
+                   use_ptuning_v2=False):
         self.tokenizer = AutoTokenizer.from_pretrained(
             model_name_or_path,
             trust_remote_code=True
         )
+
+        model_config = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=True)
+
+        if use_ptuning_v2:
+            try:
+                prefix_encoder_file = open('ptuning-v2/config.json', 'r')
+                prefix_encoder_config = json.loads(prefix_encoder_file.read())
+                prefix_encoder_file.close()
+                model_config.pre_seq_len = prefix_encoder_config['pre_seq_len']
+                model_config.prefix_projection = prefix_encoder_config['prefix_projection']
+            except Exception:
+                print("加载PrefixEncoder config.json失败")
+
         if torch.cuda.is_available() and llm_device.lower().startswith("cuda"):
             self.model = (
                 AutoModel.from_pretrained(
                     model_name_or_path,
+                    config=model_config,
                     trust_remote_code=True)
                 .half()
                 .cuda()
@@ -68,8 +86,22 @@ class ChatGLM(LLM):
             self.model = (
                 AutoModel.from_pretrained(
                     model_name_or_path,
+                    config=model_config,
                     trust_remote_code=True)
                 .float()
                 .to(llm_device)
             )
+
+        if use_ptuning_v2:
+            try:
+                prefix_state_dict = torch.load('ptuning-v2/pytorch_model.bin')
+                new_prefix_state_dict = {}
+                for k, v in prefix_state_dict.items():
+                    if k.startswith("transformer.prefix_encoder."):
+                        new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v
+                self.model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict)
+                self.model.transformer.prefix_encoder.float()
+            except Exception:
+                print("加载PrefixEncoder模型参数失败")
+
         self.model = self.model.eval()

+ 5 - 0
ptuning-v2/readme.md

@@ -0,0 +1,5 @@
+如果使用了[p-tuning-v2](https://github.com/THUDM/ChatGLM-6B/tree/main/ptuning)方式微调了模型,可以将得到的PrefixEndoer放入此文件夹。
+
+只需要放入模型的*config.json*和*pytorch_model.bin*
+
+并在加载模型时勾选 *"使用p-tuning-v2微调过的模型"*

+ 7 - 3
webui.py

@@ -53,11 +53,12 @@ def init_model():
         return """模型未成功加载,请重新选择后点击"加载模型"按钮"""
 
 
-def reinit_model(llm_model, embedding_model, llm_history_len, top_k, history):
+def reinit_model(llm_model, embedding_model, llm_history_len, use_ptuning_v2, top_k, history):
     try:
         local_doc_qa.init_cfg(llm_model=llm_model,
                               embedding_model=embedding_model,
                               llm_history_len=llm_history_len,
+                              use_ptuning_v2=use_ptuning_v2,
                               top_k=top_k)
         model_status = """模型已成功重新加载,请选择文件后点击"加载文件"按钮"""
     except:
@@ -97,7 +98,7 @@ webui_title = """
 """
 
 init_message = """欢迎使用 langchain-ChatGLM Web UI,开始提问前,请依次如下 3 个步骤:
-1. 选择语言模型、Embedding 模型及相关参数后点击"重新加载模型",并等待加载完成提示
+1. 选择语言模型、Embedding 模型及相关参数,如果使用ptuning-v2方式微调过模型,将PrefixEncoder模型放在ptuning-v2文件夹里并勾选相关选项,然后点击"重新加载模型",并等待加载完成提示
 2. 上传或选择已有文件作为本地知识文档输入后点击"重新加载文档",并等待加载完成提示
 3. 输入要提交的问题后,点击回车提交 """
 
@@ -127,6 +128,9 @@ with gr.Blocks(css=block_css) as demo:
                                         step=1,
                                         label="LLM history len",
                                         interactive=True)
+            use_ptuning_v2 = gr.Checkbox(USE_PTUNING_V2,
+                                         label="使用p-tuning-v2微调过的模型",
+                                         interactive=True)
             embedding_model = gr.Radio(embedding_model_dict_list,
                                        label="Embedding 模型",
                                        value=EMBEDDING_MODEL,
@@ -152,7 +156,7 @@ with gr.Blocks(css=block_css) as demo:
             load_file_button = gr.Button("加载文件")
     load_model_button.click(reinit_model,
                             show_progress=True,
-                            inputs=[llm_model, embedding_model, llm_history_len, top_k, chatbot],
+                            inputs=[llm_model, embedding_model, llm_history_len, use_ptuning_v2, top_k, chatbot],
                             outputs=chatbot
                             )
     # 将上传的文件保存到content文件夹下,并更新下拉框