|
@@ -1,8 +1,12 @@
|
|
|
+import logging
|
|
|
+
|
|
|
import torch
|
|
|
import platform
|
|
|
from modules.sd_hijack_utils import CondFunc
|
|
|
from packaging import version
|
|
|
|
|
|
+log = logging.getLogger()
|
|
|
+
|
|
|
|
|
|
# before torch version 1.13, has_mps is only available in nightly pytorch and macOS 12.3+,
|
|
|
# use check `getattr` and try it for compatibility.
|
|
@@ -19,9 +23,19 @@ def check_for_mps() -> bool:
|
|
|
return False
|
|
|
else:
|
|
|
return torch.backends.mps.is_available() and torch.backends.mps.is_built()
|
|
|
+
|
|
|
+
|
|
|
has_mps = check_for_mps()
|
|
|
|
|
|
|
|
|
+def torch_mps_gc() -> None:
|
|
|
+ try:
|
|
|
+ from torch.mps import empty_cache
|
|
|
+ empty_cache()
|
|
|
+ except Exception:
|
|
|
+ log.warning("MPS garbage collection failed", exc_info=True)
|
|
|
+
|
|
|
+
|
|
|
# MPS workaround for https://github.com/pytorch/pytorch/issues/89784
|
|
|
def cumsum_fix(input, cumsum_func, *args, **kwargs):
|
|
|
if input.device.type == 'mps':
|