Просмотр исходного кода

lm_model_dict 定义checkpoint名称和远程路径

glide-the 2 лет назад
Родитель
Сommit
7b95fa4aaf
6 измененных файлов с 53 добавлено и 43 удалено
  1. 14 7
      configs/model_config.py
  2. 2 2
      models/chatglm_llm.py
  3. 0 12
      models/llama_llm.py
  4. 2 2
      models/loader/loader.py
  5. 22 7
      models/shared.py
  6. 13 13
      webui.py

+ 14 - 7
configs/model_config.py

@@ -21,31 +21,38 @@ llm_model_dict 处理了loader的一些预设行为,如加载位置,模型
 """
 llm_model_dict = {
     "chatglm-6b-int4-qe": {
-        "path": "THUDM/chatglm-6b-int4-qe",
+        "name": "chatglm-6b-int4-qe",
+        "remote-checkpoint": "THUDM/chatglm-6b-int4-qe",
         "provides": "ChatGLM"
     },
     "chatglm-6b-int4": {
-        "path": "THUDM/chatglm-6b-int4",
+        "name": "chatglm-6b-int4",
+        "remote-checkpoint": "THUDM/chatglm-6b-int4",
         "provides": "ChatGLM"
     },
     "chatglm-6b": {
-        "path": "THUDM/chatglm-6b-int4",
+        "name": "chatglm-6b",
+        "remote-checkpoint": "THUDM/chatglm-6b-int4",
         "provides": "ChatGLM"
     },
     "llama-7b-hf": {
-        "path": "llama-7b-hf",
+        "name": "llama-7b-hf",
+        "remote-checkpoint": "llama-7b-hf",
         "provides": "LLamaLLM"
     },
     "vicuna-13b-hf": {
-        "path": "vicuna-13b-hf",
+        "name": "vicuna-13b-hf",
+        "remote-checkpoint": "vicuna-13b-hf",
         "provides": "LLamaLLM"
     },
     "chatyuan": {
-        "path": "ClueAI/ChatYuan-large-v2",
+        "name": "chatyuan",
+        "remote-checkpoint": "ClueAI/ChatYuan-large-v2",
         "provides": None
     },
     "chatglm-6b-int8":{
-        "path":  "THUDM/chatglm-6b-int8",
+        "name": "chatglm-6b-int8",
+        "remote-checkpoint":  "THUDM/chatglm-6b-int8",
         "provides": "ChatGLM"
     },
 }

+ 2 - 2
models/chatglm_llm.py

@@ -1,5 +1,5 @@
 import json
-from langchain.llms.base import BaseLLM
+from langchain.llms.base import LLM
 from typing import Optional, List
 from langchain.llms.utils import enforce_stop_tokens
 
@@ -9,7 +9,7 @@ from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
 from models.loader import LoaderCheckPoint
 
 
-class ChatGLM(BaseLLM):
+class ChatGLM(LLM):
     max_token: int = 10000
     temperature: float = 0.01
     top_p = 0.9

+ 0 - 12
models/llama_llm.py

@@ -135,18 +135,6 @@ class LLamaLLM(LLM):
         filler_input_ids = torch.zeros((1, inputs_embeds.shape[1]), dtype=input_ids.dtype).to(self.llm.model.device)
         return inputs_embeds, filler_input_ids
 
-    def callmessage(self, prompt: str, ):
-
-        input_ids = self.encode(prompt, add_bos_token=self.state['add_bos_token'],
-                                truncation_length=self.get_max_prompt_length())
-        self.generate_params.update({'inputs': input_ids})
-
-        with self.generate_with_streaming(**self.generate_params) as generator:
-            for output in generator:
-                new_tokens = len(output) - len(input_ids[0])
-                reply = self.decode(output[-new_tokens:])
-                print(reply)
-
     def _call(self,
               prompt: str,
               stop: Optional[List[str]] = None) -> str:

+ 2 - 2
models/loader/loader.py

@@ -194,7 +194,7 @@ class LoaderCheckPoint:
             if self.load_in_8bit and params.get('max_memory', None) is not None and params['device_map'] == 'auto':
                 config = AutoConfig.from_pretrained(checkpoint)
                 with init_empty_weights():
-                    model = AutoModelForCausalLM.from_config(config)
+                    model = LoaderClass.from_config(config)
                 model.tie_weights()
                 if self.device_map is not None:
                     params['device_map'] = self.device_map
@@ -206,7 +206,7 @@ class LoaderCheckPoint:
                         no_split_module_classes=model._no_split_modules
                     )
 
-            model = AutoModelForCausalLM.from_pretrained(checkpoint, **params)
+            model = LoaderClass.from_pretrained(checkpoint, **params)
 
         # Loading the tokenizer
         if type(model) is transformers.LlamaForCausalLM:

+ 22 - 7
models/shared.py

@@ -11,18 +11,33 @@ args = parser.parse_args()
 loaderCheckPoint: LoaderCheckPoint = None
 
 
-def loaderLLM(no_remote_model, use_ptuning_v2):
+def loaderLLM(llm_model: str = None, no_remote_model: bool = False, use_ptuning_v2: bool = False):
     """
-    初始化LLM
+    init llm_model_ins LLM
+    :param llm_model: model_name
     :param no_remote_model:  remote in the model on loader checkpoint, if your load local model to add the ` --no-remote-model
     :param use_ptuning_v2: Use p-tuning-v2 PrefixEncoder
     :return:
     """
-    llm_model_info = llm_model_dict[LLM_MODEL]
-    loaderCheckPoint.model_name = llm_model_info['path']
-    loaderCheckPoint.no_remote_model = no_remote_model
-    loaderCheckPoint.use_ptuning_v2 = use_ptuning_v2
-    loaderCheckPoint.reload_model()
+    pre_model_name = loaderCheckPoint.model_name
+    llm_model_info = llm_model_dict[pre_model_name]
+
+    if no_remote_model:
+        loaderCheckPoint.no_remote_model = no_remote_model
+    if use_ptuning_v2:
+        loaderCheckPoint.use_ptuning_v2 = use_ptuning_v2
+
+    if llm_model:
+        llm_model_info = llm_model_dict[llm_model]
+
+    if loaderCheckPoint.no_remote_model:
+        loaderCheckPoint.model_name = llm_model_info['name']
+    else:
+        loaderCheckPoint.model_name = llm_model_info['remote-checkpoint']
+
+    if llm_model and pre_model_name != llm_model:
+        loaderCheckPoint.reload_model()
+
     provides_class = getattr(sys.modules['models'], llm_model_info['provides'])
     modelInsLLM = provides_class(llm=loaderCheckPoint)
 

+ 13 - 13
webui.py

@@ -1,6 +1,12 @@
+import sys
+
 import gradio as gr
 import os
 import shutil
+
+
+from models.loader.args import parser
+from models.loader import LoaderCheckPoint
 from chains.local_doc_qa import LocalDocQA
 from configs.model_config import *
 import nltk
@@ -83,14 +89,8 @@ def init_model(llm_model: LLM = None):
 def reinit_model(llm_model, embedding_model, llm_history_len, no_remote_model, use_ptuning_v2, top_k, history):
     try:
 
-        llm_model_info = llm_model_dict[llm_model]
-
-        shared.loaderLLM.model_name = llm_model_info['path']
-        shared.loaderLLM.no_remote_model = no_remote_model
-        shared.loaderLLM.use_ptuning_v2 = use_ptuning_v2
-        shared.loaderLLM.reload_model()
-        llm_model_ins = llm_model_info['provides'](shared.loaderLLM)
-
+        llm_model_ins = shared.loaderLLM(llm_model,no_remote_model,use_ptuning_v2)
+        llm_model_ins.history_len = llm_history_len
         local_doc_qa.init_cfg(llm_model=llm_model_ins,
                               embedding_model=embedding_model,
                               top_k=top_k)
@@ -179,11 +179,11 @@ args = None
 args = parser.parse_args()
 
 args_dict = vars(args)
-shared.loaderLLM = LoaderLLM(args_dict)
-chatGLMLLM = ChatGLM(shared.loaderLLM)
-chatGLMLLM.history_len = LLM_HISTORY_LEN
+shared.loaderCheckPoint = LoaderCheckPoint(args_dict)
+llm_model_ins = shared.loaderLLM()
+llm_model_ins.history_len = LLM_HISTORY_LEN
 
-model_status = init_model(llm_model=chatGLMLLM)
+model_status = init_model(llm_model=llm_model_ins)
 
 with gr.Blocks(css=block_css) as demo:
     vs_path, file_status, model_status, vs_list = gr.State(""), gr.State(""), gr.State(model_status), gr.State(vs_list)
@@ -262,7 +262,7 @@ with gr.Blocks(css=block_css) as demo:
                              value=LLM_MODEL,
                              interactive=True)
 
-        no_remote_model = gr.Checkbox(shared.loaderLLM.no_remote_model,
+        no_remote_model = gr.Checkbox(shared.LoaderCheckPoint.no_remote_model,
                                       label="加载本地模型",
                                       interactive=True)
         llm_history_len = gr.Slider(0,