|
@@ -11,6 +11,18 @@ from transformers import CLIPTokenizer, T5TokenizerFast
|
|
#################################################################################################
|
|
#################################################################################################
|
|
|
|
|
|
|
|
|
|
|
|
+class AutocastLinear(nn.Linear):
|
|
|
|
+ """Same as usual linear layer, but casts its weights to whatever the parameter type is.
|
|
|
|
+
|
|
|
|
+ This is different from torch.autocast in a way that float16 layer processing float32 input
|
|
|
|
+ will return float16 with autocast on, and float32 with this. T5 seems to be fucked
|
|
|
|
+ if you do it in full float16 (returning almost all zeros in the final output).
|
|
|
|
+ """
|
|
|
|
+
|
|
|
|
+ def forward(self, x):
|
|
|
|
+ return torch.nn.functional.linear(x, self.weight.to(x.dtype), self.bias.to(x.dtype) if self.bias is not None else None)
|
|
|
|
+
|
|
|
|
+
|
|
def attention(q, k, v, heads, mask=None):
|
|
def attention(q, k, v, heads, mask=None):
|
|
"""Convenience wrapper around a basic attention operation"""
|
|
"""Convenience wrapper around a basic attention operation"""
|
|
b, _, dim_head = q.shape
|
|
b, _, dim_head = q.shape
|
|
@@ -27,9 +39,9 @@ class Mlp(nn.Module):
|
|
out_features = out_features or in_features
|
|
out_features = out_features or in_features
|
|
hidden_features = hidden_features or in_features
|
|
hidden_features = hidden_features or in_features
|
|
|
|
|
|
- self.fc1 = nn.Linear(in_features, hidden_features, bias=bias, dtype=dtype, device=device)
|
|
|
|
|
|
+ self.fc1 = AutocastLinear(in_features, hidden_features, bias=bias, dtype=dtype, device=device)
|
|
self.act = act_layer
|
|
self.act = act_layer
|
|
- self.fc2 = nn.Linear(hidden_features, out_features, bias=bias, dtype=dtype, device=device)
|
|
|
|
|
|
+ self.fc2 = AutocastLinear(hidden_features, out_features, bias=bias, dtype=dtype, device=device)
|
|
|
|
|
|
def forward(self, x):
|
|
def forward(self, x):
|
|
x = self.fc1(x)
|
|
x = self.fc1(x)
|
|
@@ -297,7 +309,6 @@ class T5XXLModel(SDClipModel):
|
|
### T5 implementation, for the T5-XXL text encoder portion, largely pulled from upstream impl
|
|
### T5 implementation, for the T5-XXL text encoder portion, largely pulled from upstream impl
|
|
#################################################################################################
|
|
#################################################################################################
|
|
|
|
|
|
-
|
|
|
|
class T5XXLTokenizer(SDTokenizer):
|
|
class T5XXLTokenizer(SDTokenizer):
|
|
"""Wraps the T5 Tokenizer from HF into the SDTokenizer interface"""
|
|
"""Wraps the T5 Tokenizer from HF into the SDTokenizer interface"""
|
|
def __init__(self):
|
|
def __init__(self):
|
|
@@ -319,9 +330,9 @@ class T5LayerNorm(torch.nn.Module):
|
|
class T5DenseGatedActDense(torch.nn.Module):
|
|
class T5DenseGatedActDense(torch.nn.Module):
|
|
def __init__(self, model_dim, ff_dim, dtype, device):
|
|
def __init__(self, model_dim, ff_dim, dtype, device):
|
|
super().__init__()
|
|
super().__init__()
|
|
- self.wi_0 = nn.Linear(model_dim, ff_dim, bias=False, dtype=dtype, device=device)
|
|
|
|
- self.wi_1 = nn.Linear(model_dim, ff_dim, bias=False, dtype=dtype, device=device)
|
|
|
|
- self.wo = nn.Linear(ff_dim, model_dim, bias=False, dtype=dtype, device=device)
|
|
|
|
|
|
+ self.wi_0 = AutocastLinear(model_dim, ff_dim, bias=False, dtype=dtype, device=device)
|
|
|
|
+ self.wi_1 = AutocastLinear(model_dim, ff_dim, bias=False, dtype=dtype, device=device)
|
|
|
|
+ self.wo = AutocastLinear(ff_dim, model_dim, bias=False, dtype=dtype, device=device)
|
|
|
|
|
|
def forward(self, x):
|
|
def forward(self, x):
|
|
hidden_gelu = torch.nn.functional.gelu(self.wi_0(x), approximate="tanh")
|
|
hidden_gelu = torch.nn.functional.gelu(self.wi_0(x), approximate="tanh")
|
|
@@ -348,10 +359,10 @@ class T5Attention(torch.nn.Module):
|
|
def __init__(self, model_dim, inner_dim, num_heads, relative_attention_bias, dtype, device):
|
|
def __init__(self, model_dim, inner_dim, num_heads, relative_attention_bias, dtype, device):
|
|
super().__init__()
|
|
super().__init__()
|
|
# Mesh TensorFlow initialization to avoid scaling before softmax
|
|
# Mesh TensorFlow initialization to avoid scaling before softmax
|
|
- self.q = nn.Linear(model_dim, inner_dim, bias=False, dtype=dtype, device=device)
|
|
|
|
- self.k = nn.Linear(model_dim, inner_dim, bias=False, dtype=dtype, device=device)
|
|
|
|
- self.v = nn.Linear(model_dim, inner_dim, bias=False, dtype=dtype, device=device)
|
|
|
|
- self.o = nn.Linear(inner_dim, model_dim, bias=False, dtype=dtype, device=device)
|
|
|
|
|
|
+ self.q = AutocastLinear(model_dim, inner_dim, bias=False, dtype=dtype, device=device)
|
|
|
|
+ self.k = AutocastLinear(model_dim, inner_dim, bias=False, dtype=dtype, device=device)
|
|
|
|
+ self.v = AutocastLinear(model_dim, inner_dim, bias=False, dtype=dtype, device=device)
|
|
|
|
+ self.o = AutocastLinear(inner_dim, model_dim, bias=False, dtype=dtype, device=device)
|
|
self.num_heads = num_heads
|
|
self.num_heads = num_heads
|
|
self.relative_attention_bias = None
|
|
self.relative_attention_bias = None
|
|
if relative_attention_bias:
|
|
if relative_attention_bias:
|
|
@@ -421,11 +432,16 @@ class T5Attention(torch.nn.Module):
|
|
q = self.q(x)
|
|
q = self.q(x)
|
|
k = self.k(x)
|
|
k = self.k(x)
|
|
v = self.v(x)
|
|
v = self.v(x)
|
|
|
|
+
|
|
if self.relative_attention_bias is not None:
|
|
if self.relative_attention_bias is not None:
|
|
past_bias = self.compute_bias(x.shape[1], x.shape[1], x.device)
|
|
past_bias = self.compute_bias(x.shape[1], x.shape[1], x.device)
|
|
if past_bias is not None:
|
|
if past_bias is not None:
|
|
mask = past_bias
|
|
mask = past_bias
|
|
- out = attention(q, k * ((k.shape[-1] / self.num_heads) ** 0.5), v, self.num_heads, mask)
|
|
|
|
|
|
+ else:
|
|
|
|
+ mask = None
|
|
|
|
+
|
|
|
|
+ out = attention(q, k * ((k.shape[-1] / self.num_heads) ** 0.5), v, self.num_heads, mask.to(x.dtype) if mask is not None else None)
|
|
|
|
+
|
|
return self.o(out), past_bias
|
|
return self.o(out), past_bias
|
|
|
|
|
|
|
|
|