sd_hijack.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321
  1. import math
  2. import os
  3. import sys
  4. import traceback
  5. import torch
  6. import numpy as np
  7. from torch import einsum
  8. from torch.nn.functional import silu
  9. import modules.textual_inversion.textual_inversion
  10. from modules import prompt_parser, devices, sd_hijack_optimizations, shared, hypernetwork
  11. from modules.shared import opts, device, cmd_opts
  12. import ldm.modules.attention
  13. import ldm.modules.diffusionmodules.model
  14. attention_CrossAttention_forward = ldm.modules.attention.CrossAttention.forward
  15. diffusionmodules_model_nonlinearity = ldm.modules.diffusionmodules.model.nonlinearity
  16. diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.AttnBlock.forward
  17. def apply_optimizations():
  18. undo_optimizations()
  19. ldm.modules.diffusionmodules.model.nonlinearity = silu
  20. if cmd_opts.opt_split_attention_v1:
  21. ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1
  22. elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention or torch.cuda.is_available()):
  23. ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward
  24. ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.cross_attention_attnblock_forward
  25. def undo_optimizations():
  26. ldm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward
  27. ldm.modules.diffusionmodules.model.nonlinearity = diffusionmodules_model_nonlinearity
  28. ldm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward
  29. class StableDiffusionModelHijack:
  30. fixes = None
  31. comments = []
  32. layers = None
  33. circular_enabled = False
  34. clip = None
  35. embedding_db = modules.textual_inversion.textual_inversion.EmbeddingDatabase(cmd_opts.embeddings_dir)
  36. def hijack(self, m):
  37. model_embeddings = m.cond_stage_model.transformer.text_model.embeddings
  38. model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self)
  39. m.cond_stage_model = FrozenCLIPEmbedderWithCustomWords(m.cond_stage_model, self)
  40. self.clip = m.cond_stage_model
  41. apply_optimizations()
  42. def flatten(el):
  43. flattened = [flatten(children) for children in el.children()]
  44. res = [el]
  45. for c in flattened:
  46. res += c
  47. return res
  48. self.layers = flatten(m)
  49. def undo_hijack(self, m):
  50. if type(m.cond_stage_model) == FrozenCLIPEmbedderWithCustomWords:
  51. m.cond_stage_model = m.cond_stage_model.wrapped
  52. model_embeddings = m.cond_stage_model.transformer.text_model.embeddings
  53. if type(model_embeddings.token_embedding) == EmbeddingsWithFixes:
  54. model_embeddings.token_embedding = model_embeddings.token_embedding.wrapped
  55. def apply_circular(self, enable):
  56. if self.circular_enabled == enable:
  57. return
  58. self.circular_enabled = enable
  59. for layer in [layer for layer in self.layers if type(layer) == torch.nn.Conv2d]:
  60. layer.padding_mode = 'circular' if enable else 'zeros'
  61. def tokenize(self, text):
  62. max_length = self.clip.max_length - 2
  63. _, remade_batch_tokens, _, _, _, token_count = self.clip.process_text([text])
  64. return remade_batch_tokens[0], token_count, max_length
  65. class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
  66. def __init__(self, wrapped, hijack):
  67. super().__init__()
  68. self.wrapped = wrapped
  69. self.hijack: StableDiffusionModelHijack = hijack
  70. self.tokenizer = wrapped.tokenizer
  71. self.max_length = wrapped.max_length
  72. self.token_mults = {}
  73. tokens_with_parens = [(k, v) for k, v in self.tokenizer.get_vocab().items() if '(' in k or ')' in k or '[' in k or ']' in k]
  74. for text, ident in tokens_with_parens:
  75. mult = 1.0
  76. for c in text:
  77. if c == '[':
  78. mult /= 1.1
  79. if c == ']':
  80. mult *= 1.1
  81. if c == '(':
  82. mult *= 1.1
  83. if c == ')':
  84. mult /= 1.1
  85. if mult != 1.0:
  86. self.token_mults[ident] = mult
  87. def tokenize_line(self, line, used_custom_terms, hijack_comments):
  88. id_start = self.wrapped.tokenizer.bos_token_id
  89. id_end = self.wrapped.tokenizer.eos_token_id
  90. maxlen = self.wrapped.max_length
  91. if opts.enable_emphasis:
  92. parsed = prompt_parser.parse_prompt_attention(line)
  93. else:
  94. parsed = [[line, 1.0]]
  95. tokenized = self.wrapped.tokenizer([text for text, _ in parsed], truncation=False, add_special_tokens=False)["input_ids"]
  96. fixes = []
  97. remade_tokens = []
  98. multipliers = []
  99. for tokens, (text, weight) in zip(tokenized, parsed):
  100. i = 0
  101. while i < len(tokens):
  102. token = tokens[i]
  103. embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, i)
  104. if embedding is None:
  105. remade_tokens.append(token)
  106. multipliers.append(weight)
  107. i += 1
  108. else:
  109. emb_len = int(embedding.vec.shape[0])
  110. fixes.append((len(remade_tokens), embedding))
  111. remade_tokens += [0] * emb_len
  112. multipliers += [weight] * emb_len
  113. used_custom_terms.append((embedding.name, embedding.checksum()))
  114. i += embedding_length_in_tokens
  115. if len(remade_tokens) > maxlen - 2:
  116. vocab = {v: k for k, v in self.wrapped.tokenizer.get_vocab().items()}
  117. ovf = remade_tokens[maxlen - 2:]
  118. overflowing_words = [vocab.get(int(x), "") for x in ovf]
  119. overflowing_text = self.wrapped.tokenizer.convert_tokens_to_string(''.join(overflowing_words))
  120. hijack_comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n")
  121. token_count = len(remade_tokens)
  122. remade_tokens = remade_tokens + [id_end] * (maxlen - 2 - len(remade_tokens))
  123. remade_tokens = [id_start] + remade_tokens[0:maxlen - 2] + [id_end]
  124. multipliers = multipliers + [1.0] * (maxlen - 2 - len(multipliers))
  125. multipliers = [1.0] + multipliers[0:maxlen - 2] + [1.0]
  126. return remade_tokens, fixes, multipliers, token_count
  127. def process_text(self, texts):
  128. used_custom_terms = []
  129. remade_batch_tokens = []
  130. hijack_comments = []
  131. hijack_fixes = []
  132. token_count = 0
  133. cache = {}
  134. batch_multipliers = []
  135. for line in texts:
  136. if line in cache:
  137. remade_tokens, fixes, multipliers = cache[line]
  138. else:
  139. remade_tokens, fixes, multipliers, token_count = self.tokenize_line(line, used_custom_terms, hijack_comments)
  140. cache[line] = (remade_tokens, fixes, multipliers)
  141. remade_batch_tokens.append(remade_tokens)
  142. hijack_fixes.append(fixes)
  143. batch_multipliers.append(multipliers)
  144. return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count
  145. def process_text_old(self, text):
  146. id_start = self.wrapped.tokenizer.bos_token_id
  147. id_end = self.wrapped.tokenizer.eos_token_id
  148. maxlen = self.wrapped.max_length
  149. used_custom_terms = []
  150. remade_batch_tokens = []
  151. overflowing_words = []
  152. hijack_comments = []
  153. hijack_fixes = []
  154. token_count = 0
  155. cache = {}
  156. batch_tokens = self.wrapped.tokenizer(text, truncation=False, add_special_tokens=False)["input_ids"]
  157. batch_multipliers = []
  158. for tokens in batch_tokens:
  159. tuple_tokens = tuple(tokens)
  160. if tuple_tokens in cache:
  161. remade_tokens, fixes, multipliers = cache[tuple_tokens]
  162. else:
  163. fixes = []
  164. remade_tokens = []
  165. multipliers = []
  166. mult = 1.0
  167. i = 0
  168. while i < len(tokens):
  169. token = tokens[i]
  170. embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, i)
  171. mult_change = self.token_mults.get(token) if opts.enable_emphasis else None
  172. if mult_change is not None:
  173. mult *= mult_change
  174. i += 1
  175. elif embedding is None:
  176. remade_tokens.append(token)
  177. multipliers.append(mult)
  178. i += 1
  179. else:
  180. emb_len = int(embedding.vec.shape[0])
  181. fixes.append((len(remade_tokens), embedding))
  182. remade_tokens += [0] * emb_len
  183. multipliers += [mult] * emb_len
  184. used_custom_terms.append((embedding.name, embedding.checksum()))
  185. i += embedding_length_in_tokens
  186. if len(remade_tokens) > maxlen - 2:
  187. vocab = {v: k for k, v in self.wrapped.tokenizer.get_vocab().items()}
  188. ovf = remade_tokens[maxlen - 2:]
  189. overflowing_words = [vocab.get(int(x), "") for x in ovf]
  190. overflowing_text = self.wrapped.tokenizer.convert_tokens_to_string(''.join(overflowing_words))
  191. hijack_comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n")
  192. token_count = len(remade_tokens)
  193. remade_tokens = remade_tokens + [id_end] * (maxlen - 2 - len(remade_tokens))
  194. remade_tokens = [id_start] + remade_tokens[0:maxlen-2] + [id_end]
  195. cache[tuple_tokens] = (remade_tokens, fixes, multipliers)
  196. multipliers = multipliers + [1.0] * (maxlen - 2 - len(multipliers))
  197. multipliers = [1.0] + multipliers[0:maxlen - 2] + [1.0]
  198. remade_batch_tokens.append(remade_tokens)
  199. hijack_fixes.append(fixes)
  200. batch_multipliers.append(multipliers)
  201. return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count
  202. def forward(self, text):
  203. if opts.use_old_emphasis_implementation:
  204. batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text_old(text)
  205. else:
  206. batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text(text)
  207. self.hijack.fixes = hijack_fixes
  208. self.hijack.comments = hijack_comments
  209. if len(used_custom_terms) > 0:
  210. self.hijack.comments.append("Used embeddings: " + ", ".join([f'{word} [{checksum}]' for word, checksum in used_custom_terms]))
  211. tokens = torch.asarray(remade_batch_tokens).to(device)
  212. outputs = self.wrapped.transformer(input_ids=tokens)
  213. z = outputs.last_hidden_state
  214. # restoring original mean is likely not correct, but it seems to work well to prevent artifacts that happen otherwise
  215. batch_multipliers = torch.asarray(batch_multipliers).to(device)
  216. original_mean = z.mean()
  217. z *= batch_multipliers.reshape(batch_multipliers.shape + (1,)).expand(z.shape)
  218. new_mean = z.mean()
  219. z *= original_mean / new_mean
  220. return z
  221. class EmbeddingsWithFixes(torch.nn.Module):
  222. def __init__(self, wrapped, embeddings):
  223. super().__init__()
  224. self.wrapped = wrapped
  225. self.embeddings = embeddings
  226. def forward(self, input_ids):
  227. batch_fixes = self.embeddings.fixes
  228. self.embeddings.fixes = None
  229. inputs_embeds = self.wrapped(input_ids)
  230. if batch_fixes is None or len(batch_fixes) == 0 or max([len(x) for x in batch_fixes]) == 0:
  231. return inputs_embeds
  232. vecs = []
  233. for fixes, tensor in zip(batch_fixes, inputs_embeds):
  234. for offset, embedding in fixes:
  235. emb = embedding.vec
  236. emb_len = min(tensor.shape[0]-offset-1, emb.shape[0])
  237. tensor = torch.cat([tensor[0:offset+1], emb[0:emb_len], tensor[offset+1+emb_len:]])
  238. vecs.append(tensor)
  239. return torch.stack(vecs)
  240. def add_circular_option_to_conv_2d():
  241. conv2d_constructor = torch.nn.Conv2d.__init__
  242. def conv2d_constructor_circular(self, *args, **kwargs):
  243. return conv2d_constructor(self, *args, padding_mode='circular', **kwargs)
  244. torch.nn.Conv2d.__init__ = conv2d_constructor_circular
  245. model_hijack = StableDiffusionModelHijack()