mmdit.py 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622
  1. ### This file contains impls for MM-DiT, the core model component of SD3
  2. import math
  3. from typing import Dict, Optional
  4. import numpy as np
  5. import torch
  6. import torch.nn as nn
  7. from einops import rearrange, repeat
  8. from modules.models.sd3.other_impls import attention, Mlp
  9. class PatchEmbed(nn.Module):
  10. """ 2D Image to Patch Embedding"""
  11. def __init__(
  12. self,
  13. img_size: Optional[int] = 224,
  14. patch_size: int = 16,
  15. in_chans: int = 3,
  16. embed_dim: int = 768,
  17. flatten: bool = True,
  18. bias: bool = True,
  19. strict_img_size: bool = True,
  20. dynamic_img_pad: bool = False,
  21. dtype=None,
  22. device=None,
  23. ):
  24. super().__init__()
  25. self.patch_size = (patch_size, patch_size)
  26. if img_size is not None:
  27. self.img_size = (img_size, img_size)
  28. self.grid_size = tuple([s // p for s, p in zip(self.img_size, self.patch_size)])
  29. self.num_patches = self.grid_size[0] * self.grid_size[1]
  30. else:
  31. self.img_size = None
  32. self.grid_size = None
  33. self.num_patches = None
  34. # flatten spatial dim and transpose to channels last, kept for bwd compat
  35. self.flatten = flatten
  36. self.strict_img_size = strict_img_size
  37. self.dynamic_img_pad = dynamic_img_pad
  38. self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias, dtype=dtype, device=device)
  39. def forward(self, x):
  40. B, C, H, W = x.shape
  41. x = self.proj(x)
  42. if self.flatten:
  43. x = x.flatten(2).transpose(1, 2) # NCHW -> NLC
  44. return x
  45. def modulate(x, shift, scale):
  46. if shift is None:
  47. shift = torch.zeros_like(scale)
  48. return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
  49. #################################################################################
  50. # Sine/Cosine Positional Embedding Functions #
  51. #################################################################################
  52. def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0, scaling_factor=None, offset=None):
  53. """
  54. grid_size: int of the grid height and width
  55. return:
  56. pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
  57. """
  58. grid_h = np.arange(grid_size, dtype=np.float32)
  59. grid_w = np.arange(grid_size, dtype=np.float32)
  60. grid = np.meshgrid(grid_w, grid_h) # here w goes first
  61. grid = np.stack(grid, axis=0)
  62. if scaling_factor is not None:
  63. grid = grid / scaling_factor
  64. if offset is not None:
  65. grid = grid - offset
  66. grid = grid.reshape([2, 1, grid_size, grid_size])
  67. pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
  68. if cls_token and extra_tokens > 0:
  69. pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
  70. return pos_embed
  71. def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
  72. assert embed_dim % 2 == 0
  73. # use half of dimensions to encode grid_h
  74. emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
  75. emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
  76. emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
  77. return emb
  78. def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
  79. """
  80. embed_dim: output dimension for each position
  81. pos: a list of positions to be encoded: size (M,)
  82. out: (M, D)
  83. """
  84. assert embed_dim % 2 == 0
  85. omega = np.arange(embed_dim // 2, dtype=np.float64)
  86. omega /= embed_dim / 2.0
  87. omega = 1.0 / 10000**omega # (D/2,)
  88. pos = pos.reshape(-1) # (M,)
  89. out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
  90. emb_sin = np.sin(out) # (M, D/2)
  91. emb_cos = np.cos(out) # (M, D/2)
  92. return np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
  93. #################################################################################
  94. # Embedding Layers for Timesteps and Class Labels #
  95. #################################################################################
  96. class TimestepEmbedder(nn.Module):
  97. """Embeds scalar timesteps into vector representations."""
  98. def __init__(self, hidden_size, frequency_embedding_size=256, dtype=None, device=None):
  99. super().__init__()
  100. self.mlp = nn.Sequential(
  101. nn.Linear(frequency_embedding_size, hidden_size, bias=True, dtype=dtype, device=device),
  102. nn.SiLU(),
  103. nn.Linear(hidden_size, hidden_size, bias=True, dtype=dtype, device=device),
  104. )
  105. self.frequency_embedding_size = frequency_embedding_size
  106. @staticmethod
  107. def timestep_embedding(t, dim, max_period=10000):
  108. """
  109. Create sinusoidal timestep embeddings.
  110. :param t: a 1-D Tensor of N indices, one per batch element.
  111. These may be fractional.
  112. :param dim: the dimension of the output.
  113. :param max_period: controls the minimum frequency of the embeddings.
  114. :return: an (N, D) Tensor of positional embeddings.
  115. """
  116. half = dim // 2
  117. freqs = torch.exp(
  118. -math.log(max_period)
  119. * torch.arange(start=0, end=half, dtype=torch.float32)
  120. / half
  121. ).to(device=t.device)
  122. args = t[:, None].float() * freqs[None]
  123. embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
  124. if dim % 2:
  125. embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
  126. if torch.is_floating_point(t):
  127. embedding = embedding.to(dtype=t.dtype)
  128. return embedding
  129. def forward(self, t, dtype, **kwargs):
  130. t_freq = self.timestep_embedding(t, self.frequency_embedding_size).to(dtype)
  131. t_emb = self.mlp(t_freq)
  132. return t_emb
  133. class VectorEmbedder(nn.Module):
  134. """Embeds a flat vector of dimension input_dim"""
  135. def __init__(self, input_dim: int, hidden_size: int, dtype=None, device=None):
  136. super().__init__()
  137. self.mlp = nn.Sequential(
  138. nn.Linear(input_dim, hidden_size, bias=True, dtype=dtype, device=device),
  139. nn.SiLU(),
  140. nn.Linear(hidden_size, hidden_size, bias=True, dtype=dtype, device=device),
  141. )
  142. def forward(self, x: torch.Tensor) -> torch.Tensor:
  143. return self.mlp(x)
  144. #################################################################################
  145. # Core DiT Model #
  146. #################################################################################
  147. class QkvLinear(torch.nn.Linear):
  148. pass
  149. def split_qkv(qkv, head_dim):
  150. qkv = qkv.reshape(qkv.shape[0], qkv.shape[1], 3, -1, head_dim).movedim(2, 0)
  151. return qkv[0], qkv[1], qkv[2]
  152. def optimized_attention(qkv, num_heads):
  153. return attention(qkv[0], qkv[1], qkv[2], num_heads)
  154. class SelfAttention(nn.Module):
  155. ATTENTION_MODES = ("xformers", "torch", "torch-hb", "math", "debug")
  156. def __init__(
  157. self,
  158. dim: int,
  159. num_heads: int = 8,
  160. qkv_bias: bool = False,
  161. qk_scale: Optional[float] = None,
  162. attn_mode: str = "xformers",
  163. pre_only: bool = False,
  164. qk_norm: Optional[str] = None,
  165. rmsnorm: bool = False,
  166. dtype=None,
  167. device=None,
  168. ):
  169. super().__init__()
  170. self.num_heads = num_heads
  171. self.head_dim = dim // num_heads
  172. self.qkv = QkvLinear(dim, dim * 3, bias=qkv_bias, dtype=dtype, device=device)
  173. if not pre_only:
  174. self.proj = nn.Linear(dim, dim, dtype=dtype, device=device)
  175. assert attn_mode in self.ATTENTION_MODES
  176. self.attn_mode = attn_mode
  177. self.pre_only = pre_only
  178. if qk_norm == "rms":
  179. self.ln_q = RMSNorm(self.head_dim, elementwise_affine=True, eps=1.0e-6, dtype=dtype, device=device)
  180. self.ln_k = RMSNorm(self.head_dim, elementwise_affine=True, eps=1.0e-6, dtype=dtype, device=device)
  181. elif qk_norm == "ln":
  182. self.ln_q = nn.LayerNorm(self.head_dim, elementwise_affine=True, eps=1.0e-6, dtype=dtype, device=device)
  183. self.ln_k = nn.LayerNorm(self.head_dim, elementwise_affine=True, eps=1.0e-6, dtype=dtype, device=device)
  184. elif qk_norm is None:
  185. self.ln_q = nn.Identity()
  186. self.ln_k = nn.Identity()
  187. else:
  188. raise ValueError(qk_norm)
  189. def pre_attention(self, x: torch.Tensor):
  190. B, L, C = x.shape
  191. qkv = self.qkv(x)
  192. q, k, v = split_qkv(qkv, self.head_dim)
  193. q = self.ln_q(q).reshape(q.shape[0], q.shape[1], -1)
  194. k = self.ln_k(k).reshape(q.shape[0], q.shape[1], -1)
  195. return (q, k, v)
  196. def post_attention(self, x: torch.Tensor) -> torch.Tensor:
  197. assert not self.pre_only
  198. x = self.proj(x)
  199. return x
  200. def forward(self, x: torch.Tensor) -> torch.Tensor:
  201. (q, k, v) = self.pre_attention(x)
  202. x = attention(q, k, v, self.num_heads)
  203. x = self.post_attention(x)
  204. return x
  205. class RMSNorm(torch.nn.Module):
  206. def __init__(
  207. self, dim: int, elementwise_affine: bool = False, eps: float = 1e-6, device=None, dtype=None
  208. ):
  209. """
  210. Initialize the RMSNorm normalization layer.
  211. Args:
  212. dim (int): The dimension of the input tensor.
  213. eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
  214. Attributes:
  215. eps (float): A small value added to the denominator for numerical stability.
  216. weight (nn.Parameter): Learnable scaling parameter.
  217. """
  218. super().__init__()
  219. self.eps = eps
  220. self.learnable_scale = elementwise_affine
  221. if self.learnable_scale:
  222. self.weight = nn.Parameter(torch.empty(dim, device=device, dtype=dtype))
  223. else:
  224. self.register_parameter("weight", None)
  225. def _norm(self, x):
  226. """
  227. Apply the RMSNorm normalization to the input tensor.
  228. Args:
  229. x (torch.Tensor): The input tensor.
  230. Returns:
  231. torch.Tensor: The normalized tensor.
  232. """
  233. return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
  234. def forward(self, x):
  235. """
  236. Forward pass through the RMSNorm layer.
  237. Args:
  238. x (torch.Tensor): The input tensor.
  239. Returns:
  240. torch.Tensor: The output tensor after applying RMSNorm.
  241. """
  242. x = self._norm(x)
  243. if self.learnable_scale:
  244. return x * self.weight.to(device=x.device, dtype=x.dtype)
  245. else:
  246. return x
  247. class SwiGLUFeedForward(nn.Module):
  248. def __init__(
  249. self,
  250. dim: int,
  251. hidden_dim: int,
  252. multiple_of: int,
  253. ffn_dim_multiplier: Optional[float] = None,
  254. ):
  255. """
  256. Initialize the FeedForward module.
  257. Args:
  258. dim (int): Input dimension.
  259. hidden_dim (int): Hidden dimension of the feedforward layer.
  260. multiple_of (int): Value to ensure hidden dimension is a multiple of this value.
  261. ffn_dim_multiplier (float, optional): Custom multiplier for hidden dimension. Defaults to None.
  262. Attributes:
  263. w1 (ColumnParallelLinear): Linear transformation for the first layer.
  264. w2 (RowParallelLinear): Linear transformation for the second layer.
  265. w3 (ColumnParallelLinear): Linear transformation for the third layer.
  266. """
  267. super().__init__()
  268. hidden_dim = int(2 * hidden_dim / 3)
  269. # custom dim factor multiplier
  270. if ffn_dim_multiplier is not None:
  271. hidden_dim = int(ffn_dim_multiplier * hidden_dim)
  272. hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
  273. self.w1 = nn.Linear(dim, hidden_dim, bias=False)
  274. self.w2 = nn.Linear(hidden_dim, dim, bias=False)
  275. self.w3 = nn.Linear(dim, hidden_dim, bias=False)
  276. def forward(self, x):
  277. return self.w2(nn.functional.silu(self.w1(x)) * self.w3(x))
  278. class DismantledBlock(nn.Module):
  279. """A DiT block with gated adaptive layer norm (adaLN) conditioning."""
  280. ATTENTION_MODES = ("xformers", "torch", "torch-hb", "math", "debug")
  281. def __init__(
  282. self,
  283. hidden_size: int,
  284. num_heads: int,
  285. mlp_ratio: float = 4.0,
  286. attn_mode: str = "xformers",
  287. qkv_bias: bool = False,
  288. pre_only: bool = False,
  289. rmsnorm: bool = False,
  290. scale_mod_only: bool = False,
  291. swiglu: bool = False,
  292. qk_norm: Optional[str] = None,
  293. dtype=None,
  294. device=None,
  295. **block_kwargs,
  296. ):
  297. super().__init__()
  298. assert attn_mode in self.ATTENTION_MODES
  299. if not rmsnorm:
  300. self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
  301. else:
  302. self.norm1 = RMSNorm(hidden_size, elementwise_affine=False, eps=1e-6)
  303. self.attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, attn_mode=attn_mode, pre_only=pre_only, qk_norm=qk_norm, rmsnorm=rmsnorm, dtype=dtype, device=device)
  304. if not pre_only:
  305. if not rmsnorm:
  306. self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
  307. else:
  308. self.norm2 = RMSNorm(hidden_size, elementwise_affine=False, eps=1e-6)
  309. mlp_hidden_dim = int(hidden_size * mlp_ratio)
  310. if not pre_only:
  311. if not swiglu:
  312. self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=nn.GELU(approximate="tanh"), dtype=dtype, device=device)
  313. else:
  314. self.mlp = SwiGLUFeedForward(dim=hidden_size, hidden_dim=mlp_hidden_dim, multiple_of=256)
  315. self.scale_mod_only = scale_mod_only
  316. if not scale_mod_only:
  317. n_mods = 6 if not pre_only else 2
  318. else:
  319. n_mods = 4 if not pre_only else 1
  320. self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, n_mods * hidden_size, bias=True, dtype=dtype, device=device))
  321. self.pre_only = pre_only
  322. def pre_attention(self, x: torch.Tensor, c: torch.Tensor):
  323. assert x is not None, "pre_attention called with None input"
  324. if not self.pre_only:
  325. if not self.scale_mod_only:
  326. shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1)
  327. else:
  328. shift_msa = None
  329. shift_mlp = None
  330. scale_msa, gate_msa, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(4, dim=1)
  331. qkv = self.attn.pre_attention(modulate(self.norm1(x), shift_msa, scale_msa))
  332. return qkv, (x, gate_msa, shift_mlp, scale_mlp, gate_mlp)
  333. else:
  334. if not self.scale_mod_only:
  335. shift_msa, scale_msa = self.adaLN_modulation(c).chunk(2, dim=1)
  336. else:
  337. shift_msa = None
  338. scale_msa = self.adaLN_modulation(c)
  339. qkv = self.attn.pre_attention(modulate(self.norm1(x), shift_msa, scale_msa))
  340. return qkv, None
  341. def post_attention(self, attn, x, gate_msa, shift_mlp, scale_mlp, gate_mlp):
  342. assert not self.pre_only
  343. x = x + gate_msa.unsqueeze(1) * self.attn.post_attention(attn)
  344. x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
  345. return x
  346. def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
  347. assert not self.pre_only
  348. (q, k, v), intermediates = self.pre_attention(x, c)
  349. attn = attention(q, k, v, self.attn.num_heads)
  350. return self.post_attention(attn, *intermediates)
  351. def block_mixing(context, x, context_block, x_block, c):
  352. assert context is not None, "block_mixing called with None context"
  353. context_qkv, context_intermediates = context_block.pre_attention(context, c)
  354. x_qkv, x_intermediates = x_block.pre_attention(x, c)
  355. o = []
  356. for t in range(3):
  357. o.append(torch.cat((context_qkv[t], x_qkv[t]), dim=1))
  358. q, k, v = tuple(o)
  359. attn = attention(q, k, v, x_block.attn.num_heads)
  360. context_attn, x_attn = (attn[:, : context_qkv[0].shape[1]], attn[:, context_qkv[0].shape[1] :])
  361. if not context_block.pre_only:
  362. context = context_block.post_attention(context_attn, *context_intermediates)
  363. else:
  364. context = None
  365. x = x_block.post_attention(x_attn, *x_intermediates)
  366. return context, x
  367. class JointBlock(nn.Module):
  368. """just a small wrapper to serve as a fsdp unit"""
  369. def __init__(self, *args, **kwargs):
  370. super().__init__()
  371. pre_only = kwargs.pop("pre_only")
  372. qk_norm = kwargs.pop("qk_norm", None)
  373. self.context_block = DismantledBlock(*args, pre_only=pre_only, qk_norm=qk_norm, **kwargs)
  374. self.x_block = DismantledBlock(*args, pre_only=False, qk_norm=qk_norm, **kwargs)
  375. def forward(self, *args, **kwargs):
  376. return block_mixing(*args, context_block=self.context_block, x_block=self.x_block, **kwargs)
  377. class FinalLayer(nn.Module):
  378. """
  379. The final layer of DiT.
  380. """
  381. def __init__(self, hidden_size: int, patch_size: int, out_channels: int, total_out_channels: Optional[int] = None, dtype=None, device=None):
  382. super().__init__()
  383. self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
  384. self.linear = (
  385. nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True, dtype=dtype, device=device)
  386. if (total_out_channels is None)
  387. else nn.Linear(hidden_size, total_out_channels, bias=True, dtype=dtype, device=device)
  388. )
  389. self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True, dtype=dtype, device=device))
  390. def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
  391. shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
  392. x = modulate(self.norm_final(x), shift, scale)
  393. x = self.linear(x)
  394. return x
  395. class MMDiT(nn.Module):
  396. """Diffusion model with a Transformer backbone."""
  397. def __init__(
  398. self,
  399. input_size: int = 32,
  400. patch_size: int = 2,
  401. in_channels: int = 4,
  402. depth: int = 28,
  403. mlp_ratio: float = 4.0,
  404. learn_sigma: bool = False,
  405. adm_in_channels: Optional[int] = None,
  406. context_embedder_config: Optional[Dict] = None,
  407. register_length: int = 0,
  408. attn_mode: str = "torch",
  409. rmsnorm: bool = False,
  410. scale_mod_only: bool = False,
  411. swiglu: bool = False,
  412. out_channels: Optional[int] = None,
  413. pos_embed_scaling_factor: Optional[float] = None,
  414. pos_embed_offset: Optional[float] = None,
  415. pos_embed_max_size: Optional[int] = None,
  416. num_patches = None,
  417. qk_norm: Optional[str] = None,
  418. qkv_bias: bool = True,
  419. dtype = None,
  420. device = None,
  421. ):
  422. super().__init__()
  423. self.dtype = dtype
  424. self.learn_sigma = learn_sigma
  425. self.in_channels = in_channels
  426. default_out_channels = in_channels * 2 if learn_sigma else in_channels
  427. self.out_channels = out_channels if out_channels is not None else default_out_channels
  428. self.patch_size = patch_size
  429. self.pos_embed_scaling_factor = pos_embed_scaling_factor
  430. self.pos_embed_offset = pos_embed_offset
  431. self.pos_embed_max_size = pos_embed_max_size
  432. # apply magic --> this defines a head_size of 64
  433. hidden_size = 64 * depth
  434. num_heads = depth
  435. self.num_heads = num_heads
  436. self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size, bias=True, strict_img_size=self.pos_embed_max_size is None, dtype=dtype, device=device)
  437. self.t_embedder = TimestepEmbedder(hidden_size, dtype=dtype, device=device)
  438. if adm_in_channels is not None:
  439. assert isinstance(adm_in_channels, int)
  440. self.y_embedder = VectorEmbedder(adm_in_channels, hidden_size, dtype=dtype, device=device)
  441. self.context_embedder = nn.Identity()
  442. if context_embedder_config is not None:
  443. if context_embedder_config["target"] == "torch.nn.Linear":
  444. self.context_embedder = nn.Linear(**context_embedder_config["params"], dtype=dtype, device=device)
  445. self.register_length = register_length
  446. if self.register_length > 0:
  447. self.register = nn.Parameter(torch.randn(1, register_length, hidden_size, dtype=dtype, device=device))
  448. # num_patches = self.x_embedder.num_patches
  449. # Will use fixed sin-cos embedding:
  450. # just use a buffer already
  451. if num_patches is not None:
  452. self.register_buffer(
  453. "pos_embed",
  454. torch.zeros(1, num_patches, hidden_size, dtype=dtype, device=device),
  455. )
  456. else:
  457. self.pos_embed = None
  458. self.joint_blocks = nn.ModuleList(
  459. [
  460. JointBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, attn_mode=attn_mode, pre_only=i == depth - 1, rmsnorm=rmsnorm, scale_mod_only=scale_mod_only, swiglu=swiglu, qk_norm=qk_norm, dtype=dtype, device=device)
  461. for i in range(depth)
  462. ]
  463. )
  464. self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels, dtype=dtype, device=device)
  465. def cropped_pos_embed(self, hw):
  466. assert self.pos_embed_max_size is not None
  467. p = self.x_embedder.patch_size[0]
  468. h, w = hw
  469. # patched size
  470. h = h // p
  471. w = w // p
  472. assert h <= self.pos_embed_max_size, (h, self.pos_embed_max_size)
  473. assert w <= self.pos_embed_max_size, (w, self.pos_embed_max_size)
  474. top = (self.pos_embed_max_size - h) // 2
  475. left = (self.pos_embed_max_size - w) // 2
  476. spatial_pos_embed = rearrange(
  477. self.pos_embed,
  478. "1 (h w) c -> 1 h w c",
  479. h=self.pos_embed_max_size,
  480. w=self.pos_embed_max_size,
  481. )
  482. spatial_pos_embed = spatial_pos_embed[:, top : top + h, left : left + w, :]
  483. spatial_pos_embed = rearrange(spatial_pos_embed, "1 h w c -> 1 (h w) c")
  484. return spatial_pos_embed
  485. def unpatchify(self, x, hw=None):
  486. """
  487. x: (N, T, patch_size**2 * C)
  488. imgs: (N, H, W, C)
  489. """
  490. c = self.out_channels
  491. p = self.x_embedder.patch_size[0]
  492. if hw is None:
  493. h = w = int(x.shape[1] ** 0.5)
  494. else:
  495. h, w = hw
  496. h = h // p
  497. w = w // p
  498. assert h * w == x.shape[1]
  499. x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
  500. x = torch.einsum("nhwpqc->nchpwq", x)
  501. imgs = x.reshape(shape=(x.shape[0], c, h * p, w * p))
  502. return imgs
  503. def forward_core_with_concat(self, x: torch.Tensor, c_mod: torch.Tensor, context: Optional[torch.Tensor] = None) -> torch.Tensor:
  504. if self.register_length > 0:
  505. context = torch.cat((repeat(self.register, "1 ... -> b ...", b=x.shape[0]), context if context is not None else torch.Tensor([]).type_as(x)), 1)
  506. # context is B, L', D
  507. # x is B, L, D
  508. for block in self.joint_blocks:
  509. context, x = block(context, x, c=c_mod)
  510. x = self.final_layer(x, c_mod) # (N, T, patch_size ** 2 * out_channels)
  511. return x
  512. def forward(self, x: torch.Tensor, t: torch.Tensor, y: Optional[torch.Tensor] = None, context: Optional[torch.Tensor] = None) -> torch.Tensor:
  513. """
  514. Forward pass of DiT.
  515. x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
  516. t: (N,) tensor of diffusion timesteps
  517. y: (N,) tensor of class labels
  518. """
  519. hw = x.shape[-2:]
  520. x = self.x_embedder(x) + self.cropped_pos_embed(hw)
  521. c = self.t_embedder(t, dtype=x.dtype) # (N, D)
  522. if y is not None:
  523. y = self.y_embedder(y) # (N, D)
  524. c = c + y # (N, D)
  525. context = self.context_embedder(context)
  526. x = self.forward_core_with_concat(x, c, context)
  527. x = self.unpatchify(x, hw=hw) # (N, out_channels, H, W)
  528. return x