|
@@ -1,13 +1,9 @@
|
|
|
import json
|
|
|
from langchain.llms.base import LLM
|
|
|
-from typing import Optional, List
|
|
|
-from langchain.llms.utils import enforce_stop_tokens
|
|
|
+from typing import List, Dict, Optional
|
|
|
from transformers import AutoTokenizer, AutoModel, AutoConfig
|
|
|
import torch
|
|
|
from configs.model_config import *
|
|
|
-from langchain.callbacks.base import CallbackManager
|
|
|
-from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
|
|
|
-from typing import Dict, Tuple, Union, Optional
|
|
|
from utils import torch_gc
|
|
|
|
|
|
DEVICE_ = LLM_DEVICE
|
|
@@ -53,7 +49,6 @@ class ChatGLM(LLM):
|
|
|
tokenizer: object = None
|
|
|
model: object = None
|
|
|
history_len: int = 10
|
|
|
- callback_manager = CallbackManager([StreamingStdOutCallbackHandler()])
|
|
|
|
|
|
def __init__(self):
|
|
|
super().__init__()
|