sd3_model.py 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166
  1. import contextlib
  2. import os
  3. from typing import Mapping
  4. import safetensors
  5. import torch
  6. import k_diffusion
  7. from modules.models.sd3.other_impls import SDClipModel, SDXLClipG, T5XXLModel, SD3Tokenizer
  8. from modules.models.sd3.sd3_impls import BaseModel, SDVAE, SD3LatentFormat
  9. from modules import shared, modelloader, devices
  10. CLIPG_URL = "https://huggingface.co/stabilityai/stable-diffusion-3-medium/resolve/main/text_encoders/clip_g.safetensors"
  11. CLIPG_CONFIG = {
  12. "hidden_act": "gelu",
  13. "hidden_size": 1280,
  14. "intermediate_size": 5120,
  15. "num_attention_heads": 20,
  16. "num_hidden_layers": 32,
  17. }
  18. CLIPL_URL = "https://huggingface.co/stabilityai/stable-diffusion-3-medium/resolve/main/text_encoders/clip_l.safetensors"
  19. CLIPL_CONFIG = {
  20. "hidden_act": "quick_gelu",
  21. "hidden_size": 768,
  22. "intermediate_size": 3072,
  23. "num_attention_heads": 12,
  24. "num_hidden_layers": 12,
  25. }
  26. T5_URL = "https://huggingface.co/stabilityai/stable-diffusion-3-medium/resolve/main/text_encoders/t5xxl_fp16.safetensors"
  27. T5_CONFIG = {
  28. "d_ff": 10240,
  29. "d_model": 4096,
  30. "num_heads": 64,
  31. "num_layers": 24,
  32. "vocab_size": 32128,
  33. }
  34. class SafetensorsMapping(Mapping):
  35. def __init__(self, file):
  36. self.file = file
  37. def __len__(self):
  38. return len(self.file.keys())
  39. def __iter__(self):
  40. for key in self.file.keys():
  41. yield key
  42. def __getitem__(self, key):
  43. return self.file.get_tensor(key)
  44. class SD3Cond(torch.nn.Module):
  45. def __init__(self, *args, **kwargs):
  46. super().__init__(*args, **kwargs)
  47. self.tokenizer = SD3Tokenizer()
  48. with torch.no_grad():
  49. self.clip_g = SDXLClipG(CLIPG_CONFIG, device="cpu", dtype=devices.dtype)
  50. self.clip_l = SDClipModel(layer="hidden", layer_idx=-2, device="cpu", dtype=devices.dtype, layer_norm_hidden_state=False, return_projected_pooled=False, textmodel_json_config=CLIPL_CONFIG)
  51. self.t5xxl = T5XXLModel(T5_CONFIG, device="cpu", dtype=devices.dtype)
  52. self.weights_loaded = False
  53. def forward(self, prompts: list[str]):
  54. res = []
  55. for prompt in prompts:
  56. tokens = self.tokenizer.tokenize_with_weights(prompt)
  57. l_out, l_pooled = self.clip_l.encode_token_weights(tokens["l"])
  58. g_out, g_pooled = self.clip_g.encode_token_weights(tokens["g"])
  59. t5_out, t5_pooled = self.t5xxl.encode_token_weights(tokens["t5xxl"])
  60. lg_out = torch.cat([l_out, g_out], dim=-1)
  61. lg_out = torch.nn.functional.pad(lg_out, (0, 4096 - lg_out.shape[-1]))
  62. lgt_out = torch.cat([lg_out, t5_out], dim=-2)
  63. vector_out = torch.cat((l_pooled, g_pooled), dim=-1)
  64. res.append({
  65. 'crossattn': lgt_out[0].to(devices.device),
  66. 'vector': vector_out[0].to(devices.device),
  67. })
  68. return res
  69. def load_weights(self):
  70. if self.weights_loaded:
  71. return
  72. clip_path = os.path.join(shared.models_path, "CLIP")
  73. clip_g_file = modelloader.load_file_from_url(CLIPG_URL, model_dir=clip_path, file_name="clip_g.safetensors")
  74. with safetensors.safe_open(clip_g_file, framework="pt") as file:
  75. self.clip_g.transformer.load_state_dict(SafetensorsMapping(file))
  76. clip_l_file = modelloader.load_file_from_url(CLIPL_URL, model_dir=clip_path, file_name="clip_l.safetensors")
  77. with safetensors.safe_open(clip_l_file, framework="pt") as file:
  78. self.clip_l.transformer.load_state_dict(SafetensorsMapping(file), strict=False)
  79. t5_file = modelloader.load_file_from_url(T5_URL, model_dir=clip_path, file_name="t5xxl_fp16.safetensors")
  80. with safetensors.safe_open(t5_file, framework="pt") as file:
  81. self.t5xxl.transformer.load_state_dict(SafetensorsMapping(file), strict=False)
  82. self.weights_loaded = True
  83. def encode_embedding_init_text(self, init_text, nvpt):
  84. return torch.tensor([[0]], device=devices.device) # XXX
  85. class SD3Denoiser(k_diffusion.external.DiscreteSchedule):
  86. def __init__(self, inner_model, sigmas):
  87. super().__init__(sigmas, quantize=shared.opts.enable_quantization)
  88. self.inner_model = inner_model
  89. def forward(self, input, sigma, **kwargs):
  90. return self.inner_model.apply_model(input, sigma, **kwargs)
  91. class SD3Inferencer(torch.nn.Module):
  92. def __init__(self, state_dict, shift=3, use_ema=False):
  93. super().__init__()
  94. self.shift = shift
  95. with torch.no_grad():
  96. self.model = BaseModel(shift=shift, state_dict=state_dict, prefix="model.diffusion_model.", device="cpu", dtype=devices.dtype)
  97. self.first_stage_model = SDVAE(device="cpu", dtype=devices.dtype_vae)
  98. self.first_stage_model.dtype = self.model.diffusion_model.dtype
  99. self.alphas_cumprod = 1 / (self.model.model_sampling.sigmas ** 2 + 1)
  100. self.cond_stage_model = SD3Cond()
  101. self.cond_stage_key = 'txt'
  102. self.parameterization = "eps"
  103. self.model.conditioning_key = "crossattn"
  104. self.latent_format = SD3LatentFormat()
  105. self.latent_channels = 16
  106. def after_load_weights(self):
  107. self.cond_stage_model.load_weights()
  108. def ema_scope(self):
  109. return contextlib.nullcontext()
  110. def get_learned_conditioning(self, batch: list[str]):
  111. return self.cond_stage_model(batch)
  112. def apply_model(self, x, t, cond):
  113. return self.model.apply_model(x, t, c_crossattn=cond['crossattn'], y=cond['vector'])
  114. def decode_first_stage(self, latent):
  115. latent = self.latent_format.process_out(latent)
  116. return self.first_stage_model.decode(latent)
  117. def encode_first_stage(self, image):
  118. latent = self.first_stage_model.encode(image)
  119. return self.latent_format.process_in(latent)
  120. def create_denoiser(self):
  121. return SD3Denoiser(self, self.model.model_sampling.sigmas)