other_impls.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492
  1. ### This file contains impls for underlying related models (CLIP, T5, etc)
  2. import torch, math
  3. from torch import nn
  4. from transformers import CLIPTokenizer, T5TokenizerFast
  5. #################################################################################################
  6. ### Core/Utility
  7. #################################################################################################
  8. def attention(q, k, v, heads, mask=None):
  9. """Convenience wrapper around a basic attention operation"""
  10. b, _, dim_head = q.shape
  11. dim_head //= heads
  12. q, k, v = map(lambda t: t.view(b, -1, heads, dim_head).transpose(1, 2), (q, k, v))
  13. out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
  14. return out.transpose(1, 2).reshape(b, -1, heads * dim_head)
  15. class Mlp(nn.Module):
  16. """ MLP as used in Vision Transformer, MLP-Mixer and related networks"""
  17. def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, bias=True, dtype=None, device=None):
  18. super().__init__()
  19. out_features = out_features or in_features
  20. hidden_features = hidden_features or in_features
  21. self.fc1 = nn.Linear(in_features, hidden_features, bias=bias, dtype=dtype, device=device)
  22. self.act = act_layer
  23. self.fc2 = nn.Linear(hidden_features, out_features, bias=bias, dtype=dtype, device=device)
  24. def forward(self, x):
  25. x = self.fc1(x)
  26. x = self.act(x)
  27. x = self.fc2(x)
  28. return x
  29. #################################################################################################
  30. ### CLIP
  31. #################################################################################################
  32. class CLIPAttention(torch.nn.Module):
  33. def __init__(self, embed_dim, heads, dtype, device):
  34. super().__init__()
  35. self.heads = heads
  36. self.q_proj = nn.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device)
  37. self.k_proj = nn.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device)
  38. self.v_proj = nn.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device)
  39. self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device)
  40. def forward(self, x, mask=None):
  41. q = self.q_proj(x)
  42. k = self.k_proj(x)
  43. v = self.v_proj(x)
  44. out = attention(q, k, v, self.heads, mask)
  45. return self.out_proj(out)
  46. ACTIVATIONS = {
  47. "quick_gelu": lambda a: a * torch.sigmoid(1.702 * a),
  48. "gelu": torch.nn.functional.gelu,
  49. }
  50. class CLIPLayer(torch.nn.Module):
  51. def __init__(self, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device):
  52. super().__init__()
  53. self.layer_norm1 = nn.LayerNorm(embed_dim, dtype=dtype, device=device)
  54. self.self_attn = CLIPAttention(embed_dim, heads, dtype, device)
  55. self.layer_norm2 = nn.LayerNorm(embed_dim, dtype=dtype, device=device)
  56. #self.mlp = CLIPMLP(embed_dim, intermediate_size, intermediate_activation, dtype, device)
  57. self.mlp = Mlp(embed_dim, intermediate_size, embed_dim, act_layer=ACTIVATIONS[intermediate_activation], dtype=dtype, device=device)
  58. def forward(self, x, mask=None):
  59. x += self.self_attn(self.layer_norm1(x), mask)
  60. x += self.mlp(self.layer_norm2(x))
  61. return x
  62. class CLIPEncoder(torch.nn.Module):
  63. def __init__(self, num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device):
  64. super().__init__()
  65. self.layers = torch.nn.ModuleList([CLIPLayer(embed_dim, heads, intermediate_size, intermediate_activation, dtype, device) for i in range(num_layers)])
  66. def forward(self, x, mask=None, intermediate_output=None):
  67. if intermediate_output is not None:
  68. if intermediate_output < 0:
  69. intermediate_output = len(self.layers) + intermediate_output
  70. intermediate = None
  71. for i, l in enumerate(self.layers):
  72. x = l(x, mask)
  73. if i == intermediate_output:
  74. intermediate = x.clone()
  75. return x, intermediate
  76. class CLIPEmbeddings(torch.nn.Module):
  77. def __init__(self, embed_dim, vocab_size=49408, num_positions=77, dtype=None, device=None):
  78. super().__init__()
  79. self.token_embedding = torch.nn.Embedding(vocab_size, embed_dim, dtype=dtype, device=device)
  80. self.position_embedding = torch.nn.Embedding(num_positions, embed_dim, dtype=dtype, device=device)
  81. def forward(self, input_tokens):
  82. return self.token_embedding(input_tokens) + self.position_embedding.weight
  83. class CLIPTextModel_(torch.nn.Module):
  84. def __init__(self, config_dict, dtype, device):
  85. num_layers = config_dict["num_hidden_layers"]
  86. embed_dim = config_dict["hidden_size"]
  87. heads = config_dict["num_attention_heads"]
  88. intermediate_size = config_dict["intermediate_size"]
  89. intermediate_activation = config_dict["hidden_act"]
  90. super().__init__()
  91. self.embeddings = CLIPEmbeddings(embed_dim, dtype=torch.float32, device=device)
  92. self.encoder = CLIPEncoder(num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device)
  93. self.final_layer_norm = nn.LayerNorm(embed_dim, dtype=dtype, device=device)
  94. def forward(self, input_tokens, intermediate_output=None, final_layer_norm_intermediate=True):
  95. x = self.embeddings(input_tokens)
  96. causal_mask = torch.empty(x.shape[1], x.shape[1], dtype=x.dtype, device=x.device).fill_(float("-inf")).triu_(1)
  97. x, i = self.encoder(x, mask=causal_mask, intermediate_output=intermediate_output)
  98. x = self.final_layer_norm(x)
  99. if i is not None and final_layer_norm_intermediate:
  100. i = self.final_layer_norm(i)
  101. pooled_output = x[torch.arange(x.shape[0], device=x.device), input_tokens.to(dtype=torch.int, device=x.device).argmax(dim=-1),]
  102. return x, i, pooled_output
  103. class CLIPTextModel(torch.nn.Module):
  104. def __init__(self, config_dict, dtype, device):
  105. super().__init__()
  106. self.num_layers = config_dict["num_hidden_layers"]
  107. self.text_model = CLIPTextModel_(config_dict, dtype, device)
  108. embed_dim = config_dict["hidden_size"]
  109. self.text_projection = nn.Linear(embed_dim, embed_dim, bias=False, dtype=dtype, device=device)
  110. self.text_projection.weight.copy_(torch.eye(embed_dim))
  111. self.dtype = dtype
  112. def get_input_embeddings(self):
  113. return self.text_model.embeddings.token_embedding
  114. def set_input_embeddings(self, embeddings):
  115. self.text_model.embeddings.token_embedding = embeddings
  116. def forward(self, *args, **kwargs):
  117. x = self.text_model(*args, **kwargs)
  118. out = self.text_projection(x[2])
  119. return (x[0], x[1], out, x[2])
  120. class SDTokenizer:
  121. def __init__(self, max_length=77, pad_with_end=True, tokenizer=None, has_start_token=True, pad_to_max_length=True, min_length=None):
  122. self.tokenizer = tokenizer
  123. self.max_length = max_length
  124. self.min_length = min_length
  125. empty = self.tokenizer('')["input_ids"]
  126. if has_start_token:
  127. self.tokens_start = 1
  128. self.start_token = empty[0]
  129. self.end_token = empty[1]
  130. else:
  131. self.tokens_start = 0
  132. self.start_token = None
  133. self.end_token = empty[0]
  134. self.pad_with_end = pad_with_end
  135. self.pad_to_max_length = pad_to_max_length
  136. vocab = self.tokenizer.get_vocab()
  137. self.inv_vocab = {v: k for k, v in vocab.items()}
  138. self.max_word_length = 8
  139. def tokenize_with_weights(self, text:str):
  140. """Tokenize the text, with weight values - presume 1.0 for all and ignore other features here. The details aren't relevant for a reference impl, and weights themselves has weak effect on SD3."""
  141. if self.pad_with_end:
  142. pad_token = self.end_token
  143. else:
  144. pad_token = 0
  145. batch = []
  146. if self.start_token is not None:
  147. batch.append((self.start_token, 1.0))
  148. to_tokenize = text.replace("\n", " ").split(' ')
  149. to_tokenize = [x for x in to_tokenize if x != ""]
  150. for word in to_tokenize:
  151. batch.extend([(t, 1) for t in self.tokenizer(word)["input_ids"][self.tokens_start:-1]])
  152. batch.append((self.end_token, 1.0))
  153. if self.pad_to_max_length:
  154. batch.extend([(pad_token, 1.0)] * (self.max_length - len(batch)))
  155. if self.min_length is not None and len(batch) < self.min_length:
  156. batch.extend([(pad_token, 1.0)] * (self.min_length - len(batch)))
  157. return [batch]
  158. class SDXLClipGTokenizer(SDTokenizer):
  159. def __init__(self, tokenizer):
  160. super().__init__(pad_with_end=False, tokenizer=tokenizer)
  161. class SD3Tokenizer:
  162. def __init__(self):
  163. clip_tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
  164. self.clip_l = SDTokenizer(tokenizer=clip_tokenizer)
  165. self.clip_g = SDXLClipGTokenizer(clip_tokenizer)
  166. self.t5xxl = T5XXLTokenizer()
  167. def tokenize_with_weights(self, text:str):
  168. out = {}
  169. out["g"] = self.clip_g.tokenize_with_weights(text)
  170. out["l"] = self.clip_l.tokenize_with_weights(text)
  171. out["t5xxl"] = self.t5xxl.tokenize_with_weights(text)
  172. return out
  173. class ClipTokenWeightEncoder:
  174. def encode_token_weights(self, token_weight_pairs):
  175. tokens = list(map(lambda a: a[0], token_weight_pairs[0]))
  176. out, pooled = self([tokens])
  177. if pooled is not None:
  178. first_pooled = pooled[0:1].cpu()
  179. else:
  180. first_pooled = pooled
  181. output = [out[0:1]]
  182. return torch.cat(output, dim=-2).cpu(), first_pooled
  183. class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
  184. """Uses the CLIP transformer encoder for text (from huggingface)"""
  185. LAYERS = ["last", "pooled", "hidden"]
  186. def __init__(self, device="cpu", max_length=77, layer="last", layer_idx=None, textmodel_json_config=None, dtype=None, model_class=CLIPTextModel,
  187. special_tokens={"start": 49406, "end": 49407, "pad": 49407}, layer_norm_hidden_state=True, return_projected_pooled=True):
  188. super().__init__()
  189. assert layer in self.LAYERS
  190. self.transformer = model_class(textmodel_json_config, dtype, device)
  191. self.num_layers = self.transformer.num_layers
  192. self.max_length = max_length
  193. self.transformer = self.transformer.eval()
  194. for param in self.parameters():
  195. param.requires_grad = False
  196. self.layer = layer
  197. self.layer_idx = None
  198. self.special_tokens = special_tokens
  199. self.logit_scale = torch.nn.Parameter(torch.tensor(4.6055))
  200. self.layer_norm_hidden_state = layer_norm_hidden_state
  201. self.return_projected_pooled = return_projected_pooled
  202. if layer == "hidden":
  203. assert layer_idx is not None
  204. assert abs(layer_idx) < self.num_layers
  205. self.set_clip_options({"layer": layer_idx})
  206. self.options_default = (self.layer, self.layer_idx, self.return_projected_pooled)
  207. def set_clip_options(self, options):
  208. layer_idx = options.get("layer", self.layer_idx)
  209. self.return_projected_pooled = options.get("projected_pooled", self.return_projected_pooled)
  210. if layer_idx is None or abs(layer_idx) > self.num_layers:
  211. self.layer = "last"
  212. else:
  213. self.layer = "hidden"
  214. self.layer_idx = layer_idx
  215. def forward(self, tokens):
  216. backup_embeds = self.transformer.get_input_embeddings()
  217. device = backup_embeds.weight.device
  218. tokens = torch.LongTensor(tokens).to(device)
  219. outputs = self.transformer(tokens, intermediate_output=self.layer_idx, final_layer_norm_intermediate=self.layer_norm_hidden_state)
  220. self.transformer.set_input_embeddings(backup_embeds)
  221. if self.layer == "last":
  222. z = outputs[0]
  223. else:
  224. z = outputs[1]
  225. pooled_output = None
  226. if len(outputs) >= 3:
  227. if not self.return_projected_pooled and len(outputs) >= 4 and outputs[3] is not None:
  228. pooled_output = outputs[3].float()
  229. elif outputs[2] is not None:
  230. pooled_output = outputs[2].float()
  231. return z.float(), pooled_output
  232. class SDXLClipG(SDClipModel):
  233. """Wraps the CLIP-G model into the SD-CLIP-Model interface"""
  234. def __init__(self, config, device="cpu", layer="penultimate", layer_idx=None, dtype=None):
  235. if layer == "penultimate":
  236. layer="hidden"
  237. layer_idx=-2
  238. super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=config, dtype=dtype, special_tokens={"start": 49406, "end": 49407, "pad": 0}, layer_norm_hidden_state=False)
  239. class T5XXLModel(SDClipModel):
  240. """Wraps the T5-XXL model into the SD-CLIP-Model interface for convenience"""
  241. def __init__(self, config, device="cpu", layer="last", layer_idx=None, dtype=None):
  242. super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=T5)
  243. #################################################################################################
  244. ### T5 implementation, for the T5-XXL text encoder portion, largely pulled from upstream impl
  245. #################################################################################################
  246. class T5XXLTokenizer(SDTokenizer):
  247. """Wraps the T5 Tokenizer from HF into the SDTokenizer interface"""
  248. def __init__(self):
  249. super().__init__(pad_with_end=False, tokenizer=T5TokenizerFast.from_pretrained("google/t5-v1_1-xxl"), has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=77)
  250. class T5LayerNorm(torch.nn.Module):
  251. def __init__(self, hidden_size, eps=1e-6, dtype=None, device=None):
  252. super().__init__()
  253. self.weight = torch.nn.Parameter(torch.ones(hidden_size, dtype=dtype, device=device))
  254. self.variance_epsilon = eps
  255. def forward(self, x):
  256. variance = x.pow(2).mean(-1, keepdim=True)
  257. x = x * torch.rsqrt(variance + self.variance_epsilon)
  258. return self.weight.to(device=x.device, dtype=x.dtype) * x
  259. class T5DenseGatedActDense(torch.nn.Module):
  260. def __init__(self, model_dim, ff_dim, dtype, device):
  261. super().__init__()
  262. self.wi_0 = nn.Linear(model_dim, ff_dim, bias=False, dtype=dtype, device=device)
  263. self.wi_1 = nn.Linear(model_dim, ff_dim, bias=False, dtype=dtype, device=device)
  264. self.wo = nn.Linear(ff_dim, model_dim, bias=False, dtype=dtype, device=device)
  265. def forward(self, x):
  266. hidden_gelu = torch.nn.functional.gelu(self.wi_0(x), approximate="tanh")
  267. hidden_linear = self.wi_1(x)
  268. x = hidden_gelu * hidden_linear
  269. x = self.wo(x)
  270. return x
  271. class T5LayerFF(torch.nn.Module):
  272. def __init__(self, model_dim, ff_dim, dtype, device):
  273. super().__init__()
  274. self.DenseReluDense = T5DenseGatedActDense(model_dim, ff_dim, dtype, device)
  275. self.layer_norm = T5LayerNorm(model_dim, dtype=dtype, device=device)
  276. def forward(self, x):
  277. forwarded_states = self.layer_norm(x)
  278. forwarded_states = self.DenseReluDense(forwarded_states)
  279. x += forwarded_states
  280. return x
  281. class T5Attention(torch.nn.Module):
  282. def __init__(self, model_dim, inner_dim, num_heads, relative_attention_bias, dtype, device):
  283. super().__init__()
  284. # Mesh TensorFlow initialization to avoid scaling before softmax
  285. self.q = nn.Linear(model_dim, inner_dim, bias=False, dtype=dtype, device=device)
  286. self.k = nn.Linear(model_dim, inner_dim, bias=False, dtype=dtype, device=device)
  287. self.v = nn.Linear(model_dim, inner_dim, bias=False, dtype=dtype, device=device)
  288. self.o = nn.Linear(inner_dim, model_dim, bias=False, dtype=dtype, device=device)
  289. self.num_heads = num_heads
  290. self.relative_attention_bias = None
  291. if relative_attention_bias:
  292. self.relative_attention_num_buckets = 32
  293. self.relative_attention_max_distance = 128
  294. self.relative_attention_bias = torch.nn.Embedding(self.relative_attention_num_buckets, self.num_heads, device=device)
  295. @staticmethod
  296. def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128):
  297. """
  298. Adapted from Mesh Tensorflow:
  299. https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593
  300. Translate relative position to a bucket number for relative attention. The relative position is defined as
  301. memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to
  302. position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for
  303. small absolute relative_position and larger buckets for larger absolute relative_positions. All relative
  304. positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket.
  305. This should allow for more graceful generalization to longer sequences than the model has been trained on
  306. Args:
  307. relative_position: an int32 Tensor
  308. bidirectional: a boolean - whether the attention is bidirectional
  309. num_buckets: an integer
  310. max_distance: an integer
  311. Returns:
  312. a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets)
  313. """
  314. relative_buckets = 0
  315. if bidirectional:
  316. num_buckets //= 2
  317. relative_buckets += (relative_position > 0).to(torch.long) * num_buckets
  318. relative_position = torch.abs(relative_position)
  319. else:
  320. relative_position = -torch.min(relative_position, torch.zeros_like(relative_position))
  321. # now relative_position is in the range [0, inf)
  322. # half of the buckets are for exact increments in positions
  323. max_exact = num_buckets // 2
  324. is_small = relative_position < max_exact
  325. # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance
  326. relative_position_if_large = max_exact + (
  327. torch.log(relative_position.float() / max_exact)
  328. / math.log(max_distance / max_exact)
  329. * (num_buckets - max_exact)
  330. ).to(torch.long)
  331. relative_position_if_large = torch.min(relative_position_if_large, torch.full_like(relative_position_if_large, num_buckets - 1))
  332. relative_buckets += torch.where(is_small, relative_position, relative_position_if_large)
  333. return relative_buckets
  334. def compute_bias(self, query_length, key_length, device):
  335. """Compute binned relative position bias"""
  336. context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None]
  337. memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :]
  338. relative_position = memory_position - context_position # shape (query_length, key_length)
  339. relative_position_bucket = self._relative_position_bucket(
  340. relative_position, # shape (query_length, key_length)
  341. bidirectional=True,
  342. num_buckets=self.relative_attention_num_buckets,
  343. max_distance=self.relative_attention_max_distance,
  344. )
  345. values = self.relative_attention_bias(relative_position_bucket) # shape (query_length, key_length, num_heads)
  346. values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length)
  347. return values
  348. def forward(self, x, past_bias=None):
  349. q = self.q(x)
  350. k = self.k(x)
  351. v = self.v(x)
  352. if self.relative_attention_bias is not None:
  353. past_bias = self.compute_bias(x.shape[1], x.shape[1], x.device)
  354. if past_bias is not None:
  355. mask = past_bias
  356. out = attention(q, k * ((k.shape[-1] / self.num_heads) ** 0.5), v, self.num_heads, mask)
  357. return self.o(out), past_bias
  358. class T5LayerSelfAttention(torch.nn.Module):
  359. def __init__(self, model_dim, inner_dim, ff_dim, num_heads, relative_attention_bias, dtype, device):
  360. super().__init__()
  361. self.SelfAttention = T5Attention(model_dim, inner_dim, num_heads, relative_attention_bias, dtype, device)
  362. self.layer_norm = T5LayerNorm(model_dim, dtype=dtype, device=device)
  363. def forward(self, x, past_bias=None):
  364. output, past_bias = self.SelfAttention(self.layer_norm(x), past_bias=past_bias)
  365. x += output
  366. return x, past_bias
  367. class T5Block(torch.nn.Module):
  368. def __init__(self, model_dim, inner_dim, ff_dim, num_heads, relative_attention_bias, dtype, device):
  369. super().__init__()
  370. self.layer = torch.nn.ModuleList()
  371. self.layer.append(T5LayerSelfAttention(model_dim, inner_dim, ff_dim, num_heads, relative_attention_bias, dtype, device))
  372. self.layer.append(T5LayerFF(model_dim, ff_dim, dtype, device))
  373. def forward(self, x, past_bias=None):
  374. x, past_bias = self.layer[0](x, past_bias)
  375. x = self.layer[-1](x)
  376. return x, past_bias
  377. class T5Stack(torch.nn.Module):
  378. def __init__(self, num_layers, model_dim, inner_dim, ff_dim, num_heads, vocab_size, dtype, device):
  379. super().__init__()
  380. self.embed_tokens = torch.nn.Embedding(vocab_size, model_dim, device=device)
  381. self.block = torch.nn.ModuleList([T5Block(model_dim, inner_dim, ff_dim, num_heads, relative_attention_bias=(i == 0), dtype=dtype, device=device) for i in range(num_layers)])
  382. self.final_layer_norm = T5LayerNorm(model_dim, dtype=dtype, device=device)
  383. def forward(self, input_ids, intermediate_output=None, final_layer_norm_intermediate=True):
  384. intermediate = None
  385. x = self.embed_tokens(input_ids)
  386. past_bias = None
  387. for i, l in enumerate(self.block):
  388. x, past_bias = l(x, past_bias)
  389. if i == intermediate_output:
  390. intermediate = x.clone()
  391. x = self.final_layer_norm(x)
  392. if intermediate is not None and final_layer_norm_intermediate:
  393. intermediate = self.final_layer_norm(intermediate)
  394. return x, intermediate
  395. class T5(torch.nn.Module):
  396. def __init__(self, config_dict, dtype, device):
  397. super().__init__()
  398. self.num_layers = config_dict["num_layers"]
  399. self.encoder = T5Stack(self.num_layers, config_dict["d_model"], config_dict["d_model"], config_dict["d_ff"], config_dict["num_heads"], config_dict["vocab_size"], dtype, device)
  400. self.dtype = dtype
  401. def get_input_embeddings(self):
  402. return self.encoder.embed_tokens
  403. def set_input_embeddings(self, embeddings):
  404. self.encoder.embed_tokens = embeddings
  405. def forward(self, *args, **kwargs):
  406. return self.encoder(*args, **kwargs)