__init__.py 512 B

1234567891011121314
  1. import torch
  2. def torch_gc():
  3. if torch.cuda.is_available():
  4. # with torch.cuda.device(DEVICE):
  5. torch.cuda.empty_cache()
  6. torch.cuda.ipc_collect()
  7. elif torch.backends.mps.is_available():
  8. try:
  9. from torch.mps import empty_cache
  10. empty_cache()
  11. except Exception as e:
  12. print(e)
  13. print("如果您使用的是 macOS 建议将 pytorch 版本升级至 2.0.0 或更高版本,以支持及时清理 torch 产生的内存占用。")