AUTOMATIC1111 1 year ago
parent
commit
a65dd315ad
1 changed files with 27 additions and 11 deletions
  1. 27 11
      modules/models/sd3/other_impls.py

+ 27 - 11
modules/models/sd3/other_impls.py

@@ -11,6 +11,18 @@ from transformers import CLIPTokenizer, T5TokenizerFast
 #################################################################################################
 
 
+class AutocastLinear(nn.Linear):
+    """Same as usual linear layer, but casts its weights to whatever the parameter type is.
+
+    This is different from torch.autocast in a way that float16 layer processing float32 input
+    will return float16 with autocast on, and float32 with this. T5 seems to be fucked
+    if you do it in full float16 (returning almost all zeros in the final output).
+    """
+
+    def forward(self, x):
+        return torch.nn.functional.linear(x, self.weight.to(x.dtype), self.bias.to(x.dtype) if self.bias is not None else None)
+
+
 def attention(q, k, v, heads, mask=None):
     """Convenience wrapper around a basic attention operation"""
     b, _, dim_head = q.shape
@@ -27,9 +39,9 @@ class Mlp(nn.Module):
         out_features = out_features or in_features
         hidden_features = hidden_features or in_features
 
-        self.fc1 = nn.Linear(in_features, hidden_features, bias=bias, dtype=dtype, device=device)
+        self.fc1 = AutocastLinear(in_features, hidden_features, bias=bias, dtype=dtype, device=device)
         self.act = act_layer
-        self.fc2 = nn.Linear(hidden_features, out_features, bias=bias, dtype=dtype, device=device)
+        self.fc2 = AutocastLinear(hidden_features, out_features, bias=bias, dtype=dtype, device=device)
 
     def forward(self, x):
         x = self.fc1(x)
@@ -297,7 +309,6 @@ class T5XXLModel(SDClipModel):
 ### T5 implementation, for the T5-XXL text encoder portion, largely pulled from upstream impl
 #################################################################################################
 
-
 class T5XXLTokenizer(SDTokenizer):
     """Wraps the T5 Tokenizer from HF into the SDTokenizer interface"""
     def __init__(self):
@@ -319,9 +330,9 @@ class T5LayerNorm(torch.nn.Module):
 class T5DenseGatedActDense(torch.nn.Module):
     def __init__(self, model_dim, ff_dim, dtype, device):
         super().__init__()
-        self.wi_0 = nn.Linear(model_dim, ff_dim, bias=False, dtype=dtype, device=device)
-        self.wi_1 = nn.Linear(model_dim, ff_dim, bias=False, dtype=dtype, device=device)
-        self.wo = nn.Linear(ff_dim, model_dim, bias=False, dtype=dtype, device=device)
+        self.wi_0 = AutocastLinear(model_dim, ff_dim, bias=False, dtype=dtype, device=device)
+        self.wi_1 = AutocastLinear(model_dim, ff_dim, bias=False, dtype=dtype, device=device)
+        self.wo = AutocastLinear(ff_dim, model_dim, bias=False, dtype=dtype, device=device)
 
     def forward(self, x):
         hidden_gelu = torch.nn.functional.gelu(self.wi_0(x), approximate="tanh")
@@ -348,10 +359,10 @@ class T5Attention(torch.nn.Module):
     def __init__(self, model_dim, inner_dim, num_heads, relative_attention_bias, dtype, device):
         super().__init__()
         # Mesh TensorFlow initialization to avoid scaling before softmax
-        self.q = nn.Linear(model_dim, inner_dim, bias=False, dtype=dtype, device=device)
-        self.k = nn.Linear(model_dim, inner_dim, bias=False, dtype=dtype, device=device)
-        self.v = nn.Linear(model_dim, inner_dim, bias=False, dtype=dtype, device=device)
-        self.o = nn.Linear(inner_dim, model_dim, bias=False, dtype=dtype, device=device)
+        self.q = AutocastLinear(model_dim, inner_dim, bias=False, dtype=dtype, device=device)
+        self.k = AutocastLinear(model_dim, inner_dim, bias=False, dtype=dtype, device=device)
+        self.v = AutocastLinear(model_dim, inner_dim, bias=False, dtype=dtype, device=device)
+        self.o = AutocastLinear(inner_dim, model_dim, bias=False, dtype=dtype, device=device)
         self.num_heads = num_heads
         self.relative_attention_bias = None
         if relative_attention_bias:
@@ -421,11 +432,16 @@ class T5Attention(torch.nn.Module):
         q = self.q(x)
         k = self.k(x)
         v = self.v(x)
+
         if self.relative_attention_bias is not None:
             past_bias = self.compute_bias(x.shape[1], x.shape[1], x.device)
         if past_bias is not None:
             mask = past_bias
-        out = attention(q, k * ((k.shape[-1] / self.num_heads) ** 0.5), v, self.num_heads, mask)
+        else:
+            mask = None
+
+        out = attention(q, k * ((k.shape[-1] / self.num_heads) ** 0.5), v, self.num_heads, mask.to(x.dtype) if mask is not None else None)
+
         return self.o(out), past_bias