瀏覽代碼

可选择lora权重加载 (#231)

* Add files via upload

增加lora权重使用

* Update model_config.py

* Add files via upload

修复一个小错误,少写了模型加载

* 使用lora微调的权重

使用lora微调的权重

* Update model_config.py
shrimp 2 年之前
父節點
當前提交
14d998b8e6
共有 5 個文件被更改,包括 36 次插入40 次删除
  1. 4 6
      chains/local_doc_qa.py
  2. 5 0
      configs/model_config.py
  3. 20 32
      models/chatglm_llm.py
  4. 1 0
      requirements.txt
  5. 6 2
      webui.py

+ 4 - 6
chains/local_doc_qa.py

@@ -61,9 +61,7 @@ def seperate_list(ls: List[int]) -> List[List[int]]:
 
 
 def similarity_search_with_score_by_vector(
-        self,
-        embedding: List[float],
-        k: int = 4,
+        self, embedding: List[float], k: int = 4,
 ) -> List[Tuple[Document, float]]:
     scores, indices = self.index.search(np.array([embedding], dtype=np.float32), k)
     docs = []
@@ -122,12 +120,12 @@ class LocalDocQA:
                  llm_model: str = LLM_MODEL,
                  llm_device=LLM_DEVICE,
                  top_k=VECTOR_SEARCH_TOP_K,
-                 use_ptuning_v2: bool = USE_PTUNING_V2
+                 use_ptuning_v2: bool = USE_PTUNING_V2,
+                 use_lora: bool = USE_LORA,
                  ):
         self.llm = ChatGLM()
         self.llm.load_model(model_name_or_path=llm_model_dict[llm_model],
-                            llm_device=llm_device,
-                            use_ptuning_v2=use_ptuning_v2)
+                            llm_device=llm_device, use_ptuning_v2=use_ptuning_v2, use_lora=use_lora)
         self.llm.history_len = llm_history_len
 
         self.embeddings = HuggingFaceEmbeddings(model_name=embedding_model_dict[embedding_model],

+ 5 - 0
configs/model_config.py

@@ -27,6 +27,11 @@ llm_model_dict = {
 # LLM model name
 LLM_MODEL = "chatglm-6b"
 
+# LLM lora path,默认为空,如果有请直接指定文件夹路径
+# 推荐使用 chatglm-6b-belle-zh-lora
+LLM_LORA_PATH = ""
+USE_LORA = True if LLM_LORA_PATH else False
+
 # LLM streaming reponse
 STREAMING = True
 

+ 20 - 32
models/chatglm_llm.py

@@ -78,11 +78,11 @@ class ChatGLM(LLM):
                 torch_gc()
         else:
             response, _ = self.model.chat(
-                    self.tokenizer,
-                    prompt,
-                    history=history[-self.history_len:] if self.history_len > 0 else [],
-                    max_length=self.max_token,
-                    temperature=self.temperature,
+                self.tokenizer,
+                prompt,
+                history=history[-self.history_len:] if self.history_len > 0 else [],
+                max_length=self.max_token,
+                temperature=self.temperature,
             )
             torch_gc()
             history += [[prompt, response]]
@@ -106,6 +106,7 @@ class ChatGLM(LLM):
                    model_name_or_path: str = "THUDM/chatglm-6b",
                    llm_device=LLM_DEVICE,
                    use_ptuning_v2=False,
+                   use_lora=False,
                    device_map: Optional[Dict[str, int]] = None,
                    **kwargs):
         self.tokenizer = AutoTokenizer.from_pretrained(
@@ -125,45 +126,32 @@ class ChatGLM(LLM):
             except Exception as e:
                 print(e)
                 print("加载PrefixEncoder config.json失败")
+        self.model = AutoModel.from_pretrained(model_name_or_path, config=model_config, trust_remote_code=True,
+                                               **kwargs)
+        if LLM_LORA_PATH and use_lora:
+            from peft import PeftModel
+            self.model = PeftModel.from_pretrained(self.model, LLM_LORA_PATH)
 
         if torch.cuda.is_available() and llm_device.lower().startswith("cuda"):
             # 根据当前设备GPU数量决定是否进行多卡部署
             num_gpus = torch.cuda.device_count()
             if num_gpus < 2 and device_map is None:
-                self.model = (
-                    AutoModel.from_pretrained(
-                        model_name_or_path,
-                        config=model_config,
-                        trust_remote_code=True,
-                        **kwargs)
-                    .half()
-                    .cuda()
-                )
+                self.model = self.model.half().cuda()
             else:
                 from accelerate import dispatch_model
 
-                model = (
-                    AutoModel.from_pretrained(
-                        model_name_or_path,
-                        trust_remote_code=True,
-                        config=model_config,
-                        **kwargs)
-                    .half())
+                model = AutoModel.from_pretrained(model_name_or_path, trust_remote_code=True,
+                        config=model_config, **kwargs)
+                if LLM_LORA_PATH and use_lora:
+                    from peft import PeftModel
+                    model_auto = PeftModel.from_pretrained(model, LLM_LORA_PATH)
                 # 可传入device_map自定义每张卡的部署情况
                 if device_map is None:
                     device_map = auto_configure_device_map(num_gpus)
 
-                self.model = dispatch_model(model, device_map=device_map)
+                self.model = dispatch_model(model_auto.half(), device_map=device_map)
         else:
-            self.model = (
-                AutoModel.from_pretrained(
-                    model_name_or_path,
-                    config=model_config,
-                    trust_remote_code=True,
-                    **kwargs)
-                .float()
-                .to(llm_device)
-            )
+            self.model = self.model.float().to(llm_device)
 
         if use_ptuning_v2:
             try:
@@ -185,7 +173,7 @@ if __name__ == "__main__":
     llm = ChatGLM()
     llm.load_model(model_name_or_path=llm_model_dict[LLM_MODEL],
                    llm_device=LLM_DEVICE, )
-    last_print_len=0
+    last_print_len = 0
     for resp, history in llm._call("你好", streaming=True):
         print(resp[last_print_len:], end="", flush=True)
         last_print_len = len(resp)

+ 1 - 0
requirements.txt

@@ -12,4 +12,5 @@ accelerate
 gradio==3.24.1
 fastapi
 uvicorn
+peft
 #detectron2@git+https://github.com/facebookresearch/detectron2.git@v0.6#egg=detectron2

+ 6 - 2
webui.py

@@ -72,12 +72,13 @@ def init_model():
         return reply
 
 
-def reinit_model(llm_model, embedding_model, llm_history_len, use_ptuning_v2, top_k, history):
+def reinit_model(llm_model, embedding_model, llm_history_len, use_ptuning_v2, use_lora, top_k, history):
     try:
         local_doc_qa.init_cfg(llm_model=llm_model,
                               embedding_model=embedding_model,
                               llm_history_len=llm_history_len,
                               use_ptuning_v2=use_ptuning_v2,
+                              use_lora = use_lora,
                               top_k=top_k,)
         model_status = """模型已成功重新加载,可以开始对话,或从右侧选择模式后开始对话"""
         print(model_status)
@@ -246,6 +247,9 @@ with gr.Blocks(css=block_css) as demo:
         use_ptuning_v2 = gr.Checkbox(USE_PTUNING_V2,
                                      label="使用p-tuning-v2微调过的模型",
                                      interactive=True)
+        use_lora = gr.Checkbox(USE_LORA,
+                                     label="使用lora微调的权重",
+                                     interactive=True)
         embedding_model = gr.Radio(embedding_model_dict_list,
                                    label="Embedding 模型",
                                    value=EMBEDDING_MODEL,
@@ -259,7 +263,7 @@ with gr.Blocks(css=block_css) as demo:
         load_model_button = gr.Button("重新加载模型")
     load_model_button.click(reinit_model,
                             show_progress=True,
-                            inputs=[llm_model, embedding_model, llm_history_len, use_ptuning_v2, top_k, chatbot],
+                            inputs=[llm_model, embedding_model, llm_history_len, use_ptuning_v2, use_lora, top_k, chatbot],
                             outputs=chatbot
                             )