|
@@ -12,13 +12,22 @@ from modules import shared
|
|
|
def split_cross_attention_forward_v1(self, x, context=None, mask=None):
|
|
|
h = self.heads
|
|
|
|
|
|
- q = self.to_q(x)
|
|
|
+ q_in = self.to_q(x)
|
|
|
context = default(context, x)
|
|
|
- k = self.to_k(context)
|
|
|
- v = self.to_v(context)
|
|
|
+
|
|
|
+ hypernetwork = shared.selected_hypernetwork()
|
|
|
+ hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None)
|
|
|
+
|
|
|
+ if hypernetwork_layers is not None:
|
|
|
+ k_in = self.to_k(hypernetwork_layers[0](context))
|
|
|
+ v_in = self.to_v(hypernetwork_layers[1](context))
|
|
|
+ else:
|
|
|
+ k_in = self.to_k(context)
|
|
|
+ v_in = self.to_v(context)
|
|
|
del context, x
|
|
|
|
|
|
- q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
|
|
|
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in))
|
|
|
+ del q_in, k_in, v_in
|
|
|
|
|
|
r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device)
|
|
|
for i in range(0, q.shape[0], 2):
|
|
@@ -31,6 +40,7 @@ def split_cross_attention_forward_v1(self, x, context=None, mask=None):
|
|
|
|
|
|
r1[i:end] = einsum('b i j, b j d -> b i d', s2, v[i:end])
|
|
|
del s2
|
|
|
+ del q, k, v
|
|
|
|
|
|
r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h)
|
|
|
del r1
|