|
@@ -3,7 +3,7 @@ from torch.nn.functional import silu
|
|
|
from types import MethodType
|
|
|
|
|
|
import modules.textual_inversion.textual_inversion
|
|
|
-from modules import devices, sd_hijack_optimizations, shared
|
|
|
+from modules import devices, sd_hijack_optimizations, shared, script_callbacks, errors
|
|
|
from modules.hypernetworks import hypernetwork
|
|
|
from modules.shared import cmd_opts
|
|
|
from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr
|
|
@@ -28,57 +28,56 @@ ldm.modules.attention.BasicTransformerBlock.ATTENTION_MODES["softmax-xformers"]
|
|
|
ldm.modules.attention.print = lambda *args: None
|
|
|
ldm.modules.diffusionmodules.model.print = lambda *args: None
|
|
|
|
|
|
+optimizers = []
|
|
|
+current_optimizer: sd_hijack_optimizations.SdOptimization = None
|
|
|
+
|
|
|
+
|
|
|
+def list_optimizers():
|
|
|
+ new_optimizers = script_callbacks.list_optimizers_callback()
|
|
|
+
|
|
|
+ new_optimizers = [x for x in new_optimizers if x.is_available()]
|
|
|
+
|
|
|
+ new_optimizers = sorted(new_optimizers, key=lambda x: x.priority, reverse=True)
|
|
|
+
|
|
|
+ optimizers.clear()
|
|
|
+ optimizers.extend(new_optimizers)
|
|
|
+
|
|
|
|
|
|
def apply_optimizations():
|
|
|
+ global current_optimizer
|
|
|
+
|
|
|
undo_optimizations()
|
|
|
|
|
|
ldm.modules.diffusionmodules.model.nonlinearity = silu
|
|
|
ldm.modules.diffusionmodules.openaimodel.th = sd_hijack_unet.th
|
|
|
|
|
|
- optimization_method = None
|
|
|
+ if current_optimizer is not None:
|
|
|
+ current_optimizer.undo()
|
|
|
+ current_optimizer = None
|
|
|
+
|
|
|
+ selection = shared.opts.cross_attention_optimization
|
|
|
+ if selection == "Automatic" and len(optimizers) > 0:
|
|
|
+ matching_optimizer = next(iter([x for x in optimizers if x.cmd_opt and getattr(shared.cmd_opts, x.cmd_opt, False)]), optimizers[0])
|
|
|
+ else:
|
|
|
+ matching_optimizer = next(iter([x for x in optimizers if x.title() == selection]), None)
|
|
|
|
|
|
- can_use_sdp = hasattr(torch.nn.functional, "scaled_dot_product_attention") and callable(torch.nn.functional.scaled_dot_product_attention) # not everyone has torch 2.x to use sdp
|
|
|
-
|
|
|
- if cmd_opts.force_enable_xformers or (cmd_opts.xformers and shared.xformers_available and torch.version.cuda and (6, 0) <= torch.cuda.get_device_capability(shared.device) <= (9, 0)):
|
|
|
- print("Applying xformers cross attention optimization.")
|
|
|
- ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.xformers_attention_forward
|
|
|
- ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.xformers_attnblock_forward
|
|
|
- optimization_method = 'xformers'
|
|
|
- elif cmd_opts.opt_sdp_no_mem_attention and can_use_sdp:
|
|
|
- print("Applying scaled dot product cross attention optimization (without memory efficient attention).")
|
|
|
- ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.scaled_dot_product_no_mem_attention_forward
|
|
|
- ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.sdp_no_mem_attnblock_forward
|
|
|
- optimization_method = 'sdp-no-mem'
|
|
|
- elif cmd_opts.opt_sdp_attention and can_use_sdp:
|
|
|
- print("Applying scaled dot product cross attention optimization.")
|
|
|
- ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.scaled_dot_product_attention_forward
|
|
|
- ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.sdp_attnblock_forward
|
|
|
- optimization_method = 'sdp'
|
|
|
- elif cmd_opts.opt_sub_quad_attention:
|
|
|
- print("Applying sub-quadratic cross attention optimization.")
|
|
|
- ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.sub_quad_attention_forward
|
|
|
- ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.sub_quad_attnblock_forward
|
|
|
- optimization_method = 'sub-quadratic'
|
|
|
- elif cmd_opts.opt_split_attention_v1:
|
|
|
- print("Applying v1 cross attention optimization.")
|
|
|
- ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1
|
|
|
- optimization_method = 'V1'
|
|
|
- elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention_invokeai or not cmd_opts.opt_split_attention and not torch.cuda.is_available()):
|
|
|
- print("Applying cross attention optimization (InvokeAI).")
|
|
|
- ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_invokeAI
|
|
|
- optimization_method = 'InvokeAI'
|
|
|
- elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention or torch.cuda.is_available()):
|
|
|
- print("Applying cross attention optimization (Doggettx).")
|
|
|
- ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward
|
|
|
- ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.cross_attention_attnblock_forward
|
|
|
- optimization_method = 'Doggettx'
|
|
|
-
|
|
|
- return optimization_method
|
|
|
+ if selection == "None":
|
|
|
+ matching_optimizer = None
|
|
|
+ elif matching_optimizer is None:
|
|
|
+ matching_optimizer = optimizers[0]
|
|
|
+
|
|
|
+ if matching_optimizer is not None:
|
|
|
+ print(f"Applying optimization: {matching_optimizer.name}")
|
|
|
+ matching_optimizer.apply()
|
|
|
+ current_optimizer = matching_optimizer
|
|
|
+ return current_optimizer.name
|
|
|
+ else:
|
|
|
+ return ''
|
|
|
|
|
|
|
|
|
def undo_optimizations():
|
|
|
- ldm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward
|
|
|
ldm.modules.diffusionmodules.model.nonlinearity = diffusionmodules_model_nonlinearity
|
|
|
+ ldm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward
|
|
|
ldm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward
|
|
|
|
|
|
|
|
@@ -169,7 +168,11 @@ class StableDiffusionModelHijack:
|
|
|
if m.cond_stage_key == "edit":
|
|
|
sd_hijack_unet.hijack_ddpm_edit()
|
|
|
|
|
|
- self.optimization_method = apply_optimizations()
|
|
|
+ try:
|
|
|
+ self.optimization_method = apply_optimizations()
|
|
|
+ except Exception as e:
|
|
|
+ errors.display(e, "applying cross attention optimization")
|
|
|
+ undo_optimizations()
|
|
|
|
|
|
self.clip = m.cond_stage_model
|
|
|
|
|
@@ -223,6 +226,10 @@ class StableDiffusionModelHijack:
|
|
|
|
|
|
return token_count, self.clip.get_target_prompt_token_count(token_count)
|
|
|
|
|
|
+ def redo_hijack(self, m):
|
|
|
+ self.undo_hijack(m)
|
|
|
+ self.hijack(m)
|
|
|
+
|
|
|
|
|
|
class EmbeddingsWithFixes(torch.nn.Module):
|
|
|
def __init__(self, wrapped, embeddings):
|