@@ -1,5 +1,4 @@
import torch.cuda
-import torch.mps
import torch.backends
def torch_gc(DEVICE):
@@ -8,4 +7,9 @@ def torch_gc(DEVICE):
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
elif torch.backends.mps.is_available():
- torch.mps.empty_cache()
+ try:
+ import torch.mps
+ torch.mps.empty_cache()
+ except Exception as e:
+ print(e)
+ print("如果您使用的是 macOS 建议将 pytorch 版本升级至 2.0.0 或更高版本,以支持及时清理 torch 产生的内存占用。")