sd3_impls.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371
  1. ### Impls of the SD3 core diffusion model and VAE
  2. import torch, math, einops
  3. from mmdit import MMDiT
  4. from PIL import Image
  5. #################################################################################################
  6. ### MMDiT Model Wrapping
  7. #################################################################################################
  8. class ModelSamplingDiscreteFlow(torch.nn.Module):
  9. """Helper for sampler scheduling (ie timestep/sigma calculations) for Discrete Flow models"""
  10. def __init__(self, shift=1.0):
  11. super().__init__()
  12. self.shift = shift
  13. timesteps = 1000
  14. ts = self.sigma(torch.arange(1, timesteps + 1, 1))
  15. self.register_buffer('sigmas', ts)
  16. @property
  17. def sigma_min(self):
  18. return self.sigmas[0]
  19. @property
  20. def sigma_max(self):
  21. return self.sigmas[-1]
  22. def timestep(self, sigma):
  23. return sigma * 1000
  24. def sigma(self, timestep: torch.Tensor):
  25. timestep = timestep / 1000.0
  26. if self.shift == 1.0:
  27. return timestep
  28. return self.shift * timestep / (1 + (self.shift - 1) * timestep)
  29. def calculate_denoised(self, sigma, model_output, model_input):
  30. sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1))
  31. return model_input - model_output * sigma
  32. def noise_scaling(self, sigma, noise, latent_image, max_denoise=False):
  33. return sigma * noise + (1.0 - sigma) * latent_image
  34. class BaseModel(torch.nn.Module):
  35. """Wrapper around the core MM-DiT model"""
  36. def __init__(self, shift=1.0, device=None, dtype=torch.float32, file=None, prefix=""):
  37. super().__init__()
  38. # Important configuration values can be quickly determined by checking shapes in the source file
  39. # Some of these will vary between models (eg 2B vs 8B primarily differ in their depth, but also other details change)
  40. patch_size = file.get_tensor(f"{prefix}x_embedder.proj.weight").shape[2]
  41. depth = file.get_tensor(f"{prefix}x_embedder.proj.weight").shape[0] // 64
  42. num_patches = file.get_tensor(f"{prefix}pos_embed").shape[1]
  43. pos_embed_max_size = round(math.sqrt(num_patches))
  44. adm_in_channels = file.get_tensor(f"{prefix}y_embedder.mlp.0.weight").shape[1]
  45. context_shape = file.get_tensor(f"{prefix}context_embedder.weight").shape
  46. context_embedder_config = {
  47. "target": "torch.nn.Linear",
  48. "params": {
  49. "in_features": context_shape[1],
  50. "out_features": context_shape[0]
  51. }
  52. }
  53. self.diffusion_model = MMDiT(input_size=None, pos_embed_scaling_factor=None, pos_embed_offset=None, pos_embed_max_size=pos_embed_max_size, patch_size=patch_size, in_channels=16, depth=depth, num_patches=num_patches, adm_in_channels=adm_in_channels, context_embedder_config=context_embedder_config, device=device, dtype=dtype)
  54. self.model_sampling = ModelSamplingDiscreteFlow(shift=shift)
  55. def apply_model(self, x, sigma, c_crossattn=None, y=None):
  56. dtype = self.get_dtype()
  57. timestep = self.model_sampling.timestep(sigma).float()
  58. model_output = self.diffusion_model(x.to(dtype), timestep, context=c_crossattn.to(dtype), y=y.to(dtype)).float()
  59. return self.model_sampling.calculate_denoised(sigma, model_output, x)
  60. def forward(self, *args, **kwargs):
  61. return self.apply_model(*args, **kwargs)
  62. def get_dtype(self):
  63. return self.diffusion_model.dtype
  64. class CFGDenoiser(torch.nn.Module):
  65. """Helper for applying CFG Scaling to diffusion outputs"""
  66. def __init__(self, model):
  67. super().__init__()
  68. self.model = model
  69. def forward(self, x, timestep, cond, uncond, cond_scale):
  70. # Run cond and uncond in a batch together
  71. batched = self.model.apply_model(torch.cat([x, x]), torch.cat([timestep, timestep]), c_crossattn=torch.cat([cond["c_crossattn"], uncond["c_crossattn"]]), y=torch.cat([cond["y"], uncond["y"]]))
  72. # Then split and apply CFG Scaling
  73. pos_out, neg_out = batched.chunk(2)
  74. scaled = neg_out + (pos_out - neg_out) * cond_scale
  75. return scaled
  76. class SD3LatentFormat:
  77. """Latents are slightly shifted from center - this class must be called after VAE Decode to correct for the shift"""
  78. def __init__(self):
  79. self.scale_factor = 1.5305
  80. self.shift_factor = 0.0609
  81. def process_in(self, latent):
  82. return (latent - self.shift_factor) * self.scale_factor
  83. def process_out(self, latent):
  84. return (latent / self.scale_factor) + self.shift_factor
  85. def decode_latent_to_preview(self, x0):
  86. """Quick RGB approximate preview of sd3 latents"""
  87. factors = torch.tensor([
  88. [-0.0645, 0.0177, 0.1052], [ 0.0028, 0.0312, 0.0650],
  89. [ 0.1848, 0.0762, 0.0360], [ 0.0944, 0.0360, 0.0889],
  90. [ 0.0897, 0.0506, -0.0364], [-0.0020, 0.1203, 0.0284],
  91. [ 0.0855, 0.0118, 0.0283], [-0.0539, 0.0658, 0.1047],
  92. [-0.0057, 0.0116, 0.0700], [-0.0412, 0.0281, -0.0039],
  93. [ 0.1106, 0.1171, 0.1220], [-0.0248, 0.0682, -0.0481],
  94. [ 0.0815, 0.0846, 0.1207], [-0.0120, -0.0055, -0.0867],
  95. [-0.0749, -0.0634, -0.0456], [-0.1418, -0.1457, -0.1259]
  96. ], device="cpu")
  97. latent_image = x0[0].permute(1, 2, 0).cpu() @ factors
  98. latents_ubyte = (((latent_image + 1) / 2)
  99. .clamp(0, 1) # change scale from -1..1 to 0..1
  100. .mul(0xFF) # to 0..255
  101. .byte()).cpu()
  102. return Image.fromarray(latents_ubyte.numpy())
  103. #################################################################################################
  104. ### K-Diffusion Sampling
  105. #################################################################################################
  106. def append_dims(x, target_dims):
  107. """Appends dimensions to the end of a tensor until it has target_dims dimensions."""
  108. dims_to_append = target_dims - x.ndim
  109. return x[(...,) + (None,) * dims_to_append]
  110. def to_d(x, sigma, denoised):
  111. """Converts a denoiser output to a Karras ODE derivative."""
  112. return (x - denoised) / append_dims(sigma, x.ndim)
  113. @torch.no_grad()
  114. @torch.autocast("cuda", dtype=torch.float16)
  115. def sample_euler(model, x, sigmas, extra_args=None):
  116. """Implements Algorithm 2 (Euler steps) from Karras et al. (2022)."""
  117. extra_args = {} if extra_args is None else extra_args
  118. s_in = x.new_ones([x.shape[0]])
  119. for i in range(len(sigmas) - 1):
  120. sigma_hat = sigmas[i]
  121. denoised = model(x, sigma_hat * s_in, **extra_args)
  122. d = to_d(x, sigma_hat, denoised)
  123. dt = sigmas[i + 1] - sigma_hat
  124. # Euler method
  125. x = x + d * dt
  126. return x
  127. #################################################################################################
  128. ### VAE
  129. #################################################################################################
  130. def Normalize(in_channels, num_groups=32, dtype=torch.float32, device=None):
  131. return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype, device=device)
  132. class ResnetBlock(torch.nn.Module):
  133. def __init__(self, *, in_channels, out_channels=None, dtype=torch.float32, device=None):
  134. super().__init__()
  135. self.in_channels = in_channels
  136. out_channels = in_channels if out_channels is None else out_channels
  137. self.out_channels = out_channels
  138. self.norm1 = Normalize(in_channels, dtype=dtype, device=device)
  139. self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, dtype=dtype, device=device)
  140. self.norm2 = Normalize(out_channels, dtype=dtype, device=device)
  141. self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, dtype=dtype, device=device)
  142. if self.in_channels != self.out_channels:
  143. self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, dtype=dtype, device=device)
  144. else:
  145. self.nin_shortcut = None
  146. self.swish = torch.nn.SiLU(inplace=True)
  147. def forward(self, x):
  148. hidden = x
  149. hidden = self.norm1(hidden)
  150. hidden = self.swish(hidden)
  151. hidden = self.conv1(hidden)
  152. hidden = self.norm2(hidden)
  153. hidden = self.swish(hidden)
  154. hidden = self.conv2(hidden)
  155. if self.in_channels != self.out_channels:
  156. x = self.nin_shortcut(x)
  157. return x + hidden
  158. class AttnBlock(torch.nn.Module):
  159. def __init__(self, in_channels, dtype=torch.float32, device=None):
  160. super().__init__()
  161. self.norm = Normalize(in_channels, dtype=dtype, device=device)
  162. self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0, dtype=dtype, device=device)
  163. self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0, dtype=dtype, device=device)
  164. self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0, dtype=dtype, device=device)
  165. self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0, dtype=dtype, device=device)
  166. def forward(self, x):
  167. hidden = self.norm(x)
  168. q = self.q(hidden)
  169. k = self.k(hidden)
  170. v = self.v(hidden)
  171. b, c, h, w = q.shape
  172. q, k, v = map(lambda x: einops.rearrange(x, "b c h w -> b 1 (h w) c").contiguous(), (q, k, v))
  173. hidden = torch.nn.functional.scaled_dot_product_attention(q, k, v) # scale is dim ** -0.5 per default
  174. hidden = einops.rearrange(hidden, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b)
  175. hidden = self.proj_out(hidden)
  176. return x + hidden
  177. class Downsample(torch.nn.Module):
  178. def __init__(self, in_channels, dtype=torch.float32, device=None):
  179. super().__init__()
  180. self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0, dtype=dtype, device=device)
  181. def forward(self, x):
  182. pad = (0,1,0,1)
  183. x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
  184. x = self.conv(x)
  185. return x
  186. class Upsample(torch.nn.Module):
  187. def __init__(self, in_channels, dtype=torch.float32, device=None):
  188. super().__init__()
  189. self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1, dtype=dtype, device=device)
  190. def forward(self, x):
  191. x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
  192. x = self.conv(x)
  193. return x
  194. class VAEEncoder(torch.nn.Module):
  195. def __init__(self, ch=128, ch_mult=(1,2,4,4), num_res_blocks=2, in_channels=3, z_channels=16, dtype=torch.float32, device=None):
  196. super().__init__()
  197. self.num_resolutions = len(ch_mult)
  198. self.num_res_blocks = num_res_blocks
  199. # downsampling
  200. self.conv_in = torch.nn.Conv2d(in_channels, ch, kernel_size=3, stride=1, padding=1, dtype=dtype, device=device)
  201. in_ch_mult = (1,) + tuple(ch_mult)
  202. self.in_ch_mult = in_ch_mult
  203. self.down = torch.nn.ModuleList()
  204. for i_level in range(self.num_resolutions):
  205. block = torch.nn.ModuleList()
  206. attn = torch.nn.ModuleList()
  207. block_in = ch*in_ch_mult[i_level]
  208. block_out = ch*ch_mult[i_level]
  209. for i_block in range(num_res_blocks):
  210. block.append(ResnetBlock(in_channels=block_in, out_channels=block_out, dtype=dtype, device=device))
  211. block_in = block_out
  212. down = torch.nn.Module()
  213. down.block = block
  214. down.attn = attn
  215. if i_level != self.num_resolutions - 1:
  216. down.downsample = Downsample(block_in, dtype=dtype, device=device)
  217. self.down.append(down)
  218. # middle
  219. self.mid = torch.nn.Module()
  220. self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in, dtype=dtype, device=device)
  221. self.mid.attn_1 = AttnBlock(block_in, dtype=dtype, device=device)
  222. self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in, dtype=dtype, device=device)
  223. # end
  224. self.norm_out = Normalize(block_in, dtype=dtype, device=device)
  225. self.conv_out = torch.nn.Conv2d(block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1, dtype=dtype, device=device)
  226. self.swish = torch.nn.SiLU(inplace=True)
  227. def forward(self, x):
  228. # downsampling
  229. hs = [self.conv_in(x)]
  230. for i_level in range(self.num_resolutions):
  231. for i_block in range(self.num_res_blocks):
  232. h = self.down[i_level].block[i_block](hs[-1])
  233. hs.append(h)
  234. if i_level != self.num_resolutions-1:
  235. hs.append(self.down[i_level].downsample(hs[-1]))
  236. # middle
  237. h = hs[-1]
  238. h = self.mid.block_1(h)
  239. h = self.mid.attn_1(h)
  240. h = self.mid.block_2(h)
  241. # end
  242. h = self.norm_out(h)
  243. h = self.swish(h)
  244. h = self.conv_out(h)
  245. return h
  246. class VAEDecoder(torch.nn.Module):
  247. def __init__(self, ch=128, out_ch=3, ch_mult=(1, 2, 4, 4), num_res_blocks=2, resolution=256, z_channels=16, dtype=torch.float32, device=None):
  248. super().__init__()
  249. self.num_resolutions = len(ch_mult)
  250. self.num_res_blocks = num_res_blocks
  251. block_in = ch * ch_mult[self.num_resolutions - 1]
  252. curr_res = resolution // 2 ** (self.num_resolutions - 1)
  253. # z to block_in
  254. self.conv_in = torch.nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1, dtype=dtype, device=device)
  255. # middle
  256. self.mid = torch.nn.Module()
  257. self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in, dtype=dtype, device=device)
  258. self.mid.attn_1 = AttnBlock(block_in, dtype=dtype, device=device)
  259. self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in, dtype=dtype, device=device)
  260. # upsampling
  261. self.up = torch.nn.ModuleList()
  262. for i_level in reversed(range(self.num_resolutions)):
  263. block = torch.nn.ModuleList()
  264. block_out = ch * ch_mult[i_level]
  265. for i_block in range(self.num_res_blocks + 1):
  266. block.append(ResnetBlock(in_channels=block_in, out_channels=block_out, dtype=dtype, device=device))
  267. block_in = block_out
  268. up = torch.nn.Module()
  269. up.block = block
  270. if i_level != 0:
  271. up.upsample = Upsample(block_in, dtype=dtype, device=device)
  272. curr_res = curr_res * 2
  273. self.up.insert(0, up) # prepend to get consistent order
  274. # end
  275. self.norm_out = Normalize(block_in, dtype=dtype, device=device)
  276. self.conv_out = torch.nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1, dtype=dtype, device=device)
  277. self.swish = torch.nn.SiLU(inplace=True)
  278. def forward(self, z):
  279. # z to block_in
  280. hidden = self.conv_in(z)
  281. # middle
  282. hidden = self.mid.block_1(hidden)
  283. hidden = self.mid.attn_1(hidden)
  284. hidden = self.mid.block_2(hidden)
  285. # upsampling
  286. for i_level in reversed(range(self.num_resolutions)):
  287. for i_block in range(self.num_res_blocks + 1):
  288. hidden = self.up[i_level].block[i_block](hidden)
  289. if i_level != 0:
  290. hidden = self.up[i_level].upsample(hidden)
  291. # end
  292. hidden = self.norm_out(hidden)
  293. hidden = self.swish(hidden)
  294. hidden = self.conv_out(hidden)
  295. return hidden
  296. class SDVAE(torch.nn.Module):
  297. def __init__(self, dtype=torch.float32, device=None):
  298. super().__init__()
  299. self.encoder = VAEEncoder(dtype=dtype, device=device)
  300. self.decoder = VAEDecoder(dtype=dtype, device=device)
  301. @torch.autocast("cuda", dtype=torch.float16)
  302. def decode(self, latent):
  303. return self.decoder(latent)
  304. @torch.autocast("cuda", dtype=torch.float16)
  305. def encode(self, image):
  306. hidden = self.encoder(image)
  307. mean, logvar = torch.chunk(hidden, 2, dim=1)
  308. logvar = torch.clamp(logvar, -30.0, 20.0)
  309. std = torch.exp(0.5 * logvar)
  310. return mean + std * torch.randn_like(mean)