|
@@ -602,7 +602,7 @@ def sdp_attnblock_forward(self, x):
|
|
|
q, k, v = (rearrange(t, 'b c h w -> b (h w) c') for t in (q, k, v))
|
|
|
dtype = q.dtype
|
|
|
if shared.opts.upcast_attn:
|
|
|
- q, k = q.float(), k.float()
|
|
|
+ q, k, v = q.float(), k.float(), v.float()
|
|
|
q = q.contiguous()
|
|
|
k = k.contiguous()
|
|
|
v = v.contiguous()
|