sd_hijack.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323
  1. import math
  2. import os
  3. import sys
  4. import traceback
  5. import torch
  6. import numpy as np
  7. from torch import einsum
  8. from torch.nn.functional import silu
  9. import modules.textual_inversion.textual_inversion
  10. from modules import prompt_parser, devices, sd_hijack_optimizations, shared
  11. from modules.shared import opts, device, cmd_opts
  12. import ldm.modules.attention
  13. import ldm.modules.diffusionmodules.model
  14. attention_CrossAttention_forward = ldm.modules.attention.CrossAttention.forward
  15. diffusionmodules_model_nonlinearity = ldm.modules.diffusionmodules.model.nonlinearity
  16. diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.AttnBlock.forward
  17. def apply_optimizations():
  18. undo_optimizations()
  19. ldm.modules.diffusionmodules.model.nonlinearity = silu
  20. if cmd_opts.opt_split_attention_v1:
  21. ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1
  22. elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention or torch.cuda.is_available()):
  23. ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward
  24. ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.cross_attention_attnblock_forward
  25. def undo_optimizations():
  26. from modules.hypernetwork import hypernetwork
  27. ldm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward
  28. ldm.modules.diffusionmodules.model.nonlinearity = diffusionmodules_model_nonlinearity
  29. ldm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward
  30. class StableDiffusionModelHijack:
  31. fixes = None
  32. comments = []
  33. layers = None
  34. circular_enabled = False
  35. clip = None
  36. embedding_db = modules.textual_inversion.textual_inversion.EmbeddingDatabase(cmd_opts.embeddings_dir)
  37. def hijack(self, m):
  38. model_embeddings = m.cond_stage_model.transformer.text_model.embeddings
  39. model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self)
  40. m.cond_stage_model = FrozenCLIPEmbedderWithCustomWords(m.cond_stage_model, self)
  41. self.clip = m.cond_stage_model
  42. apply_optimizations()
  43. def flatten(el):
  44. flattened = [flatten(children) for children in el.children()]
  45. res = [el]
  46. for c in flattened:
  47. res += c
  48. return res
  49. self.layers = flatten(m)
  50. def undo_hijack(self, m):
  51. if type(m.cond_stage_model) == FrozenCLIPEmbedderWithCustomWords:
  52. m.cond_stage_model = m.cond_stage_model.wrapped
  53. model_embeddings = m.cond_stage_model.transformer.text_model.embeddings
  54. if type(model_embeddings.token_embedding) == EmbeddingsWithFixes:
  55. model_embeddings.token_embedding = model_embeddings.token_embedding.wrapped
  56. def apply_circular(self, enable):
  57. if self.circular_enabled == enable:
  58. return
  59. self.circular_enabled = enable
  60. for layer in [layer for layer in self.layers if type(layer) == torch.nn.Conv2d]:
  61. layer.padding_mode = 'circular' if enable else 'zeros'
  62. def tokenize(self, text):
  63. max_length = self.clip.max_length - 2
  64. _, remade_batch_tokens, _, _, _, token_count = self.clip.process_text([text])
  65. return remade_batch_tokens[0], token_count, max_length
  66. class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
  67. def __init__(self, wrapped, hijack):
  68. super().__init__()
  69. self.wrapped = wrapped
  70. self.hijack: StableDiffusionModelHijack = hijack
  71. self.tokenizer = wrapped.tokenizer
  72. self.max_length = wrapped.max_length
  73. self.token_mults = {}
  74. tokens_with_parens = [(k, v) for k, v in self.tokenizer.get_vocab().items() if '(' in k or ')' in k or '[' in k or ']' in k]
  75. for text, ident in tokens_with_parens:
  76. mult = 1.0
  77. for c in text:
  78. if c == '[':
  79. mult /= 1.1
  80. if c == ']':
  81. mult *= 1.1
  82. if c == '(':
  83. mult *= 1.1
  84. if c == ')':
  85. mult /= 1.1
  86. if mult != 1.0:
  87. self.token_mults[ident] = mult
  88. def tokenize_line(self, line, used_custom_terms, hijack_comments):
  89. id_start = self.wrapped.tokenizer.bos_token_id
  90. id_end = self.wrapped.tokenizer.eos_token_id
  91. maxlen = self.wrapped.max_length
  92. if opts.enable_emphasis:
  93. parsed = prompt_parser.parse_prompt_attention(line)
  94. else:
  95. parsed = [[line, 1.0]]
  96. tokenized = self.wrapped.tokenizer([text for text, _ in parsed], truncation=False, add_special_tokens=False)["input_ids"]
  97. fixes = []
  98. remade_tokens = []
  99. multipliers = []
  100. for tokens, (text, weight) in zip(tokenized, parsed):
  101. i = 0
  102. while i < len(tokens):
  103. token = tokens[i]
  104. embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, i)
  105. if embedding is None:
  106. remade_tokens.append(token)
  107. multipliers.append(weight)
  108. i += 1
  109. else:
  110. emb_len = int(embedding.vec.shape[0])
  111. fixes.append((len(remade_tokens), embedding))
  112. remade_tokens += [0] * emb_len
  113. multipliers += [weight] * emb_len
  114. used_custom_terms.append((embedding.name, embedding.checksum()))
  115. i += embedding_length_in_tokens
  116. if len(remade_tokens) > maxlen - 2:
  117. vocab = {v: k for k, v in self.wrapped.tokenizer.get_vocab().items()}
  118. ovf = remade_tokens[maxlen - 2:]
  119. overflowing_words = [vocab.get(int(x), "") for x in ovf]
  120. overflowing_text = self.wrapped.tokenizer.convert_tokens_to_string(''.join(overflowing_words))
  121. hijack_comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n")
  122. token_count = len(remade_tokens)
  123. remade_tokens = remade_tokens + [id_end] * (maxlen - 2 - len(remade_tokens))
  124. remade_tokens = [id_start] + remade_tokens[0:maxlen - 2] + [id_end]
  125. multipliers = multipliers + [1.0] * (maxlen - 2 - len(multipliers))
  126. multipliers = [1.0] + multipliers[0:maxlen - 2] + [1.0]
  127. return remade_tokens, fixes, multipliers, token_count
  128. def process_text(self, texts):
  129. used_custom_terms = []
  130. remade_batch_tokens = []
  131. hijack_comments = []
  132. hijack_fixes = []
  133. token_count = 0
  134. cache = {}
  135. batch_multipliers = []
  136. for line in texts:
  137. if line in cache:
  138. remade_tokens, fixes, multipliers = cache[line]
  139. else:
  140. remade_tokens, fixes, multipliers, token_count = self.tokenize_line(line, used_custom_terms, hijack_comments)
  141. cache[line] = (remade_tokens, fixes, multipliers)
  142. remade_batch_tokens.append(remade_tokens)
  143. hijack_fixes.append(fixes)
  144. batch_multipliers.append(multipliers)
  145. return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count
  146. def process_text_old(self, text):
  147. id_start = self.wrapped.tokenizer.bos_token_id
  148. id_end = self.wrapped.tokenizer.eos_token_id
  149. maxlen = self.wrapped.max_length
  150. used_custom_terms = []
  151. remade_batch_tokens = []
  152. overflowing_words = []
  153. hijack_comments = []
  154. hijack_fixes = []
  155. token_count = 0
  156. cache = {}
  157. batch_tokens = self.wrapped.tokenizer(text, truncation=False, add_special_tokens=False)["input_ids"]
  158. batch_multipliers = []
  159. for tokens in batch_tokens:
  160. tuple_tokens = tuple(tokens)
  161. if tuple_tokens in cache:
  162. remade_tokens, fixes, multipliers = cache[tuple_tokens]
  163. else:
  164. fixes = []
  165. remade_tokens = []
  166. multipliers = []
  167. mult = 1.0
  168. i = 0
  169. while i < len(tokens):
  170. token = tokens[i]
  171. embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, i)
  172. mult_change = self.token_mults.get(token) if opts.enable_emphasis else None
  173. if mult_change is not None:
  174. mult *= mult_change
  175. i += 1
  176. elif embedding is None:
  177. remade_tokens.append(token)
  178. multipliers.append(mult)
  179. i += 1
  180. else:
  181. emb_len = int(embedding.vec.shape[0])
  182. fixes.append((len(remade_tokens), embedding))
  183. remade_tokens += [0] * emb_len
  184. multipliers += [mult] * emb_len
  185. used_custom_terms.append((embedding.name, embedding.checksum()))
  186. i += embedding_length_in_tokens
  187. if len(remade_tokens) > maxlen - 2:
  188. vocab = {v: k for k, v in self.wrapped.tokenizer.get_vocab().items()}
  189. ovf = remade_tokens[maxlen - 2:]
  190. overflowing_words = [vocab.get(int(x), "") for x in ovf]
  191. overflowing_text = self.wrapped.tokenizer.convert_tokens_to_string(''.join(overflowing_words))
  192. hijack_comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n")
  193. token_count = len(remade_tokens)
  194. remade_tokens = remade_tokens + [id_end] * (maxlen - 2 - len(remade_tokens))
  195. remade_tokens = [id_start] + remade_tokens[0:maxlen-2] + [id_end]
  196. cache[tuple_tokens] = (remade_tokens, fixes, multipliers)
  197. multipliers = multipliers + [1.0] * (maxlen - 2 - len(multipliers))
  198. multipliers = [1.0] + multipliers[0:maxlen - 2] + [1.0]
  199. remade_batch_tokens.append(remade_tokens)
  200. hijack_fixes.append(fixes)
  201. batch_multipliers.append(multipliers)
  202. return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count
  203. def forward(self, text):
  204. if opts.use_old_emphasis_implementation:
  205. batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text_old(text)
  206. else:
  207. batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text(text)
  208. self.hijack.fixes = hijack_fixes
  209. self.hijack.comments = hijack_comments
  210. if len(used_custom_terms) > 0:
  211. self.hijack.comments.append("Used embeddings: " + ", ".join([f'{word} [{checksum}]' for word, checksum in used_custom_terms]))
  212. tokens = torch.asarray(remade_batch_tokens).to(device)
  213. outputs = self.wrapped.transformer(input_ids=tokens)
  214. z = outputs.last_hidden_state
  215. # restoring original mean is likely not correct, but it seems to work well to prevent artifacts that happen otherwise
  216. batch_multipliers = torch.asarray(batch_multipliers).to(device)
  217. original_mean = z.mean()
  218. z *= batch_multipliers.reshape(batch_multipliers.shape + (1,)).expand(z.shape)
  219. new_mean = z.mean()
  220. z *= original_mean / new_mean
  221. return z
  222. class EmbeddingsWithFixes(torch.nn.Module):
  223. def __init__(self, wrapped, embeddings):
  224. super().__init__()
  225. self.wrapped = wrapped
  226. self.embeddings = embeddings
  227. def forward(self, input_ids):
  228. batch_fixes = self.embeddings.fixes
  229. self.embeddings.fixes = None
  230. inputs_embeds = self.wrapped(input_ids)
  231. if batch_fixes is None or len(batch_fixes) == 0 or max([len(x) for x in batch_fixes]) == 0:
  232. return inputs_embeds
  233. vecs = []
  234. for fixes, tensor in zip(batch_fixes, inputs_embeds):
  235. for offset, embedding in fixes:
  236. emb = embedding.vec
  237. emb_len = min(tensor.shape[0]-offset-1, emb.shape[0])
  238. tensor = torch.cat([tensor[0:offset+1], emb[0:emb_len], tensor[offset+1+emb_len:]])
  239. vecs.append(tensor)
  240. return torch.stack(vecs)
  241. def add_circular_option_to_conv_2d():
  242. conv2d_constructor = torch.nn.Conv2d.__init__
  243. def conv2d_constructor_circular(self, *args, **kwargs):
  244. return conv2d_constructor(self, *args, padding_mode='circular', **kwargs)
  245. torch.nn.Conv2d.__init__ = conv2d_constructor_circular
  246. model_hijack = StableDiffusionModelHijack()