- import torch.cuda
- import torch.mps
- import torch.backends
- def torch_gc(DEVICE):
- if torch.cuda.is_available():
- with torch.cuda.device(DEVICE):
- torch.cuda.empty_cache()
- torch.cuda.ipc_collect()
- elif torch.backends.mps.is_available():
- torch.mps.empty_cache()
|