sd_hijack.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364
  1. import torch
  2. from torch.nn.functional import silu
  3. from types import MethodType
  4. from modules import devices, sd_hijack_optimizations, shared, script_callbacks, errors, sd_unet
  5. from modules.hypernetworks import hypernetwork
  6. from modules.shared import cmd_opts
  7. from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr
  8. import ldm.modules.attention
  9. import ldm.modules.diffusionmodules.model
  10. import ldm.modules.diffusionmodules.openaimodel
  11. import ldm.models.diffusion.ddim
  12. import ldm.models.diffusion.plms
  13. import ldm.modules.encoders.modules
  14. import sgm.modules.attention
  15. import sgm.modules.diffusionmodules.model
  16. import sgm.modules.diffusionmodules.openaimodel
  17. import sgm.modules.encoders.modules
  18. attention_CrossAttention_forward = ldm.modules.attention.CrossAttention.forward
  19. diffusionmodules_model_nonlinearity = ldm.modules.diffusionmodules.model.nonlinearity
  20. diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.AttnBlock.forward
  21. # new memory efficient cross attention blocks do not support hypernets and we already
  22. # have memory efficient cross attention anyway, so this disables SD2.0's memory efficient cross attention
  23. ldm.modules.attention.MemoryEfficientCrossAttention = ldm.modules.attention.CrossAttention
  24. ldm.modules.attention.BasicTransformerBlock.ATTENTION_MODES["softmax-xformers"] = ldm.modules.attention.CrossAttention
  25. # silence new console spam from SD2
  26. ldm.modules.attention.print = shared.ldm_print
  27. ldm.modules.diffusionmodules.model.print = shared.ldm_print
  28. ldm.util.print = shared.ldm_print
  29. ldm.models.diffusion.ddpm.print = shared.ldm_print
  30. optimizers = []
  31. current_optimizer: sd_hijack_optimizations.SdOptimization = None
  32. def list_optimizers():
  33. new_optimizers = script_callbacks.list_optimizers_callback()
  34. new_optimizers = [x for x in new_optimizers if x.is_available()]
  35. new_optimizers = sorted(new_optimizers, key=lambda x: x.priority, reverse=True)
  36. optimizers.clear()
  37. optimizers.extend(new_optimizers)
  38. def apply_optimizations(option=None):
  39. global current_optimizer
  40. undo_optimizations()
  41. if len(optimizers) == 0:
  42. # a script can access the model very early, and optimizations would not be filled by then
  43. current_optimizer = None
  44. return ''
  45. ldm.modules.diffusionmodules.model.nonlinearity = silu
  46. ldm.modules.diffusionmodules.openaimodel.th = sd_hijack_unet.th
  47. sgm.modules.diffusionmodules.model.nonlinearity = silu
  48. sgm.modules.diffusionmodules.openaimodel.th = sd_hijack_unet.th
  49. if current_optimizer is not None:
  50. current_optimizer.undo()
  51. current_optimizer = None
  52. selection = option or shared.opts.cross_attention_optimization
  53. if selection == "Automatic" and len(optimizers) > 0:
  54. matching_optimizer = next(iter([x for x in optimizers if x.cmd_opt and getattr(shared.cmd_opts, x.cmd_opt, False)]), optimizers[0])
  55. else:
  56. matching_optimizer = next(iter([x for x in optimizers if x.title() == selection]), None)
  57. if selection == "None":
  58. matching_optimizer = None
  59. elif selection == "Automatic" and shared.cmd_opts.disable_opt_split_attention:
  60. matching_optimizer = None
  61. elif matching_optimizer is None:
  62. matching_optimizer = optimizers[0]
  63. if matching_optimizer is not None:
  64. print(f"Applying attention optimization: {matching_optimizer.name}... ", end='')
  65. matching_optimizer.apply()
  66. print("done.")
  67. current_optimizer = matching_optimizer
  68. return current_optimizer.name
  69. else:
  70. print("Disabling attention optimization")
  71. return ''
  72. def undo_optimizations():
  73. ldm.modules.diffusionmodules.model.nonlinearity = diffusionmodules_model_nonlinearity
  74. ldm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward
  75. ldm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward
  76. sgm.modules.diffusionmodules.model.nonlinearity = diffusionmodules_model_nonlinearity
  77. sgm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward
  78. sgm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward
  79. def fix_checkpoint():
  80. """checkpoints are now added and removed in embedding/hypernet code, since torch doesn't want
  81. checkpoints to be added when not training (there's a warning)"""
  82. pass
  83. def weighted_loss(sd_model, pred, target, mean=True):
  84. #Calculate the weight normally, but ignore the mean
  85. loss = sd_model._old_get_loss(pred, target, mean=False)
  86. #Check if we have weights available
  87. weight = getattr(sd_model, '_custom_loss_weight', None)
  88. if weight is not None:
  89. loss *= weight
  90. #Return the loss, as mean if specified
  91. return loss.mean() if mean else loss
  92. def weighted_forward(sd_model, x, c, w, *args, **kwargs):
  93. try:
  94. #Temporarily append weights to a place accessible during loss calc
  95. sd_model._custom_loss_weight = w
  96. #Replace 'get_loss' with a weight-aware one. Otherwise we need to reimplement 'forward' completely
  97. #Keep 'get_loss', but don't overwrite the previous old_get_loss if it's already set
  98. if not hasattr(sd_model, '_old_get_loss'):
  99. sd_model._old_get_loss = sd_model.get_loss
  100. sd_model.get_loss = MethodType(weighted_loss, sd_model)
  101. #Run the standard forward function, but with the patched 'get_loss'
  102. return sd_model.forward(x, c, *args, **kwargs)
  103. finally:
  104. try:
  105. #Delete temporary weights if appended
  106. del sd_model._custom_loss_weight
  107. except AttributeError:
  108. pass
  109. #If we have an old loss function, reset the loss function to the original one
  110. if hasattr(sd_model, '_old_get_loss'):
  111. sd_model.get_loss = sd_model._old_get_loss
  112. del sd_model._old_get_loss
  113. def apply_weighted_forward(sd_model):
  114. #Add new function 'weighted_forward' that can be called to calc weighted loss
  115. sd_model.weighted_forward = MethodType(weighted_forward, sd_model)
  116. def undo_weighted_forward(sd_model):
  117. try:
  118. del sd_model.weighted_forward
  119. except AttributeError:
  120. pass
  121. class StableDiffusionModelHijack:
  122. fixes = None
  123. layers = None
  124. circular_enabled = False
  125. clip = None
  126. optimization_method = None
  127. def __init__(self):
  128. import modules.textual_inversion.textual_inversion
  129. self.extra_generation_params = {}
  130. self.comments = []
  131. self.embedding_db = modules.textual_inversion.textual_inversion.EmbeddingDatabase()
  132. self.embedding_db.add_embedding_dir(cmd_opts.embeddings_dir)
  133. def apply_optimizations(self, option=None):
  134. try:
  135. self.optimization_method = apply_optimizations(option)
  136. except Exception as e:
  137. errors.display(e, "applying cross attention optimization")
  138. undo_optimizations()
  139. def hijack(self, m):
  140. conditioner = getattr(m, 'conditioner', None)
  141. if conditioner:
  142. text_cond_models = []
  143. for i in range(len(conditioner.embedders)):
  144. embedder = conditioner.embedders[i]
  145. typename = type(embedder).__name__
  146. if typename == 'FrozenOpenCLIPEmbedder':
  147. embedder.model.token_embedding = EmbeddingsWithFixes(embedder.model.token_embedding, self)
  148. conditioner.embedders[i] = sd_hijack_open_clip.FrozenOpenCLIPEmbedderWithCustomWords(embedder, self)
  149. text_cond_models.append(conditioner.embedders[i])
  150. if typename == 'FrozenCLIPEmbedder':
  151. model_embeddings = embedder.transformer.text_model.embeddings
  152. model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self)
  153. conditioner.embedders[i] = sd_hijack_clip.FrozenCLIPEmbedderForSDXLWithCustomWords(embedder, self)
  154. text_cond_models.append(conditioner.embedders[i])
  155. if typename == 'FrozenOpenCLIPEmbedder2':
  156. embedder.model.token_embedding = EmbeddingsWithFixes(embedder.model.token_embedding, self, textual_inversion_key='clip_g')
  157. conditioner.embedders[i] = sd_hijack_open_clip.FrozenOpenCLIPEmbedder2WithCustomWords(embedder, self)
  158. text_cond_models.append(conditioner.embedders[i])
  159. if len(text_cond_models) == 1:
  160. m.cond_stage_model = text_cond_models[0]
  161. else:
  162. m.cond_stage_model = conditioner
  163. if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation:
  164. model_embeddings = m.cond_stage_model.roberta.embeddings
  165. model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.word_embeddings, self)
  166. m.cond_stage_model = sd_hijack_xlmr.FrozenXLMREmbedderWithCustomWords(m.cond_stage_model, self)
  167. elif type(m.cond_stage_model) == ldm.modules.encoders.modules.FrozenCLIPEmbedder:
  168. model_embeddings = m.cond_stage_model.transformer.text_model.embeddings
  169. model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self)
  170. m.cond_stage_model = sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords(m.cond_stage_model, self)
  171. elif type(m.cond_stage_model) == ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder:
  172. m.cond_stage_model.model.token_embedding = EmbeddingsWithFixes(m.cond_stage_model.model.token_embedding, self)
  173. m.cond_stage_model = sd_hijack_open_clip.FrozenOpenCLIPEmbedderWithCustomWords(m.cond_stage_model, self)
  174. apply_weighted_forward(m)
  175. if m.cond_stage_key == "edit":
  176. sd_hijack_unet.hijack_ddpm_edit()
  177. self.apply_optimizations()
  178. self.clip = m.cond_stage_model
  179. def flatten(el):
  180. flattened = [flatten(children) for children in el.children()]
  181. res = [el]
  182. for c in flattened:
  183. res += c
  184. return res
  185. self.layers = flatten(m)
  186. if not hasattr(ldm.modules.diffusionmodules.openaimodel, 'copy_of_UNetModel_forward_for_webui'):
  187. ldm.modules.diffusionmodules.openaimodel.copy_of_UNetModel_forward_for_webui = ldm.modules.diffusionmodules.openaimodel.UNetModel.forward
  188. ldm.modules.diffusionmodules.openaimodel.UNetModel.forward = sd_unet.UNetModel_forward
  189. def undo_hijack(self, m):
  190. conditioner = getattr(m, 'conditioner', None)
  191. if conditioner:
  192. for i in range(len(conditioner.embedders)):
  193. embedder = conditioner.embedders[i]
  194. if isinstance(embedder, (sd_hijack_open_clip.FrozenOpenCLIPEmbedderWithCustomWords, sd_hijack_open_clip.FrozenOpenCLIPEmbedder2WithCustomWords)):
  195. embedder.wrapped.model.token_embedding = embedder.wrapped.model.token_embedding.wrapped
  196. conditioner.embedders[i] = embedder.wrapped
  197. if isinstance(embedder, sd_hijack_clip.FrozenCLIPEmbedderForSDXLWithCustomWords):
  198. embedder.wrapped.transformer.text_model.embeddings.token_embedding = embedder.wrapped.transformer.text_model.embeddings.token_embedding.wrapped
  199. conditioner.embedders[i] = embedder.wrapped
  200. if hasattr(m, 'cond_stage_model'):
  201. delattr(m, 'cond_stage_model')
  202. elif type(m.cond_stage_model) == sd_hijack_xlmr.FrozenXLMREmbedderWithCustomWords:
  203. m.cond_stage_model = m.cond_stage_model.wrapped
  204. elif type(m.cond_stage_model) == sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords:
  205. m.cond_stage_model = m.cond_stage_model.wrapped
  206. model_embeddings = m.cond_stage_model.transformer.text_model.embeddings
  207. if type(model_embeddings.token_embedding) == EmbeddingsWithFixes:
  208. model_embeddings.token_embedding = model_embeddings.token_embedding.wrapped
  209. elif type(m.cond_stage_model) == sd_hijack_open_clip.FrozenOpenCLIPEmbedderWithCustomWords:
  210. m.cond_stage_model.wrapped.model.token_embedding = m.cond_stage_model.wrapped.model.token_embedding.wrapped
  211. m.cond_stage_model = m.cond_stage_model.wrapped
  212. undo_optimizations()
  213. undo_weighted_forward(m)
  214. self.apply_circular(False)
  215. self.layers = None
  216. self.clip = None
  217. ldm.modules.diffusionmodules.openaimodel.UNetModel.forward = ldm.modules.diffusionmodules.openaimodel.copy_of_UNetModel_forward_for_webui
  218. def apply_circular(self, enable):
  219. if self.circular_enabled == enable:
  220. return
  221. self.circular_enabled = enable
  222. for layer in [layer for layer in self.layers if type(layer) == torch.nn.Conv2d]:
  223. layer.padding_mode = 'circular' if enable else 'zeros'
  224. def clear_comments(self):
  225. self.comments = []
  226. self.extra_generation_params = {}
  227. def get_prompt_lengths(self, text):
  228. if self.clip is None:
  229. return "-", "-"
  230. _, token_count = self.clip.process_texts([text])
  231. return token_count, self.clip.get_target_prompt_token_count(token_count)
  232. def redo_hijack(self, m):
  233. self.undo_hijack(m)
  234. self.hijack(m)
  235. class EmbeddingsWithFixes(torch.nn.Module):
  236. def __init__(self, wrapped, embeddings, textual_inversion_key='clip_l'):
  237. super().__init__()
  238. self.wrapped = wrapped
  239. self.embeddings = embeddings
  240. self.textual_inversion_key = textual_inversion_key
  241. def forward(self, input_ids):
  242. batch_fixes = self.embeddings.fixes
  243. self.embeddings.fixes = None
  244. inputs_embeds = self.wrapped(input_ids)
  245. if batch_fixes is None or len(batch_fixes) == 0 or max([len(x) for x in batch_fixes]) == 0:
  246. return inputs_embeds
  247. vecs = []
  248. for fixes, tensor in zip(batch_fixes, inputs_embeds):
  249. for offset, embedding in fixes:
  250. vec = embedding.vec[self.textual_inversion_key] if isinstance(embedding.vec, dict) else embedding.vec
  251. emb = devices.cond_cast_unet(vec)
  252. emb_len = min(tensor.shape[0] - offset - 1, emb.shape[0])
  253. tensor = torch.cat([tensor[0:offset + 1], emb[0:emb_len], tensor[offset + 1 + emb_len:]])
  254. vecs.append(tensor)
  255. return torch.stack(vecs)
  256. def add_circular_option_to_conv_2d():
  257. conv2d_constructor = torch.nn.Conv2d.__init__
  258. def conv2d_constructor_circular(self, *args, **kwargs):
  259. return conv2d_constructor(self, *args, padding_mode='circular', **kwargs)
  260. torch.nn.Conv2d.__init__ = conv2d_constructor_circular
  261. model_hijack = StableDiffusionModelHijack()
  262. def register_buffer(self, name, attr):
  263. """
  264. Fix register buffer bug for Mac OS.
  265. """
  266. if type(attr) == torch.Tensor:
  267. if attr.device != devices.device:
  268. attr = attr.to(device=devices.device, dtype=(torch.float32 if devices.device.type == 'mps' else None))
  269. setattr(self, name, attr)
  270. ldm.models.diffusion.ddim.DDIMSampler.register_buffer = register_buffer
  271. ldm.models.diffusion.plms.PLMSSampler.register_buffer = register_buffer