mmdit.py 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619
  1. ### This file contains impls for MM-DiT, the core model component of SD3
  2. import math
  3. from typing import Dict, Optional
  4. import numpy as np
  5. import torch
  6. import torch.nn as nn
  7. from einops import rearrange, repeat
  8. from modules.models.sd3.other_impls import attention, Mlp
  9. class PatchEmbed(nn.Module):
  10. """ 2D Image to Patch Embedding"""
  11. def __init__(
  12. self,
  13. img_size: Optional[int] = 224,
  14. patch_size: int = 16,
  15. in_chans: int = 3,
  16. embed_dim: int = 768,
  17. flatten: bool = True,
  18. bias: bool = True,
  19. strict_img_size: bool = True,
  20. dynamic_img_pad: bool = False,
  21. dtype=None,
  22. device=None,
  23. ):
  24. super().__init__()
  25. self.patch_size = (patch_size, patch_size)
  26. if img_size is not None:
  27. self.img_size = (img_size, img_size)
  28. self.grid_size = tuple([s // p for s, p in zip(self.img_size, self.patch_size)])
  29. self.num_patches = self.grid_size[0] * self.grid_size[1]
  30. else:
  31. self.img_size = None
  32. self.grid_size = None
  33. self.num_patches = None
  34. # flatten spatial dim and transpose to channels last, kept for bwd compat
  35. self.flatten = flatten
  36. self.strict_img_size = strict_img_size
  37. self.dynamic_img_pad = dynamic_img_pad
  38. self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias, dtype=dtype, device=device)
  39. def forward(self, x):
  40. B, C, H, W = x.shape
  41. x = self.proj(x)
  42. if self.flatten:
  43. x = x.flatten(2).transpose(1, 2) # NCHW -> NLC
  44. return x
  45. def modulate(x, shift, scale):
  46. if shift is None:
  47. shift = torch.zeros_like(scale)
  48. return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
  49. #################################################################################
  50. # Sine/Cosine Positional Embedding Functions #
  51. #################################################################################
  52. def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0, scaling_factor=None, offset=None):
  53. """
  54. grid_size: int of the grid height and width
  55. return:
  56. pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
  57. """
  58. grid_h = np.arange(grid_size, dtype=np.float32)
  59. grid_w = np.arange(grid_size, dtype=np.float32)
  60. grid = np.meshgrid(grid_w, grid_h) # here w goes first
  61. grid = np.stack(grid, axis=0)
  62. if scaling_factor is not None:
  63. grid = grid / scaling_factor
  64. if offset is not None:
  65. grid = grid - offset
  66. grid = grid.reshape([2, 1, grid_size, grid_size])
  67. pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
  68. if cls_token and extra_tokens > 0:
  69. pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
  70. return pos_embed
  71. def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
  72. assert embed_dim % 2 == 0
  73. # use half of dimensions to encode grid_h
  74. emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
  75. emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
  76. emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
  77. return emb
  78. def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
  79. """
  80. embed_dim: output dimension for each position
  81. pos: a list of positions to be encoded: size (M,)
  82. out: (M, D)
  83. """
  84. assert embed_dim % 2 == 0
  85. omega = np.arange(embed_dim // 2, dtype=np.float64)
  86. omega /= embed_dim / 2.0
  87. omega = 1.0 / 10000**omega # (D/2,)
  88. pos = pos.reshape(-1) # (M,)
  89. out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
  90. emb_sin = np.sin(out) # (M, D/2)
  91. emb_cos = np.cos(out) # (M, D/2)
  92. return np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
  93. #################################################################################
  94. # Embedding Layers for Timesteps and Class Labels #
  95. #################################################################################
  96. class TimestepEmbedder(nn.Module):
  97. """Embeds scalar timesteps into vector representations."""
  98. def __init__(self, hidden_size, frequency_embedding_size=256, dtype=None, device=None):
  99. super().__init__()
  100. self.mlp = nn.Sequential(
  101. nn.Linear(frequency_embedding_size, hidden_size, bias=True, dtype=dtype, device=device),
  102. nn.SiLU(),
  103. nn.Linear(hidden_size, hidden_size, bias=True, dtype=dtype, device=device),
  104. )
  105. self.frequency_embedding_size = frequency_embedding_size
  106. @staticmethod
  107. def timestep_embedding(t, dim, max_period=10000):
  108. """
  109. Create sinusoidal timestep embeddings.
  110. :param t: a 1-D Tensor of N indices, one per batch element.
  111. These may be fractional.
  112. :param dim: the dimension of the output.
  113. :param max_period: controls the minimum frequency of the embeddings.
  114. :return: an (N, D) Tensor of positional embeddings.
  115. """
  116. half = dim // 2
  117. freqs = torch.exp(
  118. -math.log(max_period)
  119. * torch.arange(start=0, end=half, dtype=torch.float32)
  120. / half
  121. ).to(device=t.device)
  122. args = t[:, None].float() * freqs[None]
  123. embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
  124. if dim % 2:
  125. embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
  126. if torch.is_floating_point(t):
  127. embedding = embedding.to(dtype=t.dtype)
  128. return embedding
  129. def forward(self, t, dtype, **kwargs):
  130. t_freq = self.timestep_embedding(t, self.frequency_embedding_size).to(dtype)
  131. t_emb = self.mlp(t_freq)
  132. return t_emb
  133. class VectorEmbedder(nn.Module):
  134. """Embeds a flat vector of dimension input_dim"""
  135. def __init__(self, input_dim: int, hidden_size: int, dtype=None, device=None):
  136. super().__init__()
  137. self.mlp = nn.Sequential(
  138. nn.Linear(input_dim, hidden_size, bias=True, dtype=dtype, device=device),
  139. nn.SiLU(),
  140. nn.Linear(hidden_size, hidden_size, bias=True, dtype=dtype, device=device),
  141. )
  142. def forward(self, x: torch.Tensor) -> torch.Tensor:
  143. return self.mlp(x)
  144. #################################################################################
  145. # Core DiT Model #
  146. #################################################################################
  147. def split_qkv(qkv, head_dim):
  148. qkv = qkv.reshape(qkv.shape[0], qkv.shape[1], 3, -1, head_dim).movedim(2, 0)
  149. return qkv[0], qkv[1], qkv[2]
  150. def optimized_attention(qkv, num_heads):
  151. return attention(qkv[0], qkv[1], qkv[2], num_heads)
  152. class SelfAttention(nn.Module):
  153. ATTENTION_MODES = ("xformers", "torch", "torch-hb", "math", "debug")
  154. def __init__(
  155. self,
  156. dim: int,
  157. num_heads: int = 8,
  158. qkv_bias: bool = False,
  159. qk_scale: Optional[float] = None,
  160. attn_mode: str = "xformers",
  161. pre_only: bool = False,
  162. qk_norm: Optional[str] = None,
  163. rmsnorm: bool = False,
  164. dtype=None,
  165. device=None,
  166. ):
  167. super().__init__()
  168. self.num_heads = num_heads
  169. self.head_dim = dim // num_heads
  170. self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias, dtype=dtype, device=device)
  171. if not pre_only:
  172. self.proj = nn.Linear(dim, dim, dtype=dtype, device=device)
  173. assert attn_mode in self.ATTENTION_MODES
  174. self.attn_mode = attn_mode
  175. self.pre_only = pre_only
  176. if qk_norm == "rms":
  177. self.ln_q = RMSNorm(self.head_dim, elementwise_affine=True, eps=1.0e-6, dtype=dtype, device=device)
  178. self.ln_k = RMSNorm(self.head_dim, elementwise_affine=True, eps=1.0e-6, dtype=dtype, device=device)
  179. elif qk_norm == "ln":
  180. self.ln_q = nn.LayerNorm(self.head_dim, elementwise_affine=True, eps=1.0e-6, dtype=dtype, device=device)
  181. self.ln_k = nn.LayerNorm(self.head_dim, elementwise_affine=True, eps=1.0e-6, dtype=dtype, device=device)
  182. elif qk_norm is None:
  183. self.ln_q = nn.Identity()
  184. self.ln_k = nn.Identity()
  185. else:
  186. raise ValueError(qk_norm)
  187. def pre_attention(self, x: torch.Tensor):
  188. B, L, C = x.shape
  189. qkv = self.qkv(x)
  190. q, k, v = split_qkv(qkv, self.head_dim)
  191. q = self.ln_q(q).reshape(q.shape[0], q.shape[1], -1)
  192. k = self.ln_k(k).reshape(q.shape[0], q.shape[1], -1)
  193. return (q, k, v)
  194. def post_attention(self, x: torch.Tensor) -> torch.Tensor:
  195. assert not self.pre_only
  196. x = self.proj(x)
  197. return x
  198. def forward(self, x: torch.Tensor) -> torch.Tensor:
  199. (q, k, v) = self.pre_attention(x)
  200. x = attention(q, k, v, self.num_heads)
  201. x = self.post_attention(x)
  202. return x
  203. class RMSNorm(torch.nn.Module):
  204. def __init__(
  205. self, dim: int, elementwise_affine: bool = False, eps: float = 1e-6, device=None, dtype=None
  206. ):
  207. """
  208. Initialize the RMSNorm normalization layer.
  209. Args:
  210. dim (int): The dimension of the input tensor.
  211. eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
  212. Attributes:
  213. eps (float): A small value added to the denominator for numerical stability.
  214. weight (nn.Parameter): Learnable scaling parameter.
  215. """
  216. super().__init__()
  217. self.eps = eps
  218. self.learnable_scale = elementwise_affine
  219. if self.learnable_scale:
  220. self.weight = nn.Parameter(torch.empty(dim, device=device, dtype=dtype))
  221. else:
  222. self.register_parameter("weight", None)
  223. def _norm(self, x):
  224. """
  225. Apply the RMSNorm normalization to the input tensor.
  226. Args:
  227. x (torch.Tensor): The input tensor.
  228. Returns:
  229. torch.Tensor: The normalized tensor.
  230. """
  231. return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
  232. def forward(self, x):
  233. """
  234. Forward pass through the RMSNorm layer.
  235. Args:
  236. x (torch.Tensor): The input tensor.
  237. Returns:
  238. torch.Tensor: The output tensor after applying RMSNorm.
  239. """
  240. x = self._norm(x)
  241. if self.learnable_scale:
  242. return x * self.weight.to(device=x.device, dtype=x.dtype)
  243. else:
  244. return x
  245. class SwiGLUFeedForward(nn.Module):
  246. def __init__(
  247. self,
  248. dim: int,
  249. hidden_dim: int,
  250. multiple_of: int,
  251. ffn_dim_multiplier: Optional[float] = None,
  252. ):
  253. """
  254. Initialize the FeedForward module.
  255. Args:
  256. dim (int): Input dimension.
  257. hidden_dim (int): Hidden dimension of the feedforward layer.
  258. multiple_of (int): Value to ensure hidden dimension is a multiple of this value.
  259. ffn_dim_multiplier (float, optional): Custom multiplier for hidden dimension. Defaults to None.
  260. Attributes:
  261. w1 (ColumnParallelLinear): Linear transformation for the first layer.
  262. w2 (RowParallelLinear): Linear transformation for the second layer.
  263. w3 (ColumnParallelLinear): Linear transformation for the third layer.
  264. """
  265. super().__init__()
  266. hidden_dim = int(2 * hidden_dim / 3)
  267. # custom dim factor multiplier
  268. if ffn_dim_multiplier is not None:
  269. hidden_dim = int(ffn_dim_multiplier * hidden_dim)
  270. hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
  271. self.w1 = nn.Linear(dim, hidden_dim, bias=False)
  272. self.w2 = nn.Linear(hidden_dim, dim, bias=False)
  273. self.w3 = nn.Linear(dim, hidden_dim, bias=False)
  274. def forward(self, x):
  275. return self.w2(nn.functional.silu(self.w1(x)) * self.w3(x))
  276. class DismantledBlock(nn.Module):
  277. """A DiT block with gated adaptive layer norm (adaLN) conditioning."""
  278. ATTENTION_MODES = ("xformers", "torch", "torch-hb", "math", "debug")
  279. def __init__(
  280. self,
  281. hidden_size: int,
  282. num_heads: int,
  283. mlp_ratio: float = 4.0,
  284. attn_mode: str = "xformers",
  285. qkv_bias: bool = False,
  286. pre_only: bool = False,
  287. rmsnorm: bool = False,
  288. scale_mod_only: bool = False,
  289. swiglu: bool = False,
  290. qk_norm: Optional[str] = None,
  291. dtype=None,
  292. device=None,
  293. **block_kwargs,
  294. ):
  295. super().__init__()
  296. assert attn_mode in self.ATTENTION_MODES
  297. if not rmsnorm:
  298. self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
  299. else:
  300. self.norm1 = RMSNorm(hidden_size, elementwise_affine=False, eps=1e-6)
  301. self.attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, attn_mode=attn_mode, pre_only=pre_only, qk_norm=qk_norm, rmsnorm=rmsnorm, dtype=dtype, device=device)
  302. if not pre_only:
  303. if not rmsnorm:
  304. self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
  305. else:
  306. self.norm2 = RMSNorm(hidden_size, elementwise_affine=False, eps=1e-6)
  307. mlp_hidden_dim = int(hidden_size * mlp_ratio)
  308. if not pre_only:
  309. if not swiglu:
  310. self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=nn.GELU(approximate="tanh"), dtype=dtype, device=device)
  311. else:
  312. self.mlp = SwiGLUFeedForward(dim=hidden_size, hidden_dim=mlp_hidden_dim, multiple_of=256)
  313. self.scale_mod_only = scale_mod_only
  314. if not scale_mod_only:
  315. n_mods = 6 if not pre_only else 2
  316. else:
  317. n_mods = 4 if not pre_only else 1
  318. self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, n_mods * hidden_size, bias=True, dtype=dtype, device=device))
  319. self.pre_only = pre_only
  320. def pre_attention(self, x: torch.Tensor, c: torch.Tensor):
  321. assert x is not None, "pre_attention called with None input"
  322. if not self.pre_only:
  323. if not self.scale_mod_only:
  324. shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1)
  325. else:
  326. shift_msa = None
  327. shift_mlp = None
  328. scale_msa, gate_msa, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(4, dim=1)
  329. qkv = self.attn.pre_attention(modulate(self.norm1(x), shift_msa, scale_msa))
  330. return qkv, (x, gate_msa, shift_mlp, scale_mlp, gate_mlp)
  331. else:
  332. if not self.scale_mod_only:
  333. shift_msa, scale_msa = self.adaLN_modulation(c).chunk(2, dim=1)
  334. else:
  335. shift_msa = None
  336. scale_msa = self.adaLN_modulation(c)
  337. qkv = self.attn.pre_attention(modulate(self.norm1(x), shift_msa, scale_msa))
  338. return qkv, None
  339. def post_attention(self, attn, x, gate_msa, shift_mlp, scale_mlp, gate_mlp):
  340. assert not self.pre_only
  341. x = x + gate_msa.unsqueeze(1) * self.attn.post_attention(attn)
  342. x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
  343. return x
  344. def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
  345. assert not self.pre_only
  346. (q, k, v), intermediates = self.pre_attention(x, c)
  347. attn = attention(q, k, v, self.attn.num_heads)
  348. return self.post_attention(attn, *intermediates)
  349. def block_mixing(context, x, context_block, x_block, c):
  350. assert context is not None, "block_mixing called with None context"
  351. context_qkv, context_intermediates = context_block.pre_attention(context, c)
  352. x_qkv, x_intermediates = x_block.pre_attention(x, c)
  353. o = []
  354. for t in range(3):
  355. o.append(torch.cat((context_qkv[t], x_qkv[t]), dim=1))
  356. q, k, v = tuple(o)
  357. attn = attention(q, k, v, x_block.attn.num_heads)
  358. context_attn, x_attn = (attn[:, : context_qkv[0].shape[1]], attn[:, context_qkv[0].shape[1] :])
  359. if not context_block.pre_only:
  360. context = context_block.post_attention(context_attn, *context_intermediates)
  361. else:
  362. context = None
  363. x = x_block.post_attention(x_attn, *x_intermediates)
  364. return context, x
  365. class JointBlock(nn.Module):
  366. """just a small wrapper to serve as a fsdp unit"""
  367. def __init__(self, *args, **kwargs):
  368. super().__init__()
  369. pre_only = kwargs.pop("pre_only")
  370. qk_norm = kwargs.pop("qk_norm", None)
  371. self.context_block = DismantledBlock(*args, pre_only=pre_only, qk_norm=qk_norm, **kwargs)
  372. self.x_block = DismantledBlock(*args, pre_only=False, qk_norm=qk_norm, **kwargs)
  373. def forward(self, *args, **kwargs):
  374. return block_mixing(*args, context_block=self.context_block, x_block=self.x_block, **kwargs)
  375. class FinalLayer(nn.Module):
  376. """
  377. The final layer of DiT.
  378. """
  379. def __init__(self, hidden_size: int, patch_size: int, out_channels: int, total_out_channels: Optional[int] = None, dtype=None, device=None):
  380. super().__init__()
  381. self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
  382. self.linear = (
  383. nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True, dtype=dtype, device=device)
  384. if (total_out_channels is None)
  385. else nn.Linear(hidden_size, total_out_channels, bias=True, dtype=dtype, device=device)
  386. )
  387. self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True, dtype=dtype, device=device))
  388. def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
  389. shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
  390. x = modulate(self.norm_final(x), shift, scale)
  391. x = self.linear(x)
  392. return x
  393. class MMDiT(nn.Module):
  394. """Diffusion model with a Transformer backbone."""
  395. def __init__(
  396. self,
  397. input_size: int = 32,
  398. patch_size: int = 2,
  399. in_channels: int = 4,
  400. depth: int = 28,
  401. mlp_ratio: float = 4.0,
  402. learn_sigma: bool = False,
  403. adm_in_channels: Optional[int] = None,
  404. context_embedder_config: Optional[Dict] = None,
  405. register_length: int = 0,
  406. attn_mode: str = "torch",
  407. rmsnorm: bool = False,
  408. scale_mod_only: bool = False,
  409. swiglu: bool = False,
  410. out_channels: Optional[int] = None,
  411. pos_embed_scaling_factor: Optional[float] = None,
  412. pos_embed_offset: Optional[float] = None,
  413. pos_embed_max_size: Optional[int] = None,
  414. num_patches = None,
  415. qk_norm: Optional[str] = None,
  416. qkv_bias: bool = True,
  417. dtype = None,
  418. device = None,
  419. ):
  420. super().__init__()
  421. self.dtype = dtype
  422. self.learn_sigma = learn_sigma
  423. self.in_channels = in_channels
  424. default_out_channels = in_channels * 2 if learn_sigma else in_channels
  425. self.out_channels = out_channels if out_channels is not None else default_out_channels
  426. self.patch_size = patch_size
  427. self.pos_embed_scaling_factor = pos_embed_scaling_factor
  428. self.pos_embed_offset = pos_embed_offset
  429. self.pos_embed_max_size = pos_embed_max_size
  430. # apply magic --> this defines a head_size of 64
  431. hidden_size = 64 * depth
  432. num_heads = depth
  433. self.num_heads = num_heads
  434. self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size, bias=True, strict_img_size=self.pos_embed_max_size is None, dtype=dtype, device=device)
  435. self.t_embedder = TimestepEmbedder(hidden_size, dtype=dtype, device=device)
  436. if adm_in_channels is not None:
  437. assert isinstance(adm_in_channels, int)
  438. self.y_embedder = VectorEmbedder(adm_in_channels, hidden_size, dtype=dtype, device=device)
  439. self.context_embedder = nn.Identity()
  440. if context_embedder_config is not None:
  441. if context_embedder_config["target"] == "torch.nn.Linear":
  442. self.context_embedder = nn.Linear(**context_embedder_config["params"], dtype=dtype, device=device)
  443. self.register_length = register_length
  444. if self.register_length > 0:
  445. self.register = nn.Parameter(torch.randn(1, register_length, hidden_size, dtype=dtype, device=device))
  446. # num_patches = self.x_embedder.num_patches
  447. # Will use fixed sin-cos embedding:
  448. # just use a buffer already
  449. if num_patches is not None:
  450. self.register_buffer(
  451. "pos_embed",
  452. torch.zeros(1, num_patches, hidden_size, dtype=dtype, device=device),
  453. )
  454. else:
  455. self.pos_embed = None
  456. self.joint_blocks = nn.ModuleList(
  457. [
  458. JointBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, attn_mode=attn_mode, pre_only=i == depth - 1, rmsnorm=rmsnorm, scale_mod_only=scale_mod_only, swiglu=swiglu, qk_norm=qk_norm, dtype=dtype, device=device)
  459. for i in range(depth)
  460. ]
  461. )
  462. self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels, dtype=dtype, device=device)
  463. def cropped_pos_embed(self, hw):
  464. assert self.pos_embed_max_size is not None
  465. p = self.x_embedder.patch_size[0]
  466. h, w = hw
  467. # patched size
  468. h = h // p
  469. w = w // p
  470. assert h <= self.pos_embed_max_size, (h, self.pos_embed_max_size)
  471. assert w <= self.pos_embed_max_size, (w, self.pos_embed_max_size)
  472. top = (self.pos_embed_max_size - h) // 2
  473. left = (self.pos_embed_max_size - w) // 2
  474. spatial_pos_embed = rearrange(
  475. self.pos_embed,
  476. "1 (h w) c -> 1 h w c",
  477. h=self.pos_embed_max_size,
  478. w=self.pos_embed_max_size,
  479. )
  480. spatial_pos_embed = spatial_pos_embed[:, top : top + h, left : left + w, :]
  481. spatial_pos_embed = rearrange(spatial_pos_embed, "1 h w c -> 1 (h w) c")
  482. return spatial_pos_embed
  483. def unpatchify(self, x, hw=None):
  484. """
  485. x: (N, T, patch_size**2 * C)
  486. imgs: (N, H, W, C)
  487. """
  488. c = self.out_channels
  489. p = self.x_embedder.patch_size[0]
  490. if hw is None:
  491. h = w = int(x.shape[1] ** 0.5)
  492. else:
  493. h, w = hw
  494. h = h // p
  495. w = w // p
  496. assert h * w == x.shape[1]
  497. x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
  498. x = torch.einsum("nhwpqc->nchpwq", x)
  499. imgs = x.reshape(shape=(x.shape[0], c, h * p, w * p))
  500. return imgs
  501. def forward_core_with_concat(self, x: torch.Tensor, c_mod: torch.Tensor, context: Optional[torch.Tensor] = None) -> torch.Tensor:
  502. if self.register_length > 0:
  503. context = torch.cat((repeat(self.register, "1 ... -> b ...", b=x.shape[0]), context if context is not None else torch.Tensor([]).type_as(x)), 1)
  504. # context is B, L', D
  505. # x is B, L, D
  506. for block in self.joint_blocks:
  507. context, x = block(context, x, c=c_mod)
  508. x = self.final_layer(x, c_mod) # (N, T, patch_size ** 2 * out_channels)
  509. return x
  510. def forward(self, x: torch.Tensor, t: torch.Tensor, y: Optional[torch.Tensor] = None, context: Optional[torch.Tensor] = None) -> torch.Tensor:
  511. """
  512. Forward pass of DiT.
  513. x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
  514. t: (N,) tensor of diffusion timesteps
  515. y: (N,) tensor of class labels
  516. """
  517. hw = x.shape[-2:]
  518. x = self.x_embedder(x) + self.cropped_pos_embed(hw)
  519. c = self.t_embedder(t, dtype=x.dtype) # (N, D)
  520. if y is not None:
  521. y = self.y_embedder(y) # (N, D)
  522. c = c + y # (N, D)
  523. context = self.context_embedder(context)
  524. x = self.forward_core_with_concat(x, c, context)
  525. x = self.unpatchify(x, hw=hw) # (N, out_channels, H, W)
  526. return x