chatglm_llm.py 2.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788
  1. import json
  2. from langchain.llms.base import LLM
  3. from typing import Optional, List
  4. from langchain.llms.utils import enforce_stop_tokens
  5. from models.loader.args import parser
  6. from configs.model_config import *
  7. from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
  8. from models.loader import LoaderCheckPoint
  9. class ChatGLM(LLM):
  10. max_token: int = 10000
  11. temperature: float = 0.01
  12. top_p = 0.9
  13. llm: LoaderCheckPoint = None
  14. # history = []
  15. history_len: int = 10
  16. def __init__(self, llm: LoaderCheckPoint = None):
  17. super().__init__()
  18. self.llm = llm
  19. @property
  20. def _llm_type(self) -> str:
  21. return "ChatGLM"
  22. def _call(self,
  23. prompt: str,
  24. history: List[List[str]] = [],
  25. streaming: bool = STREAMING): # -> Tuple[str, List[List[str]]]:
  26. if streaming:
  27. for inum, (stream_resp, _) in enumerate(self.llm.model.stream_chat(
  28. self.llm.tokenizer,
  29. prompt,
  30. history=history[-self.history_len:-1] if self.history_len > 0 else [],
  31. max_length=self.max_token,
  32. temperature=self.temperature,
  33. )):
  34. self.llm.clear_torch_cache()
  35. if inum == 0:
  36. history += [[prompt, stream_resp]]
  37. else:
  38. history[-1] = [prompt, stream_resp]
  39. yield stream_resp, history
  40. else:
  41. response, _ = self.llm.model.chat(
  42. self.llm.tokenizer,
  43. prompt,
  44. history=history[-self.history_len:] if self.history_len > 0 else [],
  45. max_length=self.max_token,
  46. temperature=self.temperature,
  47. )
  48. self.llm.clear_torch_cache()
  49. history += [[prompt, response]]
  50. yield response, history
  51. # def chat(self,
  52. # prompt: str) -> str:
  53. # response, _ = self.model.chat(
  54. # self.tokenizer,
  55. # prompt,
  56. # history=self.history[-self.history_len:] if self.history_len > 0 else [],
  57. # max_length=self.max_token,
  58. # temperature=self.temperature,
  59. # )
  60. # torch_gc()
  61. # self.history = self.history + [[None, response]]
  62. # return response
  63. if __name__ == "__main__":
  64. # 初始化消息
  65. args = None
  66. args = parser.parse_args()
  67. args_dict = vars(args)
  68. loaderLLM = LoaderCheckPoint(args_dict)
  69. llm = ChatGLM(loaderLLM)
  70. llm.history_len = 10
  71. last_print_len = 0
  72. for resp, history in llm._call("你好", streaming=True):
  73. print(resp[last_print_len:], end="", flush=True)
  74. last_print_len = len(resp)
  75. for resp, history in llm._call("你好", streaming=False):
  76. print(resp)
  77. pass