__init__.py 303 B

1234567891011
  1. import torch.cuda
  2. import torch.mps
  3. import torch.backends
  4. def torch_gc(DEVICE):
  5. if torch.cuda.is_available():
  6. with torch.cuda.device(DEVICE):
  7. torch.cuda.empty_cache()
  8. torch.cuda.ipc_collect()
  9. elif torch.backends.mps.is_available():
  10. torch.mps.empty_cache()