Эх сурвалжийг харах

make it possible to use hypernetworks without opt split attention

AUTOMATIC 2 жил өмнө
parent
commit
f7c787eb7c

+ 34 - 8
modules/hypernetwork.py

@@ -4,7 +4,12 @@ import sys
 import traceback
 
 import torch
-from modules import devices
+
+from ldm.util import default
+from modules import devices, shared
+import torch
+from torch import einsum
+from einops import rearrange, repeat
 
 
 class HypernetworkModule(torch.nn.Module):
@@ -48,15 +53,36 @@ def load_hypernetworks(path):
 
     return res
 
-def apply(self, x, context=None, mask=None, original=None):
 
+def attention_CrossAttention_forward(self, x, context=None, mask=None):
+    h = self.heads
+
+    q = self.to_q(x)
+    context = default(context, x)
 
-    if CrossAttention.hypernetwork is not None and context.shape[2] in CrossAttention.hypernetwork:
-        if context.shape[1] == 77 and CrossAttention.noise_cond:
-            context = context + (torch.randn_like(context) * 0.1)
-        h_k, h_v = CrossAttention.hypernetwork[context.shape[2]]
-        k = self.to_k(h_k(context))
-        v = self.to_v(h_v(context))
+    hypernetwork = shared.selected_hypernetwork()
+    hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None)
+
+    if hypernetwork_layers is not None:
+        k = self.to_k(hypernetwork_layers[0](context))
+        v = self.to_v(hypernetwork_layers[1](context))
     else:
         k = self.to_k(context)
         v = self.to_v(context)
+
+    q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
+
+    sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
+
+    if mask is not None:
+        mask = rearrange(mask, 'b ... -> b (...)')
+        max_neg_value = -torch.finfo(sim.dtype).max
+        mask = repeat(mask, 'b j -> (b h) () j', h=h)
+        sim.masked_fill_(~mask, max_neg_value)
+
+    # attention, what we cannot get enough of
+    attn = sim.softmax(dim=-1)
+
+    out = einsum('b i j, b j d -> b i d', attn, v)
+    out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
+    return self.to_out(out)

+ 4 - 2
modules/sd_hijack.py

@@ -8,7 +8,7 @@ from torch import einsum
 from torch.nn.functional import silu
 
 import modules.textual_inversion.textual_inversion
-from modules import prompt_parser, devices, sd_hijack_optimizations, shared
+from modules import prompt_parser, devices, sd_hijack_optimizations, shared, hypernetwork
 from modules.shared import opts, device, cmd_opts
 
 import ldm.modules.attention
@@ -20,6 +20,8 @@ diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.At
 
 
 def apply_optimizations():
+    undo_optimizations()
+
     ldm.modules.diffusionmodules.model.nonlinearity = silu
 
     if cmd_opts.opt_split_attention_v1:
@@ -30,7 +32,7 @@ def apply_optimizations():
 
 
 def undo_optimizations():
-    ldm.modules.attention.CrossAttention.forward = attention_CrossAttention_forward
+    ldm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward
     ldm.modules.diffusionmodules.model.nonlinearity = diffusionmodules_model_nonlinearity
     ldm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward