|
@@ -4,7 +4,7 @@ import os
|
|
|
import nltk
|
|
|
from models.loader.args import parser
|
|
|
import models.shared as shared
|
|
|
-from models.loader import LoaderLLM
|
|
|
+from models.loader import LoaderCheckPoint
|
|
|
from models.chatglm_llm import ChatGLM
|
|
|
|
|
|
nltk.data.path = [os.path.join(os.path.dirname(__file__), "nltk_data")] + nltk.data.path
|
|
@@ -22,13 +22,12 @@ if __name__ == "__main__":
|
|
|
args = None
|
|
|
args = parser.parse_args()
|
|
|
args_dict = vars(args)
|
|
|
-
|
|
|
- shared.loaderLLM = LoaderLLM(args_dict)
|
|
|
- chatGLMLLM = ChatGLM(shared.loaderLLM)
|
|
|
- chatGLMLLM.history_len = LLM_HISTORY_LEN
|
|
|
+ shared.loaderCheckPoint = LoaderCheckPoint(args_dict)
|
|
|
+ llm_model_ins = shared.loaderLLM()
|
|
|
+ llm_model_ins.history_len = LLM_HISTORY_LEN
|
|
|
|
|
|
local_doc_qa = LocalDocQA()
|
|
|
- local_doc_qa.init_cfg(llm_model=chatGLMLLM,
|
|
|
+ local_doc_qa.init_cfg(llm_model=llm_model_ins,
|
|
|
embedding_model=EMBEDDING_MODEL,
|
|
|
embedding_device=EMBEDDING_DEVICE,
|
|
|
top_k=VECTOR_SEARCH_TOP_K)
|