sd3_cond.py 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222
  1. import os
  2. import safetensors
  3. import torch
  4. import typing
  5. from transformers import CLIPTokenizer, T5TokenizerFast
  6. from modules import shared, devices, modelloader, sd_hijack_clip, prompt_parser
  7. from modules.models.sd3.other_impls import SDClipModel, SDXLClipG, T5XXLModel, SD3Tokenizer
  8. class SafetensorsMapping(typing.Mapping):
  9. def __init__(self, file):
  10. self.file = file
  11. def __len__(self):
  12. return len(self.file.keys())
  13. def __iter__(self):
  14. for key in self.file.keys():
  15. yield key
  16. def __getitem__(self, key):
  17. return self.file.get_tensor(key)
  18. CLIPL_URL = "https://huggingface.co/AUTOMATIC/stable-diffusion-3-medium-text-encoders/resolve/main/clip_l.safetensors"
  19. CLIPL_CONFIG = {
  20. "hidden_act": "quick_gelu",
  21. "hidden_size": 768,
  22. "intermediate_size": 3072,
  23. "num_attention_heads": 12,
  24. "num_hidden_layers": 12,
  25. }
  26. CLIPG_URL = "https://huggingface.co/AUTOMATIC/stable-diffusion-3-medium-text-encoders/resolve/main/clip_g.safetensors"
  27. CLIPG_CONFIG = {
  28. "hidden_act": "gelu",
  29. "hidden_size": 1280,
  30. "intermediate_size": 5120,
  31. "num_attention_heads": 20,
  32. "num_hidden_layers": 32,
  33. "textual_inversion_key": "clip_g",
  34. }
  35. T5_URL = "https://huggingface.co/AUTOMATIC/stable-diffusion-3-medium-text-encoders/resolve/main/t5xxl_fp16.safetensors"
  36. T5_CONFIG = {
  37. "d_ff": 10240,
  38. "d_model": 4096,
  39. "num_heads": 64,
  40. "num_layers": 24,
  41. "vocab_size": 32128,
  42. }
  43. class Sd3ClipLG(sd_hijack_clip.TextConditionalModel):
  44. def __init__(self, clip_l, clip_g):
  45. super().__init__()
  46. self.clip_l = clip_l
  47. self.clip_g = clip_g
  48. self.tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
  49. empty = self.tokenizer('')["input_ids"]
  50. self.id_start = empty[0]
  51. self.id_end = empty[1]
  52. self.id_pad = empty[1]
  53. self.return_pooled = True
  54. def tokenize(self, texts):
  55. return self.tokenizer(texts, truncation=False, add_special_tokens=False)["input_ids"]
  56. def encode_with_transformers(self, tokens):
  57. tokens_g = tokens.clone()
  58. for batch_pos in range(tokens_g.shape[0]):
  59. index = tokens_g[batch_pos].cpu().tolist().index(self.id_end)
  60. tokens_g[batch_pos, index+1:tokens_g.shape[1]] = 0
  61. l_out, l_pooled = self.clip_l(tokens)
  62. g_out, g_pooled = self.clip_g(tokens_g)
  63. lg_out = torch.cat([l_out, g_out], dim=-1)
  64. lg_out = torch.nn.functional.pad(lg_out, (0, 4096 - lg_out.shape[-1]))
  65. vector_out = torch.cat((l_pooled, g_pooled), dim=-1)
  66. lg_out.pooled = vector_out
  67. return lg_out
  68. def encode_embedding_init_text(self, init_text, nvpt):
  69. return torch.zeros((nvpt, 768+1280), device=devices.device) # XXX
  70. class Sd3T5(torch.nn.Module):
  71. def __init__(self, t5xxl):
  72. super().__init__()
  73. self.t5xxl = t5xxl
  74. self.tokenizer = T5TokenizerFast.from_pretrained("google/t5-v1_1-xxl")
  75. empty = self.tokenizer('', padding='max_length', max_length=2)["input_ids"]
  76. self.id_end = empty[0]
  77. self.id_pad = empty[1]
  78. def tokenize(self, texts):
  79. return self.tokenizer(texts, truncation=False, add_special_tokens=False)["input_ids"]
  80. def tokenize_line(self, line, *, target_token_count=None):
  81. if shared.opts.emphasis != "None":
  82. parsed = prompt_parser.parse_prompt_attention(line)
  83. else:
  84. parsed = [[line, 1.0]]
  85. tokenized = self.tokenize([text for text, _ in parsed])
  86. tokens = []
  87. multipliers = []
  88. for text_tokens, (text, weight) in zip(tokenized, parsed):
  89. if text == 'BREAK' and weight == -1:
  90. continue
  91. tokens += text_tokens
  92. multipliers += [weight] * len(text_tokens)
  93. tokens += [self.id_end]
  94. multipliers += [1.0]
  95. if target_token_count is not None:
  96. if len(tokens) < target_token_count:
  97. tokens += [self.id_pad] * (target_token_count - len(tokens))
  98. multipliers += [1.0] * (target_token_count - len(tokens))
  99. else:
  100. tokens = tokens[0:target_token_count]
  101. multipliers = multipliers[0:target_token_count]
  102. return tokens, multipliers
  103. def forward(self, texts, *, token_count):
  104. if not self.t5xxl or not shared.opts.sd3_enable_t5:
  105. return torch.zeros((len(texts), token_count, 4096), device=devices.device, dtype=devices.dtype)
  106. tokens_batch = []
  107. for text in texts:
  108. tokens, multipliers = self.tokenize_line(text, target_token_count=token_count)
  109. tokens_batch.append(tokens)
  110. t5_out, t5_pooled = self.t5xxl(tokens_batch)
  111. return t5_out
  112. def encode_embedding_init_text(self, init_text, nvpt):
  113. return torch.zeros((nvpt, 4096), device=devices.device) # XXX
  114. class SD3Cond(torch.nn.Module):
  115. def __init__(self, *args, **kwargs):
  116. super().__init__(*args, **kwargs)
  117. self.tokenizer = SD3Tokenizer()
  118. with torch.no_grad():
  119. self.clip_g = SDXLClipG(CLIPG_CONFIG, device="cpu", dtype=devices.dtype)
  120. self.clip_l = SDClipModel(layer="hidden", layer_idx=-2, device="cpu", dtype=devices.dtype, layer_norm_hidden_state=False, return_projected_pooled=False, textmodel_json_config=CLIPL_CONFIG)
  121. if shared.opts.sd3_enable_t5:
  122. self.t5xxl = T5XXLModel(T5_CONFIG, device="cpu", dtype=devices.dtype)
  123. else:
  124. self.t5xxl = None
  125. self.model_lg = Sd3ClipLG(self.clip_l, self.clip_g)
  126. self.model_t5 = Sd3T5(self.t5xxl)
  127. def forward(self, prompts: list[str]):
  128. with devices.without_autocast():
  129. lg_out, vector_out = self.model_lg(prompts)
  130. t5_out = self.model_t5(prompts, token_count=lg_out.shape[1])
  131. lgt_out = torch.cat([lg_out, t5_out], dim=-2)
  132. return {
  133. 'crossattn': lgt_out,
  134. 'vector': vector_out,
  135. }
  136. def before_load_weights(self, state_dict):
  137. clip_path = os.path.join(shared.models_path, "CLIP")
  138. if 'text_encoders.clip_g.transformer.text_model.embeddings.position_embedding.weight' not in state_dict:
  139. clip_g_file = modelloader.load_file_from_url(CLIPG_URL, model_dir=clip_path, file_name="clip_g.safetensors")
  140. with safetensors.safe_open(clip_g_file, framework="pt") as file:
  141. self.clip_g.transformer.load_state_dict(SafetensorsMapping(file))
  142. if 'text_encoders.clip_l.transformer.text_model.embeddings.position_embedding.weight' not in state_dict:
  143. clip_l_file = modelloader.load_file_from_url(CLIPL_URL, model_dir=clip_path, file_name="clip_l.safetensors")
  144. with safetensors.safe_open(clip_l_file, framework="pt") as file:
  145. self.clip_l.transformer.load_state_dict(SafetensorsMapping(file), strict=False)
  146. if self.t5xxl and 'text_encoders.t5xxl.transformer.encoder.embed_tokens.weight' not in state_dict:
  147. t5_file = modelloader.load_file_from_url(T5_URL, model_dir=clip_path, file_name="t5xxl_fp16.safetensors")
  148. with safetensors.safe_open(t5_file, framework="pt") as file:
  149. self.t5xxl.transformer.load_state_dict(SafetensorsMapping(file), strict=False)
  150. def encode_embedding_init_text(self, init_text, nvpt):
  151. return self.model_lg.encode_embedding_init_text(init_text, nvpt)
  152. def tokenize(self, texts):
  153. return self.model_lg.tokenize(texts)
  154. def medvram_modules(self):
  155. return [self.clip_g, self.clip_l, self.t5xxl]
  156. def get_token_count(self, text):
  157. _, token_count = self.model_lg.process_texts([text])
  158. return token_count
  159. def get_target_prompt_token_count(self, token_count):
  160. return self.model_lg.get_target_prompt_token_count(token_count)