sd_hijack.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539
  1. import math
  2. import os
  3. import sys
  4. import traceback
  5. import torch
  6. import numpy as np
  7. from torch import einsum
  8. from modules import prompt_parser
  9. from modules.shared import opts, device, cmd_opts
  10. from ldm.util import default
  11. from einops import rearrange
  12. import ldm.modules.attention
  13. import ldm.modules.diffusionmodules.model
  14. # see https://github.com/basujindal/stable-diffusion/pull/117 for discussion
  15. def split_cross_attention_forward_v1(self, x, context=None, mask=None):
  16. h = self.heads
  17. q = self.to_q(x)
  18. context = default(context, x)
  19. k = self.to_k(context)
  20. v = self.to_v(context)
  21. del context, x
  22. q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
  23. r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device)
  24. for i in range(0, q.shape[0], 2):
  25. end = i + 2
  26. s1 = einsum('b i d, b j d -> b i j', q[i:end], k[i:end])
  27. s1 *= self.scale
  28. s2 = s1.softmax(dim=-1)
  29. del s1
  30. r1[i:end] = einsum('b i j, b j d -> b i d', s2, v[i:end])
  31. del s2
  32. r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h)
  33. del r1
  34. return self.to_out(r2)
  35. # taken from https://github.com/Doggettx/stable-diffusion
  36. def split_cross_attention_forward(self, x, context=None, mask=None):
  37. h = self.heads
  38. q_in = self.to_q(x)
  39. context = default(context, x)
  40. k_in = self.to_k(context) * self.scale
  41. v_in = self.to_v(context)
  42. del context, x
  43. q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in))
  44. del q_in, k_in, v_in
  45. r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
  46. stats = torch.cuda.memory_stats(q.device)
  47. mem_active = stats['active_bytes.all.current']
  48. mem_reserved = stats['reserved_bytes.all.current']
  49. mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device())
  50. mem_free_torch = mem_reserved - mem_active
  51. mem_free_total = mem_free_cuda + mem_free_torch
  52. gb = 1024 ** 3
  53. tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size()
  54. modifier = 3 if q.element_size() == 2 else 2.5
  55. mem_required = tensor_size * modifier
  56. steps = 1
  57. if mem_required > mem_free_total:
  58. steps = 2 ** (math.ceil(math.log(mem_required / mem_free_total, 2)))
  59. # print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB "
  60. # f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}")
  61. if steps > 64:
  62. max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64
  63. raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). '
  64. f'Need: {mem_required / 64 / gb:0.1f}GB free, Have:{mem_free_total / gb:0.1f}GB free')
  65. slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
  66. for i in range(0, q.shape[1], slice_size):
  67. end = i + slice_size
  68. s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k)
  69. s2 = s1.softmax(dim=-1, dtype=q.dtype)
  70. del s1
  71. r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
  72. del s2
  73. del q, k, v
  74. r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h)
  75. del r1
  76. return self.to_out(r2)
  77. def nonlinearity_hijack(x):
  78. # swish
  79. t = torch.sigmoid(x)
  80. x *= t
  81. del t
  82. return x
  83. def cross_attention_attnblock_forward(self, x):
  84. h_ = x
  85. h_ = self.norm(h_)
  86. q1 = self.q(h_)
  87. k1 = self.k(h_)
  88. v = self.v(h_)
  89. # compute attention
  90. b, c, h, w = q1.shape
  91. q2 = q1.reshape(b, c, h*w)
  92. del q1
  93. q = q2.permute(0, 2, 1) # b,hw,c
  94. del q2
  95. k = k1.reshape(b, c, h*w) # b,c,hw
  96. del k1
  97. h_ = torch.zeros_like(k, device=q.device)
  98. stats = torch.cuda.memory_stats(q.device)
  99. mem_active = stats['active_bytes.all.current']
  100. mem_reserved = stats['reserved_bytes.all.current']
  101. mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device())
  102. mem_free_torch = mem_reserved - mem_active
  103. mem_free_total = mem_free_cuda + mem_free_torch
  104. tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * q.element_size()
  105. mem_required = tensor_size * 2.5
  106. steps = 1
  107. if mem_required > mem_free_total:
  108. steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2)))
  109. slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
  110. for i in range(0, q.shape[1], slice_size):
  111. end = i + slice_size
  112. w1 = torch.bmm(q[:, i:end], k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
  113. w2 = w1 * (int(c)**(-0.5))
  114. del w1
  115. w3 = torch.nn.functional.softmax(w2, dim=2, dtype=q.dtype)
  116. del w2
  117. # attend to values
  118. v1 = v.reshape(b, c, h*w)
  119. w4 = w3.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
  120. del w3
  121. h_[:, :, i:end] = torch.bmm(v1, w4) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
  122. del v1, w4
  123. h2 = h_.reshape(b, c, h, w)
  124. del h_
  125. h3 = self.proj_out(h2)
  126. del h2
  127. h3 += x
  128. return h3
  129. class StableDiffusionModelHijack:
  130. ids_lookup = {}
  131. word_embeddings = {}
  132. word_embeddings_checksums = {}
  133. fixes = None
  134. comments = []
  135. dir_mtime = None
  136. layers = None
  137. circular_enabled = False
  138. clip = None
  139. def load_textual_inversion_embeddings(self, dirname, model):
  140. mt = os.path.getmtime(dirname)
  141. if self.dir_mtime is not None and mt <= self.dir_mtime:
  142. return
  143. self.dir_mtime = mt
  144. self.ids_lookup.clear()
  145. self.word_embeddings.clear()
  146. tokenizer = model.cond_stage_model.tokenizer
  147. def const_hash(a):
  148. r = 0
  149. for v in a:
  150. r = (r * 281 ^ int(v) * 997) & 0xFFFFFFFF
  151. return r
  152. def process_file(path, filename):
  153. name = os.path.splitext(filename)[0]
  154. data = torch.load(path, map_location="cpu")
  155. # textual inversion embeddings
  156. if 'string_to_param' in data:
  157. param_dict = data['string_to_param']
  158. if hasattr(param_dict, '_parameters'):
  159. param_dict = getattr(param_dict, '_parameters') # fix for torch 1.12.1 loading saved file from torch 1.11
  160. assert len(param_dict) == 1, 'embedding file has multiple terms in it'
  161. emb = next(iter(param_dict.items()))[1]
  162. # diffuser concepts
  163. elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor:
  164. assert len(data.keys()) == 1, 'embedding file has multiple terms in it'
  165. emb = next(iter(data.values()))
  166. if len(emb.shape) == 1:
  167. emb = emb.unsqueeze(0)
  168. self.word_embeddings[name] = emb.detach().to(device)
  169. self.word_embeddings_checksums[name] = f'{const_hash(emb.reshape(-1)*100)&0xffff:04x}'
  170. ids = tokenizer([name], add_special_tokens=False)['input_ids'][0]
  171. first_id = ids[0]
  172. if first_id not in self.ids_lookup:
  173. self.ids_lookup[first_id] = []
  174. self.ids_lookup[first_id].append((ids, name))
  175. for fn in os.listdir(dirname):
  176. try:
  177. fullfn = os.path.join(dirname, fn)
  178. if os.stat(fullfn).st_size == 0:
  179. continue
  180. process_file(fullfn, fn)
  181. except Exception:
  182. print(f"Error loading emedding {fn}:", file=sys.stderr)
  183. print(traceback.format_exc(), file=sys.stderr)
  184. continue
  185. print(f"Loaded a total of {len(self.word_embeddings)} textual inversion embeddings.")
  186. def hijack(self, m):
  187. model_embeddings = m.cond_stage_model.transformer.text_model.embeddings
  188. model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self)
  189. m.cond_stage_model = FrozenCLIPEmbedderWithCustomWords(m.cond_stage_model, self)
  190. self.clip = m.cond_stage_model
  191. if cmd_opts.opt_split_attention_v1:
  192. ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward_v1
  193. elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention or torch.cuda.is_available()):
  194. ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward
  195. ldm.modules.diffusionmodules.model.nonlinearity = nonlinearity_hijack
  196. ldm.modules.diffusionmodules.model.AttnBlock.forward = cross_attention_attnblock_forward
  197. def flatten(el):
  198. flattened = [flatten(children) for children in el.children()]
  199. res = [el]
  200. for c in flattened:
  201. res += c
  202. return res
  203. self.layers = flatten(m)
  204. def undo_hijack(self, m):
  205. if type(m.cond_stage_model) == FrozenCLIPEmbedderWithCustomWords:
  206. m.cond_stage_model = m.cond_stage_model.wrapped
  207. model_embeddings = m.cond_stage_model.transformer.text_model.embeddings
  208. if type(model_embeddings.token_embedding) == EmbeddingsWithFixes:
  209. model_embeddings.token_embedding = model_embeddings.token_embedding.wrapped
  210. def apply_circular(self, enable):
  211. if self.circular_enabled == enable:
  212. return
  213. self.circular_enabled = enable
  214. for layer in [layer for layer in self.layers if type(layer) == torch.nn.Conv2d]:
  215. layer.padding_mode = 'circular' if enable else 'zeros'
  216. def tokenize(self, text):
  217. max_length = self.clip.max_length - 2
  218. _, remade_batch_tokens, _, _, _, token_count = self.clip.process_text([text])
  219. return remade_batch_tokens[0], token_count, max_length
  220. class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
  221. def __init__(self, wrapped, hijack):
  222. super().__init__()
  223. self.wrapped = wrapped
  224. self.hijack = hijack
  225. self.tokenizer = wrapped.tokenizer
  226. self.max_length = wrapped.max_length
  227. self.token_mults = {}
  228. tokens_with_parens = [(k, v) for k, v in self.tokenizer.get_vocab().items() if '(' in k or ')' in k or '[' in k or ']' in k]
  229. for text, ident in tokens_with_parens:
  230. mult = 1.0
  231. for c in text:
  232. if c == '[':
  233. mult /= 1.1
  234. if c == ']':
  235. mult *= 1.1
  236. if c == '(':
  237. mult *= 1.1
  238. if c == ')':
  239. mult /= 1.1
  240. if mult != 1.0:
  241. self.token_mults[ident] = mult
  242. def tokenize_line(self, line, used_custom_terms, hijack_comments):
  243. id_start = self.wrapped.tokenizer.bos_token_id
  244. id_end = self.wrapped.tokenizer.eos_token_id
  245. maxlen = self.wrapped.max_length
  246. if opts.enable_emphasis:
  247. parsed = prompt_parser.parse_prompt_attention(line)
  248. else:
  249. parsed = [[line, 1.0]]
  250. tokenized = self.wrapped.tokenizer([text for text, _ in parsed], truncation=False, add_special_tokens=False)["input_ids"]
  251. fixes = []
  252. remade_tokens = []
  253. multipliers = []
  254. for tokens, (text, weight) in zip(tokenized, parsed):
  255. i = 0
  256. while i < len(tokens):
  257. token = tokens[i]
  258. possible_matches = self.hijack.ids_lookup.get(token, None)
  259. if possible_matches is None:
  260. remade_tokens.append(token)
  261. multipliers.append(weight)
  262. else:
  263. found = False
  264. for ids, word in possible_matches:
  265. if tokens[i:i + len(ids)] == ids:
  266. emb_len = int(self.hijack.word_embeddings[word].shape[0])
  267. fixes.append((len(remade_tokens), word))
  268. remade_tokens += [0] * emb_len
  269. multipliers += [weight] * emb_len
  270. i += len(ids) - 1
  271. found = True
  272. used_custom_terms.append((word, self.hijack.word_embeddings_checksums[word]))
  273. break
  274. if not found:
  275. remade_tokens.append(token)
  276. multipliers.append(weight)
  277. i += 1
  278. if len(remade_tokens) > maxlen - 2:
  279. vocab = {v: k for k, v in self.wrapped.tokenizer.get_vocab().items()}
  280. ovf = remade_tokens[maxlen - 2:]
  281. overflowing_words = [vocab.get(int(x), "") for x in ovf]
  282. overflowing_text = self.wrapped.tokenizer.convert_tokens_to_string(''.join(overflowing_words))
  283. hijack_comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n")
  284. token_count = len(remade_tokens)
  285. remade_tokens = remade_tokens + [id_end] * (maxlen - 2 - len(remade_tokens))
  286. remade_tokens = [id_start] + remade_tokens[0:maxlen - 2] + [id_end]
  287. multipliers = multipliers + [1.0] * (maxlen - 2 - len(multipliers))
  288. multipliers = [1.0] + multipliers[0:maxlen - 2] + [1.0]
  289. return remade_tokens, fixes, multipliers, token_count
  290. def process_text(self, texts):
  291. used_custom_terms = []
  292. remade_batch_tokens = []
  293. hijack_comments = []
  294. hijack_fixes = []
  295. token_count = 0
  296. cache = {}
  297. batch_multipliers = []
  298. for line in texts:
  299. if line in cache:
  300. remade_tokens, fixes, multipliers = cache[line]
  301. else:
  302. remade_tokens, fixes, multipliers, token_count = self.tokenize_line(line, used_custom_terms, hijack_comments)
  303. cache[line] = (remade_tokens, fixes, multipliers)
  304. remade_batch_tokens.append(remade_tokens)
  305. hijack_fixes.append(fixes)
  306. batch_multipliers.append(multipliers)
  307. return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count
  308. def process_text_old(self, text):
  309. id_start = self.wrapped.tokenizer.bos_token_id
  310. id_end = self.wrapped.tokenizer.eos_token_id
  311. maxlen = self.wrapped.max_length
  312. used_custom_terms = []
  313. remade_batch_tokens = []
  314. overflowing_words = []
  315. hijack_comments = []
  316. hijack_fixes = []
  317. token_count = 0
  318. cache = {}
  319. batch_tokens = self.wrapped.tokenizer(text, truncation=False, add_special_tokens=False)["input_ids"]
  320. batch_multipliers = []
  321. for tokens in batch_tokens:
  322. tuple_tokens = tuple(tokens)
  323. if tuple_tokens in cache:
  324. remade_tokens, fixes, multipliers = cache[tuple_tokens]
  325. else:
  326. fixes = []
  327. remade_tokens = []
  328. multipliers = []
  329. mult = 1.0
  330. i = 0
  331. while i < len(tokens):
  332. token = tokens[i]
  333. possible_matches = self.hijack.ids_lookup.get(token, None)
  334. mult_change = self.token_mults.get(token) if opts.enable_emphasis else None
  335. if mult_change is not None:
  336. mult *= mult_change
  337. elif possible_matches is None:
  338. remade_tokens.append(token)
  339. multipliers.append(mult)
  340. else:
  341. found = False
  342. for ids, word in possible_matches:
  343. if tokens[i:i+len(ids)] == ids:
  344. emb_len = int(self.hijack.word_embeddings[word].shape[0])
  345. fixes.append((len(remade_tokens), word))
  346. remade_tokens += [0] * emb_len
  347. multipliers += [mult] * emb_len
  348. i += len(ids) - 1
  349. found = True
  350. used_custom_terms.append((word, self.hijack.word_embeddings_checksums[word]))
  351. break
  352. if not found:
  353. remade_tokens.append(token)
  354. multipliers.append(mult)
  355. i += 1
  356. if len(remade_tokens) > maxlen - 2:
  357. vocab = {v: k for k, v in self.wrapped.tokenizer.get_vocab().items()}
  358. ovf = remade_tokens[maxlen - 2:]
  359. overflowing_words = [vocab.get(int(x), "") for x in ovf]
  360. overflowing_text = self.wrapped.tokenizer.convert_tokens_to_string(''.join(overflowing_words))
  361. hijack_comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n")
  362. token_count = len(remade_tokens)
  363. remade_tokens = remade_tokens + [id_end] * (maxlen - 2 - len(remade_tokens))
  364. remade_tokens = [id_start] + remade_tokens[0:maxlen-2] + [id_end]
  365. cache[tuple_tokens] = (remade_tokens, fixes, multipliers)
  366. multipliers = multipliers + [1.0] * (maxlen - 2 - len(multipliers))
  367. multipliers = [1.0] + multipliers[0:maxlen - 2] + [1.0]
  368. remade_batch_tokens.append(remade_tokens)
  369. hijack_fixes.append(fixes)
  370. batch_multipliers.append(multipliers)
  371. return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count
  372. def forward(self, text):
  373. if opts.use_old_emphasis_implementation:
  374. batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text_old(text)
  375. else:
  376. batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text(text)
  377. self.hijack.fixes = hijack_fixes
  378. self.hijack.comments = hijack_comments
  379. if len(used_custom_terms) > 0:
  380. self.hijack.comments.append("Used embeddings: " + ", ".join([f'{word} [{checksum}]' for word, checksum in used_custom_terms]))
  381. tokens = torch.asarray(remade_batch_tokens).to(device)
  382. outputs = self.wrapped.transformer(input_ids=tokens)
  383. z = outputs.last_hidden_state
  384. # restoring original mean is likely not correct, but it seems to work well to prevent artifacts that happen otherwise
  385. batch_multipliers = torch.asarray(batch_multipliers).to(device)
  386. original_mean = z.mean()
  387. z *= batch_multipliers.reshape(batch_multipliers.shape + (1,)).expand(z.shape)
  388. new_mean = z.mean()
  389. z *= original_mean / new_mean
  390. return z
  391. class EmbeddingsWithFixes(torch.nn.Module):
  392. def __init__(self, wrapped, embeddings):
  393. super().__init__()
  394. self.wrapped = wrapped
  395. self.embeddings = embeddings
  396. def forward(self, input_ids):
  397. batch_fixes = self.embeddings.fixes
  398. self.embeddings.fixes = None
  399. inputs_embeds = self.wrapped(input_ids)
  400. if batch_fixes is not None:
  401. for fixes, tensor in zip(batch_fixes, inputs_embeds):
  402. for offset, word in fixes:
  403. emb = self.embeddings.word_embeddings[word]
  404. emb_len = min(tensor.shape[0]-offset-1, emb.shape[0])
  405. tensor[offset+1:offset+1+emb_len] = self.embeddings.word_embeddings[word][0:emb_len]
  406. return inputs_embeds
  407. def add_circular_option_to_conv_2d():
  408. conv2d_constructor = torch.nn.Conv2d.__init__
  409. def conv2d_constructor_circular(self, *args, **kwargs):
  410. return conv2d_constructor(self, *args, padding_mode='circular', **kwargs)
  411. torch.nn.Conv2d.__init__ = conv2d_constructor_circular
  412. model_hijack = StableDiffusionModelHijack()