sd3_model.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596
  1. import contextlib
  2. import torch
  3. import k_diffusion
  4. from modules.models.sd3.sd3_impls import BaseModel, SDVAE, SD3LatentFormat
  5. from modules.models.sd3.sd3_cond import SD3Cond
  6. from modules import shared, devices
  7. class SD3Denoiser(k_diffusion.external.DiscreteSchedule):
  8. def __init__(self, inner_model, sigmas):
  9. super().__init__(sigmas, quantize=shared.opts.enable_quantization)
  10. self.inner_model = inner_model
  11. def forward(self, input, sigma, **kwargs):
  12. return self.inner_model.apply_model(input, sigma, **kwargs)
  13. class SD3Inferencer(torch.nn.Module):
  14. def __init__(self, state_dict, shift=3, use_ema=False):
  15. super().__init__()
  16. self.shift = shift
  17. with torch.no_grad():
  18. self.model = BaseModel(shift=shift, state_dict=state_dict, prefix="model.diffusion_model.", device="cpu", dtype=devices.dtype)
  19. self.first_stage_model = SDVAE(device="cpu", dtype=devices.dtype_vae)
  20. self.first_stage_model.dtype = self.model.diffusion_model.dtype
  21. self.alphas_cumprod = 1 / (self.model.model_sampling.sigmas ** 2 + 1)
  22. self.text_encoders = SD3Cond()
  23. self.cond_stage_key = 'txt'
  24. self.parameterization = "eps"
  25. self.model.conditioning_key = "crossattn"
  26. self.latent_format = SD3LatentFormat()
  27. self.latent_channels = 16
  28. @property
  29. def cond_stage_model(self):
  30. return self.text_encoders
  31. def before_load_weights(self, state_dict):
  32. self.cond_stage_model.before_load_weights(state_dict)
  33. def ema_scope(self):
  34. return contextlib.nullcontext()
  35. def get_learned_conditioning(self, batch: list[str]):
  36. return self.cond_stage_model(batch)
  37. def apply_model(self, x, t, cond):
  38. return self.model(x, t, c_crossattn=cond['crossattn'], y=cond['vector'])
  39. def decode_first_stage(self, latent):
  40. latent = self.latent_format.process_out(latent)
  41. return self.first_stage_model.decode(latent)
  42. def encode_first_stage(self, image):
  43. latent = self.first_stage_model.encode(image)
  44. return self.latent_format.process_in(latent)
  45. def get_first_stage_encoding(self, x):
  46. return x
  47. def create_denoiser(self):
  48. return SD3Denoiser(self, self.model.model_sampling.sigmas)
  49. def medvram_fields(self):
  50. return [
  51. (self, 'first_stage_model'),
  52. (self, 'text_encoders'),
  53. (self, 'model'),
  54. ]
  55. def add_noise_to_latent(self, x, noise, amount):
  56. return x * (1 - amount) + noise * amount
  57. def fix_dimensions(self, width, height):
  58. return width // 16 * 16, height // 16 * 16
  59. def diffusers_weight_mapping(self):
  60. for i in range(self.model.depth):
  61. yield f"transformer.transformer_blocks.{i}.attn.to_q", f"diffusion_model_joint_blocks_{i}_x_block_attn_qkv_q_proj"
  62. yield f"transformer.transformer_blocks.{i}.attn.to_k", f"diffusion_model_joint_blocks_{i}_x_block_attn_qkv_k_proj"
  63. yield f"transformer.transformer_blocks.{i}.attn.to_v", f"diffusion_model_joint_blocks_{i}_x_block_attn_qkv_v_proj"
  64. yield f"transformer.transformer_blocks.{i}.attn.to_out.0", f"diffusion_model_joint_blocks_{i}_x_block_attn_proj"
  65. yield f"transformer.transformer_blocks.{i}.attn.add_q_proj", f"diffusion_model_joint_blocks_{i}_context_block.attn_qkv_q_proj"
  66. yield f"transformer.transformer_blocks.{i}.attn.add_k_proj", f"diffusion_model_joint_blocks_{i}_context_block.attn_qkv_k_proj"
  67. yield f"transformer.transformer_blocks.{i}.attn.add_v_proj", f"diffusion_model_joint_blocks_{i}_context_block.attn_qkv_v_proj"
  68. yield f"transformer.transformer_blocks.{i}.attn.add_out_proj.0", f"diffusion_model_joint_blocks_{i}_context_block_attn_proj"