shared.py 1.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445
  1. import sys
  2. from models.loader.args import parser
  3. from models.loader import LoaderCheckPoint
  4. from configs.model_config import (llm_model_dict, LLM_MODEL)
  5. """迭代器是否停止状态"""
  6. stop_everything = False
  7. args = parser.parse_args()
  8. loaderCheckPoint: LoaderCheckPoint = None
  9. def loaderLLM(llm_model: str = None, no_remote_model: bool = False, use_ptuning_v2: bool = False):
  10. """
  11. init llm_model_ins LLM
  12. :param llm_model: model_name
  13. :param no_remote_model: remote in the model on loader checkpoint, if your load local model to add the ` --no-remote-model
  14. :param use_ptuning_v2: Use p-tuning-v2 PrefixEncoder
  15. :return:
  16. """
  17. pre_model_name = loaderCheckPoint.model_name
  18. llm_model_info = llm_model_dict[pre_model_name]
  19. if no_remote_model:
  20. loaderCheckPoint.no_remote_model = no_remote_model
  21. if use_ptuning_v2:
  22. loaderCheckPoint.use_ptuning_v2 = use_ptuning_v2
  23. if llm_model:
  24. llm_model_info = llm_model_dict[llm_model]
  25. if loaderCheckPoint.no_remote_model:
  26. loaderCheckPoint.model_name = llm_model_info['name']
  27. else:
  28. loaderCheckPoint.model_name = llm_model_info['remote-checkpoint']
  29. loaderCheckPoint.model_path = llm_model_info['path']
  30. loaderCheckPoint.reload_model()
  31. provides_class = getattr(sys.modules['models'], llm_model_info['provides'])
  32. modelInsLLM = provides_class(llm=loaderCheckPoint)
  33. return modelInsLLM