import torch.cuda import torch.mps import torch.backends def torch_gc(DEVICE): if torch.cuda.is_available(): with torch.cuda.device(DEVICE): torch.cuda.empty_cache() torch.cuda.ipc_collect() elif torch.backends.mps.is_available(): torch.mps.empty_cache()