sd_hijack.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276
  1. import torch
  2. from torch.nn.functional import silu
  3. from types import MethodType
  4. import modules.textual_inversion.textual_inversion
  5. from modules import devices, sd_hijack_optimizations, shared
  6. from modules.hypernetworks import hypernetwork
  7. from modules.shared import cmd_opts
  8. from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr
  9. import ldm.modules.attention
  10. import ldm.modules.diffusionmodules.model
  11. import ldm.modules.diffusionmodules.openaimodel
  12. import ldm.models.diffusion.ddim
  13. import ldm.models.diffusion.plms
  14. import ldm.modules.encoders.modules
  15. attention_CrossAttention_forward = ldm.modules.attention.CrossAttention.forward
  16. diffusionmodules_model_nonlinearity = ldm.modules.diffusionmodules.model.nonlinearity
  17. diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.AttnBlock.forward
  18. # new memory efficient cross attention blocks do not support hypernets and we already
  19. # have memory efficient cross attention anyway, so this disables SD2.0's memory efficient cross attention
  20. ldm.modules.attention.MemoryEfficientCrossAttention = ldm.modules.attention.CrossAttention
  21. ldm.modules.attention.BasicTransformerBlock.ATTENTION_MODES["softmax-xformers"] = ldm.modules.attention.CrossAttention
  22. # silence new console spam from SD2
  23. ldm.modules.attention.print = lambda *args: None
  24. ldm.modules.diffusionmodules.model.print = lambda *args: None
  25. def apply_optimizations():
  26. undo_optimizations()
  27. ldm.modules.diffusionmodules.model.nonlinearity = silu
  28. ldm.modules.diffusionmodules.openaimodel.th = sd_hijack_unet.th
  29. optimization_method = None
  30. can_use_sdp = hasattr(torch.nn.functional, "scaled_dot_product_attention") and callable(getattr(torch.nn.functional, "scaled_dot_product_attention")) # not everyone has torch 2.x to use sdp
  31. if cmd_opts.force_enable_xformers or (cmd_opts.xformers and shared.xformers_available and torch.version.cuda and (6, 0) <= torch.cuda.get_device_capability(shared.device) <= (9, 0)):
  32. print("Applying xformers cross attention optimization.")
  33. ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.xformers_attention_forward
  34. ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.xformers_attnblock_forward
  35. optimization_method = 'xformers'
  36. elif cmd_opts.opt_sdp_no_mem_attention and can_use_sdp:
  37. print("Applying scaled dot product cross attention optimization (without memory efficient attention).")
  38. ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.scaled_dot_product_no_mem_attention_forward
  39. ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.sdp_no_mem_attnblock_forward
  40. optimization_method = 'sdp-no-mem'
  41. elif cmd_opts.opt_sdp_attention and can_use_sdp:
  42. print("Applying scaled dot product cross attention optimization.")
  43. ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.scaled_dot_product_attention_forward
  44. ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.sdp_attnblock_forward
  45. optimization_method = 'sdp'
  46. elif cmd_opts.opt_sub_quad_attention:
  47. print("Applying sub-quadratic cross attention optimization.")
  48. ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.sub_quad_attention_forward
  49. ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.sub_quad_attnblock_forward
  50. optimization_method = 'sub-quadratic'
  51. elif cmd_opts.opt_split_attention_v1:
  52. print("Applying v1 cross attention optimization.")
  53. ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1
  54. optimization_method = 'V1'
  55. elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention_invokeai or not cmd_opts.opt_split_attention and not torch.cuda.is_available()):
  56. print("Applying cross attention optimization (InvokeAI).")
  57. ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_invokeAI
  58. optimization_method = 'InvokeAI'
  59. elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention or torch.cuda.is_available()):
  60. print("Applying cross attention optimization (Doggettx).")
  61. ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward
  62. ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.cross_attention_attnblock_forward
  63. optimization_method = 'Doggettx'
  64. return optimization_method
  65. def undo_optimizations():
  66. ldm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward
  67. ldm.modules.diffusionmodules.model.nonlinearity = diffusionmodules_model_nonlinearity
  68. ldm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward
  69. def fix_checkpoint():
  70. """checkpoints are now added and removed in embedding/hypernet code, since torch doesn't want
  71. checkpoints to be added when not training (there's a warning)"""
  72. pass
  73. def weighted_loss(sd_model, pred, target, mean=True):
  74. #Calculate the weight normally, but ignore the mean
  75. loss = sd_model._old_get_loss(pred, target, mean=False)
  76. #Check if we have weights available
  77. weight = getattr(sd_model, '_custom_loss_weight', None)
  78. if weight is not None:
  79. loss *= weight
  80. #Return the loss, as mean if specified
  81. return loss.mean() if mean else loss
  82. def weighted_forward(sd_model, x, c, w, *args, **kwargs):
  83. try:
  84. #Temporarily append weights to a place accessible during loss calc
  85. sd_model._custom_loss_weight = w
  86. #Replace 'get_loss' with a weight-aware one. Otherwise we need to reimplement 'forward' completely
  87. #Keep 'get_loss', but don't overwrite the previous old_get_loss if it's already set
  88. if not hasattr(sd_model, '_old_get_loss'):
  89. sd_model._old_get_loss = sd_model.get_loss
  90. sd_model.get_loss = MethodType(weighted_loss, sd_model)
  91. #Run the standard forward function, but with the patched 'get_loss'
  92. return sd_model.forward(x, c, *args, **kwargs)
  93. finally:
  94. try:
  95. #Delete temporary weights if appended
  96. del sd_model._custom_loss_weight
  97. except AttributeError:
  98. pass
  99. #If we have an old loss function, reset the loss function to the original one
  100. if hasattr(sd_model, '_old_get_loss'):
  101. sd_model.get_loss = sd_model._old_get_loss
  102. del sd_model._old_get_loss
  103. def apply_weighted_forward(sd_model):
  104. #Add new function 'weighted_forward' that can be called to calc weighted loss
  105. sd_model.weighted_forward = MethodType(weighted_forward, sd_model)
  106. def undo_weighted_forward(sd_model):
  107. try:
  108. del sd_model.weighted_forward
  109. except AttributeError:
  110. pass
  111. class StableDiffusionModelHijack:
  112. fixes = None
  113. comments = []
  114. layers = None
  115. circular_enabled = False
  116. clip = None
  117. optimization_method = None
  118. embedding_db = modules.textual_inversion.textual_inversion.EmbeddingDatabase()
  119. def __init__(self):
  120. self.embedding_db.add_embedding_dir(cmd_opts.embeddings_dir)
  121. def hijack(self, m):
  122. if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation:
  123. model_embeddings = m.cond_stage_model.roberta.embeddings
  124. model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.word_embeddings, self)
  125. m.cond_stage_model = sd_hijack_xlmr.FrozenXLMREmbedderWithCustomWords(m.cond_stage_model, self)
  126. elif type(m.cond_stage_model) == ldm.modules.encoders.modules.FrozenCLIPEmbedder:
  127. model_embeddings = m.cond_stage_model.transformer.text_model.embeddings
  128. model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self)
  129. m.cond_stage_model = sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords(m.cond_stage_model, self)
  130. elif type(m.cond_stage_model) == ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder:
  131. m.cond_stage_model.model.token_embedding = EmbeddingsWithFixes(m.cond_stage_model.model.token_embedding, self)
  132. m.cond_stage_model = sd_hijack_open_clip.FrozenOpenCLIPEmbedderWithCustomWords(m.cond_stage_model, self)
  133. apply_weighted_forward(m)
  134. if m.cond_stage_key == "edit":
  135. sd_hijack_unet.hijack_ddpm_edit()
  136. self.optimization_method = apply_optimizations()
  137. self.clip = m.cond_stage_model
  138. def flatten(el):
  139. flattened = [flatten(children) for children in el.children()]
  140. res = [el]
  141. for c in flattened:
  142. res += c
  143. return res
  144. self.layers = flatten(m)
  145. def undo_hijack(self, m):
  146. if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation:
  147. m.cond_stage_model = m.cond_stage_model.wrapped
  148. elif type(m.cond_stage_model) == sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords:
  149. m.cond_stage_model = m.cond_stage_model.wrapped
  150. model_embeddings = m.cond_stage_model.transformer.text_model.embeddings
  151. if type(model_embeddings.token_embedding) == EmbeddingsWithFixes:
  152. model_embeddings.token_embedding = model_embeddings.token_embedding.wrapped
  153. elif type(m.cond_stage_model) == sd_hijack_open_clip.FrozenOpenCLIPEmbedderWithCustomWords:
  154. m.cond_stage_model.wrapped.model.token_embedding = m.cond_stage_model.wrapped.model.token_embedding.wrapped
  155. m.cond_stage_model = m.cond_stage_model.wrapped
  156. undo_optimizations()
  157. undo_weighted_forward(m)
  158. self.apply_circular(False)
  159. self.layers = None
  160. self.clip = None
  161. def apply_circular(self, enable):
  162. if self.circular_enabled == enable:
  163. return
  164. self.circular_enabled = enable
  165. for layer in [layer for layer in self.layers if type(layer) == torch.nn.Conv2d]:
  166. layer.padding_mode = 'circular' if enable else 'zeros'
  167. def clear_comments(self):
  168. self.comments = []
  169. def get_prompt_lengths(self, text):
  170. _, token_count = self.clip.process_texts([text])
  171. return token_count, self.clip.get_target_prompt_token_count(token_count)
  172. class EmbeddingsWithFixes(torch.nn.Module):
  173. def __init__(self, wrapped, embeddings):
  174. super().__init__()
  175. self.wrapped = wrapped
  176. self.embeddings = embeddings
  177. def forward(self, input_ids):
  178. batch_fixes = self.embeddings.fixes
  179. self.embeddings.fixes = None
  180. inputs_embeds = self.wrapped(input_ids)
  181. if batch_fixes is None or len(batch_fixes) == 0 or max([len(x) for x in batch_fixes]) == 0:
  182. return inputs_embeds
  183. vecs = []
  184. for fixes, tensor in zip(batch_fixes, inputs_embeds):
  185. for offset, embedding in fixes:
  186. emb = devices.cond_cast_unet(embedding.vec)
  187. emb_len = min(tensor.shape[0] - offset - 1, emb.shape[0])
  188. tensor = torch.cat([tensor[0:offset + 1], emb[0:emb_len], tensor[offset + 1 + emb_len:]])
  189. vecs.append(tensor)
  190. return torch.stack(vecs)
  191. def add_circular_option_to_conv_2d():
  192. conv2d_constructor = torch.nn.Conv2d.__init__
  193. def conv2d_constructor_circular(self, *args, **kwargs):
  194. return conv2d_constructor(self, *args, padding_mode='circular', **kwargs)
  195. torch.nn.Conv2d.__init__ = conv2d_constructor_circular
  196. model_hijack = StableDiffusionModelHijack()
  197. def register_buffer(self, name, attr):
  198. """
  199. Fix register buffer bug for Mac OS.
  200. """
  201. if type(attr) == torch.Tensor:
  202. if attr.device != devices.device:
  203. attr = attr.to(device=devices.device, dtype=(torch.float32 if devices.device.type == 'mps' else None))
  204. setattr(self, name, attr)
  205. ldm.models.diffusion.ddim.DDIMSampler.register_buffer = register_buffer
  206. ldm.models.diffusion.plms.PLMSSampler.register_buffer = register_buffer