sd_hijack.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305
  1. import torch
  2. from torch.nn.functional import silu
  3. from types import MethodType
  4. import modules.textual_inversion.textual_inversion
  5. from modules import devices, sd_hijack_optimizations, shared, script_callbacks, errors, sd_unet
  6. from modules.hypernetworks import hypernetwork
  7. from modules.shared import cmd_opts
  8. from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr
  9. import ldm.modules.attention
  10. import ldm.modules.diffusionmodules.model
  11. import ldm.modules.diffusionmodules.openaimodel
  12. import ldm.models.diffusion.ddim
  13. import ldm.models.diffusion.plms
  14. import ldm.modules.encoders.modules
  15. attention_CrossAttention_forward = ldm.modules.attention.CrossAttention.forward
  16. diffusionmodules_model_nonlinearity = ldm.modules.diffusionmodules.model.nonlinearity
  17. diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.AttnBlock.forward
  18. # new memory efficient cross attention blocks do not support hypernets and we already
  19. # have memory efficient cross attention anyway, so this disables SD2.0's memory efficient cross attention
  20. ldm.modules.attention.MemoryEfficientCrossAttention = ldm.modules.attention.CrossAttention
  21. ldm.modules.attention.BasicTransformerBlock.ATTENTION_MODES["softmax-xformers"] = ldm.modules.attention.CrossAttention
  22. # silence new console spam from SD2
  23. ldm.modules.attention.print = lambda *args: None
  24. ldm.modules.diffusionmodules.model.print = lambda *args: None
  25. optimizers = []
  26. current_optimizer: sd_hijack_optimizations.SdOptimization = None
  27. def list_optimizers():
  28. new_optimizers = script_callbacks.list_optimizers_callback()
  29. new_optimizers = [x for x in new_optimizers if x.is_available()]
  30. new_optimizers = sorted(new_optimizers, key=lambda x: x.priority, reverse=True)
  31. optimizers.clear()
  32. optimizers.extend(new_optimizers)
  33. def apply_optimizations(option=None):
  34. global current_optimizer
  35. undo_optimizations()
  36. if len(optimizers) == 0:
  37. # a script can access the model very early, and optimizations would not be filled by then
  38. current_optimizer = None
  39. return ''
  40. ldm.modules.diffusionmodules.model.nonlinearity = silu
  41. ldm.modules.diffusionmodules.openaimodel.th = sd_hijack_unet.th
  42. if current_optimizer is not None:
  43. current_optimizer.undo()
  44. current_optimizer = None
  45. selection = option or shared.opts.cross_attention_optimization
  46. if selection == "Automatic" and len(optimizers) > 0:
  47. matching_optimizer = next(iter([x for x in optimizers if x.cmd_opt and getattr(shared.cmd_opts, x.cmd_opt, False)]), optimizers[0])
  48. else:
  49. matching_optimizer = next(iter([x for x in optimizers if x.title() == selection]), None)
  50. if selection == "None":
  51. matching_optimizer = None
  52. elif selection == "Automatic" and shared.cmd_opts.disable_opt_split_attention:
  53. matching_optimizer = None
  54. elif matching_optimizer is None:
  55. matching_optimizer = optimizers[0]
  56. if matching_optimizer is not None:
  57. print(f"Applying attention optimization: {matching_optimizer.name}... ", end='')
  58. matching_optimizer.apply()
  59. print("done.")
  60. current_optimizer = matching_optimizer
  61. return current_optimizer.name
  62. else:
  63. print("Disabling attention optimization")
  64. return ''
  65. def undo_optimizations():
  66. ldm.modules.diffusionmodules.model.nonlinearity = diffusionmodules_model_nonlinearity
  67. ldm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward
  68. ldm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward
  69. def fix_checkpoint():
  70. """checkpoints are now added and removed in embedding/hypernet code, since torch doesn't want
  71. checkpoints to be added when not training (there's a warning)"""
  72. pass
  73. def weighted_loss(sd_model, pred, target, mean=True):
  74. #Calculate the weight normally, but ignore the mean
  75. loss = sd_model._old_get_loss(pred, target, mean=False)
  76. #Check if we have weights available
  77. weight = getattr(sd_model, '_custom_loss_weight', None)
  78. if weight is not None:
  79. loss *= weight
  80. #Return the loss, as mean if specified
  81. return loss.mean() if mean else loss
  82. def weighted_forward(sd_model, x, c, w, *args, **kwargs):
  83. try:
  84. #Temporarily append weights to a place accessible during loss calc
  85. sd_model._custom_loss_weight = w
  86. #Replace 'get_loss' with a weight-aware one. Otherwise we need to reimplement 'forward' completely
  87. #Keep 'get_loss', but don't overwrite the previous old_get_loss if it's already set
  88. if not hasattr(sd_model, '_old_get_loss'):
  89. sd_model._old_get_loss = sd_model.get_loss
  90. sd_model.get_loss = MethodType(weighted_loss, sd_model)
  91. #Run the standard forward function, but with the patched 'get_loss'
  92. return sd_model.forward(x, c, *args, **kwargs)
  93. finally:
  94. try:
  95. #Delete temporary weights if appended
  96. del sd_model._custom_loss_weight
  97. except AttributeError:
  98. pass
  99. #If we have an old loss function, reset the loss function to the original one
  100. if hasattr(sd_model, '_old_get_loss'):
  101. sd_model.get_loss = sd_model._old_get_loss
  102. del sd_model._old_get_loss
  103. def apply_weighted_forward(sd_model):
  104. #Add new function 'weighted_forward' that can be called to calc weighted loss
  105. sd_model.weighted_forward = MethodType(weighted_forward, sd_model)
  106. def undo_weighted_forward(sd_model):
  107. try:
  108. del sd_model.weighted_forward
  109. except AttributeError:
  110. pass
  111. class StableDiffusionModelHijack:
  112. fixes = None
  113. comments = []
  114. layers = None
  115. circular_enabled = False
  116. clip = None
  117. optimization_method = None
  118. embedding_db = modules.textual_inversion.textual_inversion.EmbeddingDatabase()
  119. def __init__(self):
  120. self.embedding_db.add_embedding_dir(cmd_opts.embeddings_dir)
  121. def apply_optimizations(self, option=None):
  122. try:
  123. self.optimization_method = apply_optimizations(option)
  124. except Exception as e:
  125. errors.display(e, "applying cross attention optimization")
  126. undo_optimizations()
  127. def hijack(self, m):
  128. if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation:
  129. model_embeddings = m.cond_stage_model.roberta.embeddings
  130. model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.word_embeddings, self)
  131. m.cond_stage_model = sd_hijack_xlmr.FrozenXLMREmbedderWithCustomWords(m.cond_stage_model, self)
  132. elif type(m.cond_stage_model) == ldm.modules.encoders.modules.FrozenCLIPEmbedder:
  133. model_embeddings = m.cond_stage_model.transformer.text_model.embeddings
  134. model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self)
  135. m.cond_stage_model = sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords(m.cond_stage_model, self)
  136. elif type(m.cond_stage_model) == ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder:
  137. m.cond_stage_model.model.token_embedding = EmbeddingsWithFixes(m.cond_stage_model.model.token_embedding, self)
  138. m.cond_stage_model = sd_hijack_open_clip.FrozenOpenCLIPEmbedderWithCustomWords(m.cond_stage_model, self)
  139. apply_weighted_forward(m)
  140. if m.cond_stage_key == "edit":
  141. sd_hijack_unet.hijack_ddpm_edit()
  142. self.apply_optimizations()
  143. self.clip = m.cond_stage_model
  144. def flatten(el):
  145. flattened = [flatten(children) for children in el.children()]
  146. res = [el]
  147. for c in flattened:
  148. res += c
  149. return res
  150. self.layers = flatten(m)
  151. if not hasattr(ldm.modules.diffusionmodules.openaimodel, 'copy_of_UNetModel_forward_for_webui'):
  152. ldm.modules.diffusionmodules.openaimodel.copy_of_UNetModel_forward_for_webui = ldm.modules.diffusionmodules.openaimodel.UNetModel.forward
  153. ldm.modules.diffusionmodules.openaimodel.UNetModel.forward = sd_unet.UNetModel_forward
  154. def undo_hijack(self, m):
  155. if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation:
  156. m.cond_stage_model = m.cond_stage_model.wrapped
  157. elif type(m.cond_stage_model) == sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords:
  158. m.cond_stage_model = m.cond_stage_model.wrapped
  159. model_embeddings = m.cond_stage_model.transformer.text_model.embeddings
  160. if type(model_embeddings.token_embedding) == EmbeddingsWithFixes:
  161. model_embeddings.token_embedding = model_embeddings.token_embedding.wrapped
  162. elif type(m.cond_stage_model) == sd_hijack_open_clip.FrozenOpenCLIPEmbedderWithCustomWords:
  163. m.cond_stage_model.wrapped.model.token_embedding = m.cond_stage_model.wrapped.model.token_embedding.wrapped
  164. m.cond_stage_model = m.cond_stage_model.wrapped
  165. undo_optimizations()
  166. undo_weighted_forward(m)
  167. self.apply_circular(False)
  168. self.layers = None
  169. self.clip = None
  170. ldm.modules.diffusionmodules.openaimodel.UNetModel.forward = ldm.modules.diffusionmodules.openaimodel.copy_of_UNetModel_forward_for_webui
  171. def apply_circular(self, enable):
  172. if self.circular_enabled == enable:
  173. return
  174. self.circular_enabled = enable
  175. for layer in [layer for layer in self.layers if type(layer) == torch.nn.Conv2d]:
  176. layer.padding_mode = 'circular' if enable else 'zeros'
  177. def clear_comments(self):
  178. self.comments = []
  179. def get_prompt_lengths(self, text):
  180. if self.clip is None:
  181. return "-", "-"
  182. _, token_count = self.clip.process_texts([text])
  183. return token_count, self.clip.get_target_prompt_token_count(token_count)
  184. def redo_hijack(self, m):
  185. self.undo_hijack(m)
  186. self.hijack(m)
  187. class EmbeddingsWithFixes(torch.nn.Module):
  188. def __init__(self, wrapped, embeddings):
  189. super().__init__()
  190. self.wrapped = wrapped
  191. self.embeddings = embeddings
  192. def forward(self, input_ids):
  193. batch_fixes = self.embeddings.fixes
  194. self.embeddings.fixes = None
  195. inputs_embeds = self.wrapped(input_ids)
  196. if batch_fixes is None or len(batch_fixes) == 0 or max([len(x) for x in batch_fixes]) == 0:
  197. return inputs_embeds
  198. vecs = []
  199. for fixes, tensor in zip(batch_fixes, inputs_embeds):
  200. for offset, embedding in fixes:
  201. emb = devices.cond_cast_unet(embedding.vec)
  202. emb_len = min(tensor.shape[0] - offset - 1, emb.shape[0])
  203. tensor = torch.cat([tensor[0:offset + 1], emb[0:emb_len], tensor[offset + 1 + emb_len:]])
  204. vecs.append(tensor)
  205. return torch.stack(vecs)
  206. def add_circular_option_to_conv_2d():
  207. conv2d_constructor = torch.nn.Conv2d.__init__
  208. def conv2d_constructor_circular(self, *args, **kwargs):
  209. return conv2d_constructor(self, *args, padding_mode='circular', **kwargs)
  210. torch.nn.Conv2d.__init__ = conv2d_constructor_circular
  211. model_hijack = StableDiffusionModelHijack()
  212. def register_buffer(self, name, attr):
  213. """
  214. Fix register buffer bug for Mac OS.
  215. """
  216. if type(attr) == torch.Tensor:
  217. if attr.device != devices.device:
  218. attr = attr.to(device=devices.device, dtype=(torch.float32 if devices.device.type == 'mps' else None))
  219. setattr(self, name, attr)
  220. ldm.models.diffusion.ddim.DDIMSampler.register_buffer = register_buffer
  221. ldm.models.diffusion.plms.PLMSSampler.register_buffer = register_buffer