瀏覽代碼

Merge pull request #11722 from akx/mps-gc-fix

Fix MPS cache cleanup
AUTOMATIC1111 2 年之前
父節點
當前提交
799760ab95
共有 2 個文件被更改,包括 17 次插入2 次删除
  1. 3 2
      modules/devices.py
  2. 14 0
      modules/mac_specific.py

+ 3 - 2
modules/devices.py

@@ -54,8 +54,9 @@ def torch_gc():
         with torch.cuda.device(get_cuda_device_string()):
             torch.cuda.empty_cache()
             torch.cuda.ipc_collect()
-    elif has_mps() and hasattr(torch.mps, 'empty_cache'):
-        torch.mps.empty_cache()
+
+    if has_mps():
+        mac_specific.torch_mps_gc()
 
 
 def enable_tf32():

+ 14 - 0
modules/mac_specific.py

@@ -1,8 +1,12 @@
+import logging
+
 import torch
 import platform
 from modules.sd_hijack_utils import CondFunc
 from packaging import version
 
+log = logging.getLogger()
+
 
 # before torch version 1.13, has_mps is only available in nightly pytorch and macOS 12.3+,
 # use check `getattr` and try it for compatibility.
@@ -19,9 +23,19 @@ def check_for_mps() -> bool:
             return False
     else:
         return torch.backends.mps.is_available() and torch.backends.mps.is_built()
+
+
 has_mps = check_for_mps()
 
 
+def torch_mps_gc() -> None:
+    try:
+        from torch.mps import empty_cache
+        empty_cache()
+    except Exception:
+        log.warning("MPS garbage collection failed", exc_info=True)
+
+
 # MPS workaround for https://github.com/pytorch/pytorch/issues/89784
 def cumsum_fix(input, cumsum_func, *args, **kwargs):
     if input.device.type == 'mps':