|
@@ -173,8 +173,8 @@ def init_cfg():
|
|
|
start_time = time.perf_counter()
|
|
|
print("加载GLM模型......")
|
|
|
chatglm = ChatGLM()
|
|
|
- chatglm.load_model(model_name_or_path=MODEL_CONFIG.llm_model_dict[MODEL_CONFIG.LLM_MODEL])
|
|
|
- chatglm.history_len = MODEL_CONFIG.LLM_HISTORY_LEN
|
|
|
+ chatglm.load_model(model_name_or_path=model_config.llm_model_dict[model_config.LLM_MODEL])
|
|
|
+ chatglm.history_len = model_config.LLM_HISTORY_LEN
|
|
|
print("模型加载完成!!!")
|
|
|
end_time = time.perf_counter()
|
|
|
# 计算操作耗时
|
|
@@ -195,7 +195,7 @@ def init_cfg():
|
|
|
prompt = ChatPromptTemplate.from_messages(messages)
|
|
|
knowledge_chain = RetrievalQA.from_llm(
|
|
|
llm=chatglm,
|
|
|
- retriever=vector_store.as_retriever(search_kwargs={"k": MODEL_CONFIG.VECTOR_SEARCH_TOP_K}),
|
|
|
+ retriever=vector_store.as_retriever(search_kwargs={"k": model_config.VECTOR_SEARCH_TOP_K}),
|
|
|
prompt=prompt
|
|
|
)
|
|
|
knowledge_chain.return_source_documents = False
|