@@ -1,5 +1,4 @@
-import torch.cuda
-import torch.backends
+import torch
def torch_gc(DEVICE):
if torch.cuda.is_available():
@@ -8,7 +7,6 @@ def torch_gc(DEVICE):
torch.cuda.ipc_collect()
elif torch.backends.mps.is_available():
try:
- import torch.mps
torch.mps.empty_cache()
except Exception as e:
print(e)