other_impls.py 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510
  1. ### This file contains impls for underlying related models (CLIP, T5, etc)
  2. import torch
  3. import math
  4. from torch import nn
  5. from transformers import CLIPTokenizer, T5TokenizerFast
  6. from modules import sd_hijack
  7. #################################################################################################
  8. ### Core/Utility
  9. #################################################################################################
  10. class AutocastLinear(nn.Linear):
  11. """Same as usual linear layer, but casts its weights to whatever the parameter type is.
  12. This is different from torch.autocast in a way that float16 layer processing float32 input
  13. will return float16 with autocast on, and float32 with this. T5 seems to be fucked
  14. if you do it in full float16 (returning almost all zeros in the final output).
  15. """
  16. def forward(self, x):
  17. return torch.nn.functional.linear(x, self.weight.to(x.dtype), self.bias.to(x.dtype) if self.bias is not None else None)
  18. def attention(q, k, v, heads, mask=None):
  19. """Convenience wrapper around a basic attention operation"""
  20. b, _, dim_head = q.shape
  21. dim_head //= heads
  22. q, k, v = [t.view(b, -1, heads, dim_head).transpose(1, 2) for t in (q, k, v)]
  23. out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
  24. return out.transpose(1, 2).reshape(b, -1, heads * dim_head)
  25. class Mlp(nn.Module):
  26. """ MLP as used in Vision Transformer, MLP-Mixer and related networks"""
  27. def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, bias=True, dtype=None, device=None):
  28. super().__init__()
  29. out_features = out_features or in_features
  30. hidden_features = hidden_features or in_features
  31. self.fc1 = nn.Linear(in_features, hidden_features, bias=bias, dtype=dtype, device=device)
  32. self.act = act_layer
  33. self.fc2 = nn.Linear(hidden_features, out_features, bias=bias, dtype=dtype, device=device)
  34. def forward(self, x):
  35. x = self.fc1(x)
  36. x = self.act(x)
  37. x = self.fc2(x)
  38. return x
  39. #################################################################################################
  40. ### CLIP
  41. #################################################################################################
  42. class CLIPAttention(torch.nn.Module):
  43. def __init__(self, embed_dim, heads, dtype, device):
  44. super().__init__()
  45. self.heads = heads
  46. self.q_proj = nn.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device)
  47. self.k_proj = nn.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device)
  48. self.v_proj = nn.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device)
  49. self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device)
  50. def forward(self, x, mask=None):
  51. q = self.q_proj(x)
  52. k = self.k_proj(x)
  53. v = self.v_proj(x)
  54. out = attention(q, k, v, self.heads, mask)
  55. return self.out_proj(out)
  56. ACTIVATIONS = {
  57. "quick_gelu": lambda a: a * torch.sigmoid(1.702 * a),
  58. "gelu": torch.nn.functional.gelu,
  59. }
  60. class CLIPLayer(torch.nn.Module):
  61. def __init__(self, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device):
  62. super().__init__()
  63. self.layer_norm1 = nn.LayerNorm(embed_dim, dtype=dtype, device=device)
  64. self.self_attn = CLIPAttention(embed_dim, heads, dtype, device)
  65. self.layer_norm2 = nn.LayerNorm(embed_dim, dtype=dtype, device=device)
  66. #self.mlp = CLIPMLP(embed_dim, intermediate_size, intermediate_activation, dtype, device)
  67. self.mlp = Mlp(embed_dim, intermediate_size, embed_dim, act_layer=ACTIVATIONS[intermediate_activation], dtype=dtype, device=device)
  68. def forward(self, x, mask=None):
  69. x += self.self_attn(self.layer_norm1(x), mask)
  70. x += self.mlp(self.layer_norm2(x))
  71. return x
  72. class CLIPEncoder(torch.nn.Module):
  73. def __init__(self, num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device):
  74. super().__init__()
  75. self.layers = torch.nn.ModuleList([CLIPLayer(embed_dim, heads, intermediate_size, intermediate_activation, dtype, device) for i in range(num_layers)])
  76. def forward(self, x, mask=None, intermediate_output=None):
  77. if intermediate_output is not None:
  78. if intermediate_output < 0:
  79. intermediate_output = len(self.layers) + intermediate_output
  80. intermediate = None
  81. for i, layer in enumerate(self.layers):
  82. x = layer(x, mask)
  83. if i == intermediate_output:
  84. intermediate = x.clone()
  85. return x, intermediate
  86. class CLIPEmbeddings(torch.nn.Module):
  87. def __init__(self, embed_dim, vocab_size=49408, num_positions=77, dtype=None, device=None, textual_inversion_key="clip_l"):
  88. super().__init__()
  89. self.token_embedding = sd_hijack.TextualInversionEmbeddings(vocab_size, embed_dim, dtype=dtype, device=device, textual_inversion_key=textual_inversion_key)
  90. self.position_embedding = torch.nn.Embedding(num_positions, embed_dim, dtype=dtype, device=device)
  91. def forward(self, input_tokens):
  92. return self.token_embedding(input_tokens) + self.position_embedding.weight
  93. class CLIPTextModel_(torch.nn.Module):
  94. def __init__(self, config_dict, dtype, device):
  95. num_layers = config_dict["num_hidden_layers"]
  96. embed_dim = config_dict["hidden_size"]
  97. heads = config_dict["num_attention_heads"]
  98. intermediate_size = config_dict["intermediate_size"]
  99. intermediate_activation = config_dict["hidden_act"]
  100. super().__init__()
  101. self.embeddings = CLIPEmbeddings(embed_dim, dtype=torch.float32, device=device, textual_inversion_key=config_dict.get('textual_inversion_key', 'clip_l'))
  102. self.encoder = CLIPEncoder(num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device)
  103. self.final_layer_norm = nn.LayerNorm(embed_dim, dtype=dtype, device=device)
  104. def forward(self, input_tokens, intermediate_output=None, final_layer_norm_intermediate=True):
  105. x = self.embeddings(input_tokens)
  106. causal_mask = torch.empty(x.shape[1], x.shape[1], dtype=x.dtype, device=x.device).fill_(float("-inf")).triu_(1)
  107. x, i = self.encoder(x, mask=causal_mask, intermediate_output=intermediate_output)
  108. x = self.final_layer_norm(x)
  109. if i is not None and final_layer_norm_intermediate:
  110. i = self.final_layer_norm(i)
  111. pooled_output = x[torch.arange(x.shape[0], device=x.device), input_tokens.to(dtype=torch.int, device=x.device).argmax(dim=-1),]
  112. return x, i, pooled_output
  113. class CLIPTextModel(torch.nn.Module):
  114. def __init__(self, config_dict, dtype, device):
  115. super().__init__()
  116. self.num_layers = config_dict["num_hidden_layers"]
  117. self.text_model = CLIPTextModel_(config_dict, dtype, device)
  118. embed_dim = config_dict["hidden_size"]
  119. self.text_projection = nn.Linear(embed_dim, embed_dim, bias=False, dtype=dtype, device=device)
  120. self.text_projection.weight.copy_(torch.eye(embed_dim))
  121. self.dtype = dtype
  122. def get_input_embeddings(self):
  123. return self.text_model.embeddings.token_embedding
  124. def set_input_embeddings(self, embeddings):
  125. self.text_model.embeddings.token_embedding = embeddings
  126. def forward(self, *args, **kwargs):
  127. x = self.text_model(*args, **kwargs)
  128. out = self.text_projection(x[2])
  129. return (x[0], x[1], out, x[2])
  130. class SDTokenizer:
  131. def __init__(self, max_length=77, pad_with_end=True, tokenizer=None, has_start_token=True, pad_to_max_length=True, min_length=None):
  132. self.tokenizer = tokenizer
  133. self.max_length = max_length
  134. self.min_length = min_length
  135. empty = self.tokenizer('')["input_ids"]
  136. if has_start_token:
  137. self.tokens_start = 1
  138. self.start_token = empty[0]
  139. self.end_token = empty[1]
  140. else:
  141. self.tokens_start = 0
  142. self.start_token = None
  143. self.end_token = empty[0]
  144. self.pad_with_end = pad_with_end
  145. self.pad_to_max_length = pad_to_max_length
  146. vocab = self.tokenizer.get_vocab()
  147. self.inv_vocab = {v: k for k, v in vocab.items()}
  148. self.max_word_length = 8
  149. def tokenize_with_weights(self, text:str):
  150. """Tokenize the text, with weight values - presume 1.0 for all and ignore other features here. The details aren't relevant for a reference impl, and weights themselves has weak effect on SD3."""
  151. if self.pad_with_end:
  152. pad_token = self.end_token
  153. else:
  154. pad_token = 0
  155. batch = []
  156. if self.start_token is not None:
  157. batch.append((self.start_token, 1.0))
  158. to_tokenize = text.replace("\n", " ").split(' ')
  159. to_tokenize = [x for x in to_tokenize if x != ""]
  160. for word in to_tokenize:
  161. batch.extend([(t, 1) for t in self.tokenizer(word)["input_ids"][self.tokens_start:-1]])
  162. batch.append((self.end_token, 1.0))
  163. if self.pad_to_max_length:
  164. batch.extend([(pad_token, 1.0)] * (self.max_length - len(batch)))
  165. if self.min_length is not None and len(batch) < self.min_length:
  166. batch.extend([(pad_token, 1.0)] * (self.min_length - len(batch)))
  167. return [batch]
  168. class SDXLClipGTokenizer(SDTokenizer):
  169. def __init__(self, tokenizer):
  170. super().__init__(pad_with_end=False, tokenizer=tokenizer)
  171. class SD3Tokenizer:
  172. def __init__(self):
  173. clip_tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
  174. self.clip_l = SDTokenizer(tokenizer=clip_tokenizer)
  175. self.clip_g = SDXLClipGTokenizer(clip_tokenizer)
  176. self.t5xxl = T5XXLTokenizer()
  177. def tokenize_with_weights(self, text:str):
  178. out = {}
  179. out["g"] = self.clip_g.tokenize_with_weights(text)
  180. out["l"] = self.clip_l.tokenize_with_weights(text)
  181. out["t5xxl"] = self.t5xxl.tokenize_with_weights(text)
  182. return out
  183. class ClipTokenWeightEncoder:
  184. def encode_token_weights(self, token_weight_pairs):
  185. tokens = [a[0] for a in token_weight_pairs[0]]
  186. out, pooled = self([tokens])
  187. if pooled is not None:
  188. first_pooled = pooled[0:1].cpu()
  189. else:
  190. first_pooled = pooled
  191. output = [out[0:1]]
  192. return torch.cat(output, dim=-2).cpu(), first_pooled
  193. class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
  194. """Uses the CLIP transformer encoder for text (from huggingface)"""
  195. LAYERS = ["last", "pooled", "hidden"]
  196. def __init__(self, device="cpu", max_length=77, layer="last", layer_idx=None, textmodel_json_config=None, dtype=None, model_class=CLIPTextModel,
  197. special_tokens=None, layer_norm_hidden_state=True, return_projected_pooled=True):
  198. super().__init__()
  199. assert layer in self.LAYERS
  200. self.transformer = model_class(textmodel_json_config, dtype, device)
  201. self.num_layers = self.transformer.num_layers
  202. self.max_length = max_length
  203. self.transformer = self.transformer.eval()
  204. for param in self.parameters():
  205. param.requires_grad = False
  206. self.layer = layer
  207. self.layer_idx = None
  208. self.special_tokens = special_tokens if special_tokens is not None else {"start": 49406, "end": 49407, "pad": 49407}
  209. self.logit_scale = torch.nn.Parameter(torch.tensor(4.6055))
  210. self.layer_norm_hidden_state = layer_norm_hidden_state
  211. self.return_projected_pooled = return_projected_pooled
  212. if layer == "hidden":
  213. assert layer_idx is not None
  214. assert abs(layer_idx) < self.num_layers
  215. self.set_clip_options({"layer": layer_idx})
  216. self.options_default = (self.layer, self.layer_idx, self.return_projected_pooled)
  217. def set_clip_options(self, options):
  218. layer_idx = options.get("layer", self.layer_idx)
  219. self.return_projected_pooled = options.get("projected_pooled", self.return_projected_pooled)
  220. if layer_idx is None or abs(layer_idx) > self.num_layers:
  221. self.layer = "last"
  222. else:
  223. self.layer = "hidden"
  224. self.layer_idx = layer_idx
  225. def forward(self, tokens):
  226. backup_embeds = self.transformer.get_input_embeddings()
  227. tokens = torch.asarray(tokens, dtype=torch.int64, device=backup_embeds.weight.device)
  228. outputs = self.transformer(tokens, intermediate_output=self.layer_idx, final_layer_norm_intermediate=self.layer_norm_hidden_state)
  229. self.transformer.set_input_embeddings(backup_embeds)
  230. if self.layer == "last":
  231. z = outputs[0]
  232. else:
  233. z = outputs[1]
  234. pooled_output = None
  235. if len(outputs) >= 3:
  236. if not self.return_projected_pooled and len(outputs) >= 4 and outputs[3] is not None:
  237. pooled_output = outputs[3].float()
  238. elif outputs[2] is not None:
  239. pooled_output = outputs[2].float()
  240. return z.float(), pooled_output
  241. class SDXLClipG(SDClipModel):
  242. """Wraps the CLIP-G model into the SD-CLIP-Model interface"""
  243. def __init__(self, config, device="cpu", layer="penultimate", layer_idx=None, dtype=None):
  244. if layer == "penultimate":
  245. layer="hidden"
  246. layer_idx=-2
  247. super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=config, dtype=dtype, special_tokens={"start": 49406, "end": 49407, "pad": 0}, layer_norm_hidden_state=False)
  248. class T5XXLModel(SDClipModel):
  249. """Wraps the T5-XXL model into the SD-CLIP-Model interface for convenience"""
  250. def __init__(self, config, device="cpu", layer="last", layer_idx=None, dtype=None):
  251. super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=T5)
  252. #################################################################################################
  253. ### T5 implementation, for the T5-XXL text encoder portion, largely pulled from upstream impl
  254. #################################################################################################
  255. class T5XXLTokenizer(SDTokenizer):
  256. """Wraps the T5 Tokenizer from HF into the SDTokenizer interface"""
  257. def __init__(self):
  258. super().__init__(pad_with_end=False, tokenizer=T5TokenizerFast.from_pretrained("google/t5-v1_1-xxl"), has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=77)
  259. class T5LayerNorm(torch.nn.Module):
  260. def __init__(self, hidden_size, eps=1e-6, dtype=None, device=None):
  261. super().__init__()
  262. self.weight = torch.nn.Parameter(torch.ones(hidden_size, dtype=dtype, device=device))
  263. self.variance_epsilon = eps
  264. def forward(self, x):
  265. variance = x.pow(2).mean(-1, keepdim=True)
  266. x = x * torch.rsqrt(variance + self.variance_epsilon)
  267. return self.weight.to(device=x.device, dtype=x.dtype) * x
  268. class T5DenseGatedActDense(torch.nn.Module):
  269. def __init__(self, model_dim, ff_dim, dtype, device):
  270. super().__init__()
  271. self.wi_0 = AutocastLinear(model_dim, ff_dim, bias=False, dtype=dtype, device=device)
  272. self.wi_1 = AutocastLinear(model_dim, ff_dim, bias=False, dtype=dtype, device=device)
  273. self.wo = AutocastLinear(ff_dim, model_dim, bias=False, dtype=dtype, device=device)
  274. def forward(self, x):
  275. hidden_gelu = torch.nn.functional.gelu(self.wi_0(x), approximate="tanh")
  276. hidden_linear = self.wi_1(x)
  277. x = hidden_gelu * hidden_linear
  278. x = self.wo(x)
  279. return x
  280. class T5LayerFF(torch.nn.Module):
  281. def __init__(self, model_dim, ff_dim, dtype, device):
  282. super().__init__()
  283. self.DenseReluDense = T5DenseGatedActDense(model_dim, ff_dim, dtype, device)
  284. self.layer_norm = T5LayerNorm(model_dim, dtype=dtype, device=device)
  285. def forward(self, x):
  286. forwarded_states = self.layer_norm(x)
  287. forwarded_states = self.DenseReluDense(forwarded_states)
  288. x += forwarded_states
  289. return x
  290. class T5Attention(torch.nn.Module):
  291. def __init__(self, model_dim, inner_dim, num_heads, relative_attention_bias, dtype, device):
  292. super().__init__()
  293. # Mesh TensorFlow initialization to avoid scaling before softmax
  294. self.q = AutocastLinear(model_dim, inner_dim, bias=False, dtype=dtype, device=device)
  295. self.k = AutocastLinear(model_dim, inner_dim, bias=False, dtype=dtype, device=device)
  296. self.v = AutocastLinear(model_dim, inner_dim, bias=False, dtype=dtype, device=device)
  297. self.o = AutocastLinear(inner_dim, model_dim, bias=False, dtype=dtype, device=device)
  298. self.num_heads = num_heads
  299. self.relative_attention_bias = None
  300. if relative_attention_bias:
  301. self.relative_attention_num_buckets = 32
  302. self.relative_attention_max_distance = 128
  303. self.relative_attention_bias = torch.nn.Embedding(self.relative_attention_num_buckets, self.num_heads, device=device)
  304. @staticmethod
  305. def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128):
  306. """
  307. Adapted from Mesh Tensorflow:
  308. https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593
  309. Translate relative position to a bucket number for relative attention. The relative position is defined as
  310. memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to
  311. position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for
  312. small absolute relative_position and larger buckets for larger absolute relative_positions. All relative
  313. positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket.
  314. This should allow for more graceful generalization to longer sequences than the model has been trained on
  315. Args:
  316. relative_position: an int32 Tensor
  317. bidirectional: a boolean - whether the attention is bidirectional
  318. num_buckets: an integer
  319. max_distance: an integer
  320. Returns:
  321. a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets)
  322. """
  323. relative_buckets = 0
  324. if bidirectional:
  325. num_buckets //= 2
  326. relative_buckets += (relative_position > 0).to(torch.long) * num_buckets
  327. relative_position = torch.abs(relative_position)
  328. else:
  329. relative_position = -torch.min(relative_position, torch.zeros_like(relative_position))
  330. # now relative_position is in the range [0, inf)
  331. # half of the buckets are for exact increments in positions
  332. max_exact = num_buckets // 2
  333. is_small = relative_position < max_exact
  334. # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance
  335. relative_position_if_large = max_exact + (
  336. torch.log(relative_position.float() / max_exact)
  337. / math.log(max_distance / max_exact)
  338. * (num_buckets - max_exact)
  339. ).to(torch.long)
  340. relative_position_if_large = torch.min(relative_position_if_large, torch.full_like(relative_position_if_large, num_buckets - 1))
  341. relative_buckets += torch.where(is_small, relative_position, relative_position_if_large)
  342. return relative_buckets
  343. def compute_bias(self, query_length, key_length, device):
  344. """Compute binned relative position bias"""
  345. context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None]
  346. memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :]
  347. relative_position = memory_position - context_position # shape (query_length, key_length)
  348. relative_position_bucket = self._relative_position_bucket(
  349. relative_position, # shape (query_length, key_length)
  350. bidirectional=True,
  351. num_buckets=self.relative_attention_num_buckets,
  352. max_distance=self.relative_attention_max_distance,
  353. )
  354. values = self.relative_attention_bias(relative_position_bucket) # shape (query_length, key_length, num_heads)
  355. values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length)
  356. return values
  357. def forward(self, x, past_bias=None):
  358. q = self.q(x)
  359. k = self.k(x)
  360. v = self.v(x)
  361. if self.relative_attention_bias is not None:
  362. past_bias = self.compute_bias(x.shape[1], x.shape[1], x.device)
  363. if past_bias is not None:
  364. mask = past_bias
  365. else:
  366. mask = None
  367. out = attention(q, k * ((k.shape[-1] / self.num_heads) ** 0.5), v, self.num_heads, mask.to(x.dtype) if mask is not None else None)
  368. return self.o(out), past_bias
  369. class T5LayerSelfAttention(torch.nn.Module):
  370. def __init__(self, model_dim, inner_dim, ff_dim, num_heads, relative_attention_bias, dtype, device):
  371. super().__init__()
  372. self.SelfAttention = T5Attention(model_dim, inner_dim, num_heads, relative_attention_bias, dtype, device)
  373. self.layer_norm = T5LayerNorm(model_dim, dtype=dtype, device=device)
  374. def forward(self, x, past_bias=None):
  375. output, past_bias = self.SelfAttention(self.layer_norm(x), past_bias=past_bias)
  376. x += output
  377. return x, past_bias
  378. class T5Block(torch.nn.Module):
  379. def __init__(self, model_dim, inner_dim, ff_dim, num_heads, relative_attention_bias, dtype, device):
  380. super().__init__()
  381. self.layer = torch.nn.ModuleList()
  382. self.layer.append(T5LayerSelfAttention(model_dim, inner_dim, ff_dim, num_heads, relative_attention_bias, dtype, device))
  383. self.layer.append(T5LayerFF(model_dim, ff_dim, dtype, device))
  384. def forward(self, x, past_bias=None):
  385. x, past_bias = self.layer[0](x, past_bias)
  386. x = self.layer[-1](x)
  387. return x, past_bias
  388. class T5Stack(torch.nn.Module):
  389. def __init__(self, num_layers, model_dim, inner_dim, ff_dim, num_heads, vocab_size, dtype, device):
  390. super().__init__()
  391. self.embed_tokens = torch.nn.Embedding(vocab_size, model_dim, device=device)
  392. self.block = torch.nn.ModuleList([T5Block(model_dim, inner_dim, ff_dim, num_heads, relative_attention_bias=(i == 0), dtype=dtype, device=device) for i in range(num_layers)])
  393. self.final_layer_norm = T5LayerNorm(model_dim, dtype=dtype, device=device)
  394. def forward(self, input_ids, intermediate_output=None, final_layer_norm_intermediate=True):
  395. intermediate = None
  396. x = self.embed_tokens(input_ids).to(torch.float32) # needs float32 or else T5 returns all zeroes
  397. past_bias = None
  398. for i, layer in enumerate(self.block):
  399. x, past_bias = layer(x, past_bias)
  400. if i == intermediate_output:
  401. intermediate = x.clone()
  402. x = self.final_layer_norm(x)
  403. if intermediate is not None and final_layer_norm_intermediate:
  404. intermediate = self.final_layer_norm(intermediate)
  405. return x, intermediate
  406. class T5(torch.nn.Module):
  407. def __init__(self, config_dict, dtype, device):
  408. super().__init__()
  409. self.num_layers = config_dict["num_layers"]
  410. self.encoder = T5Stack(self.num_layers, config_dict["d_model"], config_dict["d_model"], config_dict["d_ff"], config_dict["num_heads"], config_dict["vocab_size"], dtype, device)
  411. self.dtype = dtype
  412. def get_input_embeddings(self):
  413. return self.encoder.embed_tokens
  414. def set_input_embeddings(self, embeddings):
  415. self.encoder.embed_tokens = embeddings
  416. def forward(self, *args, **kwargs):
  417. return self.encoder(*args, **kwargs)