Prechádzať zdrojové kódy

Make sub-quadratic the default for MPS

brkirch 2 rokov pred
rodič
commit
87dd685224
1 zmenil súbory, kde vykonal 5 pridanie a 2 odobranie
  1. 5 2
      modules/sd_hijack_optimizations.py

+ 5 - 2
modules/sd_hijack_optimizations.py

@@ -95,7 +95,10 @@ class SdOptimizationSdp(SdOptimizationSdpNoMem):
 class SdOptimizationSubQuad(SdOptimization):
     name = "sub-quadratic"
     cmd_opt = "opt_sub_quad_attention"
-    priority = 10
+
+    @property
+    def priority(self):
+        return 1000 if shared.device.type == 'mps' else 10
 
     def apply(self):
         ldm.modules.attention.CrossAttention.forward = sub_quad_attention_forward
@@ -121,7 +124,7 @@ class SdOptimizationInvokeAI(SdOptimization):
 
     @property
     def priority(self):
-        return 1000 if not torch.cuda.is_available() else 10
+        return 1000 if shared.device.type != 'mps' and not torch.cuda.is_available() else 10
 
     def apply(self):
         ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward_invokeAI