Răsfoiți Sursa

Merge pull request #11722 from akx/mps-gc-fix

Fix MPS cache cleanup
AUTOMATIC1111 2 ani în urmă
părinte
comite
799760ab95
2 a modificat fișierele cu 17 adăugiri și 2 ștergeri
  1. 3 2
      modules/devices.py
  2. 14 0
      modules/mac_specific.py

+ 3 - 2
modules/devices.py

@@ -54,8 +54,9 @@ def torch_gc():
         with torch.cuda.device(get_cuda_device_string()):
         with torch.cuda.device(get_cuda_device_string()):
             torch.cuda.empty_cache()
             torch.cuda.empty_cache()
             torch.cuda.ipc_collect()
             torch.cuda.ipc_collect()
-    elif has_mps() and hasattr(torch.mps, 'empty_cache'):
-        torch.mps.empty_cache()
+
+    if has_mps():
+        mac_specific.torch_mps_gc()
 
 
 
 
 def enable_tf32():
 def enable_tf32():

+ 14 - 0
modules/mac_specific.py

@@ -1,8 +1,12 @@
+import logging
+
 import torch
 import torch
 import platform
 import platform
 from modules.sd_hijack_utils import CondFunc
 from modules.sd_hijack_utils import CondFunc
 from packaging import version
 from packaging import version
 
 
+log = logging.getLogger()
+
 
 
 # before torch version 1.13, has_mps is only available in nightly pytorch and macOS 12.3+,
 # before torch version 1.13, has_mps is only available in nightly pytorch and macOS 12.3+,
 # use check `getattr` and try it for compatibility.
 # use check `getattr` and try it for compatibility.
@@ -19,9 +23,19 @@ def check_for_mps() -> bool:
             return False
             return False
     else:
     else:
         return torch.backends.mps.is_available() and torch.backends.mps.is_built()
         return torch.backends.mps.is_available() and torch.backends.mps.is_built()
+
+
 has_mps = check_for_mps()
 has_mps = check_for_mps()
 
 
 
 
+def torch_mps_gc() -> None:
+    try:
+        from torch.mps import empty_cache
+        empty_cache()
+    except Exception:
+        log.warning("MPS garbage collection failed", exc_info=True)
+
+
 # MPS workaround for https://github.com/pytorch/pytorch/issues/89784
 # MPS workaround for https://github.com/pytorch/pytorch/issues/89784
 def cumsum_fix(input, cumsum_func, *args, **kwargs):
 def cumsum_fix(input, cumsum_func, *args, **kwargs):
     if input.device.type == 'mps':
     if input.device.type == 'mps':