sd_hijack.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352
  1. import torch
  2. from torch.nn.functional import silu
  3. from types import MethodType
  4. from modules import devices, sd_hijack_optimizations, shared, script_callbacks, errors, sd_unet
  5. from modules.hypernetworks import hypernetwork
  6. from modules.shared import cmd_opts
  7. from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr, sd_hijack_inpainting
  8. import ldm.modules.attention
  9. import ldm.modules.diffusionmodules.model
  10. import ldm.modules.diffusionmodules.openaimodel
  11. import ldm.models.diffusion.ddim
  12. import ldm.models.diffusion.plms
  13. import ldm.modules.encoders.modules
  14. import sgm.modules.attention
  15. import sgm.modules.diffusionmodules.model
  16. import sgm.modules.diffusionmodules.openaimodel
  17. import sgm.modules.encoders.modules
  18. attention_CrossAttention_forward = ldm.modules.attention.CrossAttention.forward
  19. diffusionmodules_model_nonlinearity = ldm.modules.diffusionmodules.model.nonlinearity
  20. diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.AttnBlock.forward
  21. # new memory efficient cross attention blocks do not support hypernets and we already
  22. # have memory efficient cross attention anyway, so this disables SD2.0's memory efficient cross attention
  23. ldm.modules.attention.MemoryEfficientCrossAttention = ldm.modules.attention.CrossAttention
  24. ldm.modules.attention.BasicTransformerBlock.ATTENTION_MODES["softmax-xformers"] = ldm.modules.attention.CrossAttention
  25. # silence new console spam from SD2
  26. ldm.modules.attention.print = shared.ldm_print
  27. ldm.modules.diffusionmodules.model.print = shared.ldm_print
  28. ldm.util.print = shared.ldm_print
  29. ldm.models.diffusion.ddpm.print = shared.ldm_print
  30. sd_hijack_inpainting.do_inpainting_hijack()
  31. optimizers = []
  32. current_optimizer: sd_hijack_optimizations.SdOptimization = None
  33. def list_optimizers():
  34. new_optimizers = script_callbacks.list_optimizers_callback()
  35. new_optimizers = [x for x in new_optimizers if x.is_available()]
  36. new_optimizers = sorted(new_optimizers, key=lambda x: x.priority, reverse=True)
  37. optimizers.clear()
  38. optimizers.extend(new_optimizers)
  39. def apply_optimizations(option=None):
  40. global current_optimizer
  41. undo_optimizations()
  42. if len(optimizers) == 0:
  43. # a script can access the model very early, and optimizations would not be filled by then
  44. current_optimizer = None
  45. return ''
  46. ldm.modules.diffusionmodules.model.nonlinearity = silu
  47. ldm.modules.diffusionmodules.openaimodel.th = sd_hijack_unet.th
  48. sgm.modules.diffusionmodules.model.nonlinearity = silu
  49. sgm.modules.diffusionmodules.openaimodel.th = sd_hijack_unet.th
  50. if current_optimizer is not None:
  51. current_optimizer.undo()
  52. current_optimizer = None
  53. selection = option or shared.opts.cross_attention_optimization
  54. if selection == "Automatic" and len(optimizers) > 0:
  55. matching_optimizer = next(iter([x for x in optimizers if x.cmd_opt and getattr(shared.cmd_opts, x.cmd_opt, False)]), optimizers[0])
  56. else:
  57. matching_optimizer = next(iter([x for x in optimizers if x.title() == selection]), None)
  58. if selection == "None":
  59. matching_optimizer = None
  60. elif selection == "Automatic" and shared.cmd_opts.disable_opt_split_attention:
  61. matching_optimizer = None
  62. elif matching_optimizer is None:
  63. matching_optimizer = optimizers[0]
  64. if matching_optimizer is not None:
  65. print(f"Applying attention optimization: {matching_optimizer.name}... ", end='')
  66. matching_optimizer.apply()
  67. print("done.")
  68. current_optimizer = matching_optimizer
  69. return current_optimizer.name
  70. else:
  71. print("Disabling attention optimization")
  72. return ''
  73. def undo_optimizations():
  74. ldm.modules.diffusionmodules.model.nonlinearity = diffusionmodules_model_nonlinearity
  75. ldm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward
  76. ldm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward
  77. sgm.modules.diffusionmodules.model.nonlinearity = diffusionmodules_model_nonlinearity
  78. sgm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward
  79. sgm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward
  80. def fix_checkpoint():
  81. """checkpoints are now added and removed in embedding/hypernet code, since torch doesn't want
  82. checkpoints to be added when not training (there's a warning)"""
  83. pass
  84. def weighted_loss(sd_model, pred, target, mean=True):
  85. #Calculate the weight normally, but ignore the mean
  86. loss = sd_model._old_get_loss(pred, target, mean=False)
  87. #Check if we have weights available
  88. weight = getattr(sd_model, '_custom_loss_weight', None)
  89. if weight is not None:
  90. loss *= weight
  91. #Return the loss, as mean if specified
  92. return loss.mean() if mean else loss
  93. def weighted_forward(sd_model, x, c, w, *args, **kwargs):
  94. try:
  95. #Temporarily append weights to a place accessible during loss calc
  96. sd_model._custom_loss_weight = w
  97. #Replace 'get_loss' with a weight-aware one. Otherwise we need to reimplement 'forward' completely
  98. #Keep 'get_loss', but don't overwrite the previous old_get_loss if it's already set
  99. if not hasattr(sd_model, '_old_get_loss'):
  100. sd_model._old_get_loss = sd_model.get_loss
  101. sd_model.get_loss = MethodType(weighted_loss, sd_model)
  102. #Run the standard forward function, but with the patched 'get_loss'
  103. return sd_model.forward(x, c, *args, **kwargs)
  104. finally:
  105. try:
  106. #Delete temporary weights if appended
  107. del sd_model._custom_loss_weight
  108. except AttributeError:
  109. pass
  110. #If we have an old loss function, reset the loss function to the original one
  111. if hasattr(sd_model, '_old_get_loss'):
  112. sd_model.get_loss = sd_model._old_get_loss
  113. del sd_model._old_get_loss
  114. def apply_weighted_forward(sd_model):
  115. #Add new function 'weighted_forward' that can be called to calc weighted loss
  116. sd_model.weighted_forward = MethodType(weighted_forward, sd_model)
  117. def undo_weighted_forward(sd_model):
  118. try:
  119. del sd_model.weighted_forward
  120. except AttributeError:
  121. pass
  122. class StableDiffusionModelHijack:
  123. fixes = None
  124. layers = None
  125. circular_enabled = False
  126. clip = None
  127. optimization_method = None
  128. def __init__(self):
  129. import modules.textual_inversion.textual_inversion
  130. self.extra_generation_params = {}
  131. self.comments = []
  132. self.embedding_db = modules.textual_inversion.textual_inversion.EmbeddingDatabase()
  133. self.embedding_db.add_embedding_dir(cmd_opts.embeddings_dir)
  134. def apply_optimizations(self, option=None):
  135. try:
  136. self.optimization_method = apply_optimizations(option)
  137. except Exception as e:
  138. errors.display(e, "applying cross attention optimization")
  139. undo_optimizations()
  140. def hijack(self, m):
  141. conditioner = getattr(m, 'conditioner', None)
  142. if conditioner:
  143. text_cond_models = []
  144. for i in range(len(conditioner.embedders)):
  145. embedder = conditioner.embedders[i]
  146. typename = type(embedder).__name__
  147. if typename == 'FrozenOpenCLIPEmbedder':
  148. embedder.model.token_embedding = EmbeddingsWithFixes(embedder.model.token_embedding, self)
  149. conditioner.embedders[i] = sd_hijack_open_clip.FrozenOpenCLIPEmbedderWithCustomWords(embedder, self)
  150. text_cond_models.append(conditioner.embedders[i])
  151. if typename == 'FrozenCLIPEmbedder':
  152. model_embeddings = embedder.transformer.text_model.embeddings
  153. model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self)
  154. conditioner.embedders[i] = sd_hijack_clip.FrozenCLIPEmbedderForSDXLWithCustomWords(embedder, self)
  155. text_cond_models.append(conditioner.embedders[i])
  156. if typename == 'FrozenOpenCLIPEmbedder2':
  157. embedder.model.token_embedding = EmbeddingsWithFixes(embedder.model.token_embedding, self, textual_inversion_key='clip_g')
  158. conditioner.embedders[i] = sd_hijack_open_clip.FrozenOpenCLIPEmbedder2WithCustomWords(embedder, self)
  159. text_cond_models.append(conditioner.embedders[i])
  160. if len(text_cond_models) == 1:
  161. m.cond_stage_model = text_cond_models[0]
  162. else:
  163. m.cond_stage_model = conditioner
  164. if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation:
  165. model_embeddings = m.cond_stage_model.roberta.embeddings
  166. model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.word_embeddings, self)
  167. m.cond_stage_model = sd_hijack_xlmr.FrozenXLMREmbedderWithCustomWords(m.cond_stage_model, self)
  168. elif type(m.cond_stage_model) == ldm.modules.encoders.modules.FrozenCLIPEmbedder:
  169. model_embeddings = m.cond_stage_model.transformer.text_model.embeddings
  170. model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self)
  171. m.cond_stage_model = sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords(m.cond_stage_model, self)
  172. elif type(m.cond_stage_model) == ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder:
  173. m.cond_stage_model.model.token_embedding = EmbeddingsWithFixes(m.cond_stage_model.model.token_embedding, self)
  174. m.cond_stage_model = sd_hijack_open_clip.FrozenOpenCLIPEmbedderWithCustomWords(m.cond_stage_model, self)
  175. apply_weighted_forward(m)
  176. if m.cond_stage_key == "edit":
  177. sd_hijack_unet.hijack_ddpm_edit()
  178. self.apply_optimizations()
  179. self.clip = m.cond_stage_model
  180. def flatten(el):
  181. flattened = [flatten(children) for children in el.children()]
  182. res = [el]
  183. for c in flattened:
  184. res += c
  185. return res
  186. self.layers = flatten(m)
  187. if not hasattr(ldm.modules.diffusionmodules.openaimodel, 'copy_of_UNetModel_forward_for_webui'):
  188. ldm.modules.diffusionmodules.openaimodel.copy_of_UNetModel_forward_for_webui = ldm.modules.diffusionmodules.openaimodel.UNetModel.forward
  189. ldm.modules.diffusionmodules.openaimodel.UNetModel.forward = sd_unet.UNetModel_forward
  190. def undo_hijack(self, m):
  191. if type(m.cond_stage_model) == sd_hijack_xlmr.FrozenXLMREmbedderWithCustomWords:
  192. m.cond_stage_model = m.cond_stage_model.wrapped
  193. elif type(m.cond_stage_model) == sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords:
  194. m.cond_stage_model = m.cond_stage_model.wrapped
  195. model_embeddings = m.cond_stage_model.transformer.text_model.embeddings
  196. if type(model_embeddings.token_embedding) == EmbeddingsWithFixes:
  197. model_embeddings.token_embedding = model_embeddings.token_embedding.wrapped
  198. elif type(m.cond_stage_model) == sd_hijack_open_clip.FrozenOpenCLIPEmbedderWithCustomWords:
  199. m.cond_stage_model.wrapped.model.token_embedding = m.cond_stage_model.wrapped.model.token_embedding.wrapped
  200. m.cond_stage_model = m.cond_stage_model.wrapped
  201. undo_optimizations()
  202. undo_weighted_forward(m)
  203. self.apply_circular(False)
  204. self.layers = None
  205. self.clip = None
  206. ldm.modules.diffusionmodules.openaimodel.UNetModel.forward = ldm.modules.diffusionmodules.openaimodel.copy_of_UNetModel_forward_for_webui
  207. def apply_circular(self, enable):
  208. if self.circular_enabled == enable:
  209. return
  210. self.circular_enabled = enable
  211. for layer in [layer for layer in self.layers if type(layer) == torch.nn.Conv2d]:
  212. layer.padding_mode = 'circular' if enable else 'zeros'
  213. def clear_comments(self):
  214. self.comments = []
  215. self.extra_generation_params = {}
  216. def get_prompt_lengths(self, text):
  217. if self.clip is None:
  218. return "-", "-"
  219. _, token_count = self.clip.process_texts([text])
  220. return token_count, self.clip.get_target_prompt_token_count(token_count)
  221. def redo_hijack(self, m):
  222. self.undo_hijack(m)
  223. self.hijack(m)
  224. class EmbeddingsWithFixes(torch.nn.Module):
  225. def __init__(self, wrapped, embeddings, textual_inversion_key='clip_l'):
  226. super().__init__()
  227. self.wrapped = wrapped
  228. self.embeddings = embeddings
  229. self.textual_inversion_key = textual_inversion_key
  230. def forward(self, input_ids):
  231. batch_fixes = self.embeddings.fixes
  232. self.embeddings.fixes = None
  233. inputs_embeds = self.wrapped(input_ids)
  234. if batch_fixes is None or len(batch_fixes) == 0 or max([len(x) for x in batch_fixes]) == 0:
  235. return inputs_embeds
  236. vecs = []
  237. for fixes, tensor in zip(batch_fixes, inputs_embeds):
  238. for offset, embedding in fixes:
  239. vec = embedding.vec[self.textual_inversion_key] if isinstance(embedding.vec, dict) else embedding.vec
  240. emb = devices.cond_cast_unet(vec)
  241. emb_len = min(tensor.shape[0] - offset - 1, emb.shape[0])
  242. tensor = torch.cat([tensor[0:offset + 1], emb[0:emb_len], tensor[offset + 1 + emb_len:]])
  243. vecs.append(tensor)
  244. return torch.stack(vecs)
  245. def add_circular_option_to_conv_2d():
  246. conv2d_constructor = torch.nn.Conv2d.__init__
  247. def conv2d_constructor_circular(self, *args, **kwargs):
  248. return conv2d_constructor(self, *args, padding_mode='circular', **kwargs)
  249. torch.nn.Conv2d.__init__ = conv2d_constructor_circular
  250. model_hijack = StableDiffusionModelHijack()
  251. def register_buffer(self, name, attr):
  252. """
  253. Fix register buffer bug for Mac OS.
  254. """
  255. if type(attr) == torch.Tensor:
  256. if attr.device != devices.device:
  257. attr = attr.to(device=devices.device, dtype=(torch.float32 if devices.device.type == 'mps' else None))
  258. setattr(self, name, attr)
  259. ldm.models.diffusion.ddim.DDIMSampler.register_buffer = register_buffer
  260. ldm.models.diffusion.plms.PLMSSampler.register_buffer = register_buffer