sd_hijack.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391
  1. import torch
  2. from torch.nn.functional import silu
  3. from types import MethodType
  4. from modules import devices, sd_hijack_optimizations, shared, script_callbacks, errors, sd_unet, patches
  5. from modules.hypernetworks import hypernetwork
  6. from modules.shared import cmd_opts
  7. from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr, xlmr_m18
  8. import ldm.modules.attention
  9. import ldm.modules.diffusionmodules.model
  10. import ldm.modules.diffusionmodules.openaimodel
  11. import ldm.models.diffusion.ddpm
  12. import ldm.models.diffusion.ddim
  13. import ldm.models.diffusion.plms
  14. import ldm.modules.encoders.modules
  15. import sgm.modules.attention
  16. import sgm.modules.diffusionmodules.model
  17. import sgm.modules.diffusionmodules.openaimodel
  18. import sgm.modules.encoders.modules
  19. attention_CrossAttention_forward = ldm.modules.attention.CrossAttention.forward
  20. diffusionmodules_model_nonlinearity = ldm.modules.diffusionmodules.model.nonlinearity
  21. diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.AttnBlock.forward
  22. # new memory efficient cross attention blocks do not support hypernets and we already
  23. # have memory efficient cross attention anyway, so this disables SD2.0's memory efficient cross attention
  24. ldm.modules.attention.MemoryEfficientCrossAttention = ldm.modules.attention.CrossAttention
  25. ldm.modules.attention.BasicTransformerBlock.ATTENTION_MODES["softmax-xformers"] = ldm.modules.attention.CrossAttention
  26. # silence new console spam from SD2
  27. ldm.modules.attention.print = shared.ldm_print
  28. ldm.modules.diffusionmodules.model.print = shared.ldm_print
  29. ldm.util.print = shared.ldm_print
  30. ldm.models.diffusion.ddpm.print = shared.ldm_print
  31. optimizers = []
  32. current_optimizer: sd_hijack_optimizations.SdOptimization = None
  33. ldm_patched_forward = sd_unet.create_unet_forward(ldm.modules.diffusionmodules.openaimodel.UNetModel.forward)
  34. ldm_original_forward = patches.patch(__file__, ldm.modules.diffusionmodules.openaimodel.UNetModel, "forward", ldm_patched_forward)
  35. sgm_patched_forward = sd_unet.create_unet_forward(sgm.modules.diffusionmodules.openaimodel.UNetModel.forward)
  36. sgm_original_forward = patches.patch(__file__, sgm.modules.diffusionmodules.openaimodel.UNetModel, "forward", sgm_patched_forward)
  37. def list_optimizers():
  38. new_optimizers = script_callbacks.list_optimizers_callback()
  39. new_optimizers = [x for x in new_optimizers if x.is_available()]
  40. new_optimizers = sorted(new_optimizers, key=lambda x: x.priority, reverse=True)
  41. optimizers.clear()
  42. optimizers.extend(new_optimizers)
  43. def apply_optimizations(option=None):
  44. global current_optimizer
  45. undo_optimizations()
  46. if len(optimizers) == 0:
  47. # a script can access the model very early, and optimizations would not be filled by then
  48. current_optimizer = None
  49. return ''
  50. ldm.modules.diffusionmodules.model.nonlinearity = silu
  51. ldm.modules.diffusionmodules.openaimodel.th = sd_hijack_unet.th
  52. sgm.modules.diffusionmodules.model.nonlinearity = silu
  53. sgm.modules.diffusionmodules.openaimodel.th = sd_hijack_unet.th
  54. if current_optimizer is not None:
  55. current_optimizer.undo()
  56. current_optimizer = None
  57. selection = option or shared.opts.cross_attention_optimization
  58. if selection == "Automatic" and len(optimizers) > 0:
  59. matching_optimizer = next(iter([x for x in optimizers if x.cmd_opt and getattr(shared.cmd_opts, x.cmd_opt, False)]), optimizers[0])
  60. else:
  61. matching_optimizer = next(iter([x for x in optimizers if x.title() == selection]), None)
  62. if selection == "None":
  63. matching_optimizer = None
  64. elif selection == "Automatic" and shared.cmd_opts.disable_opt_split_attention:
  65. matching_optimizer = None
  66. elif matching_optimizer is None:
  67. matching_optimizer = optimizers[0]
  68. if matching_optimizer is not None:
  69. print(f"Applying attention optimization: {matching_optimizer.name}... ", end='')
  70. matching_optimizer.apply()
  71. print("done.")
  72. current_optimizer = matching_optimizer
  73. return current_optimizer.name
  74. else:
  75. print("Disabling attention optimization")
  76. return ''
  77. def undo_optimizations():
  78. ldm.modules.diffusionmodules.model.nonlinearity = diffusionmodules_model_nonlinearity
  79. ldm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward
  80. ldm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward
  81. sgm.modules.diffusionmodules.model.nonlinearity = diffusionmodules_model_nonlinearity
  82. sgm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward
  83. sgm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward
  84. def fix_checkpoint():
  85. """checkpoints are now added and removed in embedding/hypernet code, since torch doesn't want
  86. checkpoints to be added when not training (there's a warning)"""
  87. pass
  88. def weighted_loss(sd_model, pred, target, mean=True):
  89. #Calculate the weight normally, but ignore the mean
  90. loss = sd_model._old_get_loss(pred, target, mean=False)
  91. #Check if we have weights available
  92. weight = getattr(sd_model, '_custom_loss_weight', None)
  93. if weight is not None:
  94. loss *= weight
  95. #Return the loss, as mean if specified
  96. return loss.mean() if mean else loss
  97. def weighted_forward(sd_model, x, c, w, *args, **kwargs):
  98. try:
  99. #Temporarily append weights to a place accessible during loss calc
  100. sd_model._custom_loss_weight = w
  101. #Replace 'get_loss' with a weight-aware one. Otherwise we need to reimplement 'forward' completely
  102. #Keep 'get_loss', but don't overwrite the previous old_get_loss if it's already set
  103. if not hasattr(sd_model, '_old_get_loss'):
  104. sd_model._old_get_loss = sd_model.get_loss
  105. sd_model.get_loss = MethodType(weighted_loss, sd_model)
  106. #Run the standard forward function, but with the patched 'get_loss'
  107. return sd_model.forward(x, c, *args, **kwargs)
  108. finally:
  109. try:
  110. #Delete temporary weights if appended
  111. del sd_model._custom_loss_weight
  112. except AttributeError:
  113. pass
  114. #If we have an old loss function, reset the loss function to the original one
  115. if hasattr(sd_model, '_old_get_loss'):
  116. sd_model.get_loss = sd_model._old_get_loss
  117. del sd_model._old_get_loss
  118. def apply_weighted_forward(sd_model):
  119. #Add new function 'weighted_forward' that can be called to calc weighted loss
  120. sd_model.weighted_forward = MethodType(weighted_forward, sd_model)
  121. def undo_weighted_forward(sd_model):
  122. try:
  123. del sd_model.weighted_forward
  124. except AttributeError:
  125. pass
  126. class StableDiffusionModelHijack:
  127. fixes = None
  128. layers = None
  129. circular_enabled = False
  130. clip = None
  131. optimization_method = None
  132. def __init__(self):
  133. import modules.textual_inversion.textual_inversion
  134. self.extra_generation_params = {}
  135. self.comments = []
  136. self.embedding_db = modules.textual_inversion.textual_inversion.EmbeddingDatabase()
  137. self.embedding_db.add_embedding_dir(cmd_opts.embeddings_dir)
  138. def apply_optimizations(self, option=None):
  139. try:
  140. self.optimization_method = apply_optimizations(option)
  141. except Exception as e:
  142. errors.display(e, "applying cross attention optimization")
  143. undo_optimizations()
  144. def convert_sdxl_to_ssd(self, m):
  145. """Converts an SDXL model to a Segmind Stable Diffusion model (see https://huggingface.co/segmind/SSD-1B)"""
  146. delattr(m.model.diffusion_model.middle_block, '1')
  147. delattr(m.model.diffusion_model.middle_block, '2')
  148. for i in ['9', '8', '7', '6', '5', '4']:
  149. delattr(m.model.diffusion_model.input_blocks[7][1].transformer_blocks, i)
  150. delattr(m.model.diffusion_model.input_blocks[8][1].transformer_blocks, i)
  151. delattr(m.model.diffusion_model.output_blocks[0][1].transformer_blocks, i)
  152. delattr(m.model.diffusion_model.output_blocks[1][1].transformer_blocks, i)
  153. delattr(m.model.diffusion_model.output_blocks[4][1].transformer_blocks, '1')
  154. delattr(m.model.diffusion_model.output_blocks[5][1].transformer_blocks, '1')
  155. devices.torch_gc()
  156. def hijack(self, m):
  157. conditioner = getattr(m, 'conditioner', None)
  158. if conditioner:
  159. text_cond_models = []
  160. for i in range(len(conditioner.embedders)):
  161. embedder = conditioner.embedders[i]
  162. typename = type(embedder).__name__
  163. if typename == 'FrozenOpenCLIPEmbedder':
  164. embedder.model.token_embedding = EmbeddingsWithFixes(embedder.model.token_embedding, self)
  165. conditioner.embedders[i] = sd_hijack_open_clip.FrozenOpenCLIPEmbedderWithCustomWords(embedder, self)
  166. text_cond_models.append(conditioner.embedders[i])
  167. if typename == 'FrozenCLIPEmbedder':
  168. model_embeddings = embedder.transformer.text_model.embeddings
  169. model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self)
  170. conditioner.embedders[i] = sd_hijack_clip.FrozenCLIPEmbedderForSDXLWithCustomWords(embedder, self)
  171. text_cond_models.append(conditioner.embedders[i])
  172. if typename == 'FrozenOpenCLIPEmbedder2':
  173. embedder.model.token_embedding = EmbeddingsWithFixes(embedder.model.token_embedding, self, textual_inversion_key='clip_g')
  174. conditioner.embedders[i] = sd_hijack_open_clip.FrozenOpenCLIPEmbedder2WithCustomWords(embedder, self)
  175. text_cond_models.append(conditioner.embedders[i])
  176. if len(text_cond_models) == 1:
  177. m.cond_stage_model = text_cond_models[0]
  178. else:
  179. m.cond_stage_model = conditioner
  180. if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation or type(m.cond_stage_model) == xlmr_m18.BertSeriesModelWithTransformation:
  181. model_embeddings = m.cond_stage_model.roberta.embeddings
  182. model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.word_embeddings, self)
  183. m.cond_stage_model = sd_hijack_xlmr.FrozenXLMREmbedderWithCustomWords(m.cond_stage_model, self)
  184. elif type(m.cond_stage_model) == ldm.modules.encoders.modules.FrozenCLIPEmbedder:
  185. model_embeddings = m.cond_stage_model.transformer.text_model.embeddings
  186. model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self)
  187. m.cond_stage_model = sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords(m.cond_stage_model, self)
  188. elif type(m.cond_stage_model) == ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder:
  189. m.cond_stage_model.model.token_embedding = EmbeddingsWithFixes(m.cond_stage_model.model.token_embedding, self)
  190. m.cond_stage_model = sd_hijack_open_clip.FrozenOpenCLIPEmbedderWithCustomWords(m.cond_stage_model, self)
  191. apply_weighted_forward(m)
  192. if m.cond_stage_key == "edit":
  193. sd_hijack_unet.hijack_ddpm_edit()
  194. self.apply_optimizations()
  195. self.clip = m.cond_stage_model
  196. def flatten(el):
  197. flattened = [flatten(children) for children in el.children()]
  198. res = [el]
  199. for c in flattened:
  200. res += c
  201. return res
  202. self.layers = flatten(m)
  203. import modules.models.diffusion.ddpm_edit
  204. if isinstance(m, ldm.models.diffusion.ddpm.LatentDiffusion):
  205. sd_unet.original_forward = ldm_original_forward
  206. elif isinstance(m, modules.models.diffusion.ddpm_edit.LatentDiffusion):
  207. sd_unet.original_forward = ldm_original_forward
  208. elif isinstance(m, sgm.models.diffusion.DiffusionEngine):
  209. sd_unet.original_forward = sgm_original_forward
  210. else:
  211. sd_unet.original_forward = None
  212. def undo_hijack(self, m):
  213. conditioner = getattr(m, 'conditioner', None)
  214. if conditioner:
  215. for i in range(len(conditioner.embedders)):
  216. embedder = conditioner.embedders[i]
  217. if isinstance(embedder, (sd_hijack_open_clip.FrozenOpenCLIPEmbedderWithCustomWords, sd_hijack_open_clip.FrozenOpenCLIPEmbedder2WithCustomWords)):
  218. embedder.wrapped.model.token_embedding = embedder.wrapped.model.token_embedding.wrapped
  219. conditioner.embedders[i] = embedder.wrapped
  220. if isinstance(embedder, sd_hijack_clip.FrozenCLIPEmbedderForSDXLWithCustomWords):
  221. embedder.wrapped.transformer.text_model.embeddings.token_embedding = embedder.wrapped.transformer.text_model.embeddings.token_embedding.wrapped
  222. conditioner.embedders[i] = embedder.wrapped
  223. if hasattr(m, 'cond_stage_model'):
  224. delattr(m, 'cond_stage_model')
  225. elif type(m.cond_stage_model) == sd_hijack_xlmr.FrozenXLMREmbedderWithCustomWords:
  226. m.cond_stage_model = m.cond_stage_model.wrapped
  227. elif type(m.cond_stage_model) == sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords:
  228. m.cond_stage_model = m.cond_stage_model.wrapped
  229. model_embeddings = m.cond_stage_model.transformer.text_model.embeddings
  230. if type(model_embeddings.token_embedding) == EmbeddingsWithFixes:
  231. model_embeddings.token_embedding = model_embeddings.token_embedding.wrapped
  232. elif type(m.cond_stage_model) == sd_hijack_open_clip.FrozenOpenCLIPEmbedderWithCustomWords:
  233. m.cond_stage_model.wrapped.model.token_embedding = m.cond_stage_model.wrapped.model.token_embedding.wrapped
  234. m.cond_stage_model = m.cond_stage_model.wrapped
  235. undo_optimizations()
  236. undo_weighted_forward(m)
  237. self.apply_circular(False)
  238. self.layers = None
  239. self.clip = None
  240. def apply_circular(self, enable):
  241. if self.circular_enabled == enable:
  242. return
  243. self.circular_enabled = enable
  244. for layer in [layer for layer in self.layers if type(layer) == torch.nn.Conv2d]:
  245. layer.padding_mode = 'circular' if enable else 'zeros'
  246. def clear_comments(self):
  247. self.comments = []
  248. self.extra_generation_params = {}
  249. def get_prompt_lengths(self, text):
  250. if self.clip is None:
  251. return "-", "-"
  252. _, token_count = self.clip.process_texts([text])
  253. return token_count, self.clip.get_target_prompt_token_count(token_count)
  254. def redo_hijack(self, m):
  255. self.undo_hijack(m)
  256. self.hijack(m)
  257. class EmbeddingsWithFixes(torch.nn.Module):
  258. def __init__(self, wrapped, embeddings, textual_inversion_key='clip_l'):
  259. super().__init__()
  260. self.wrapped = wrapped
  261. self.embeddings = embeddings
  262. self.textual_inversion_key = textual_inversion_key
  263. def forward(self, input_ids):
  264. batch_fixes = self.embeddings.fixes
  265. self.embeddings.fixes = None
  266. inputs_embeds = self.wrapped(input_ids)
  267. if batch_fixes is None or len(batch_fixes) == 0 or max([len(x) for x in batch_fixes]) == 0:
  268. return inputs_embeds
  269. vecs = []
  270. for fixes, tensor in zip(batch_fixes, inputs_embeds):
  271. for offset, embedding in fixes:
  272. vec = embedding.vec[self.textual_inversion_key] if isinstance(embedding.vec, dict) else embedding.vec
  273. emb = devices.cond_cast_unet(vec)
  274. emb_len = min(tensor.shape[0] - offset - 1, emb.shape[0])
  275. tensor = torch.cat([tensor[0:offset + 1], emb[0:emb_len], tensor[offset + 1 + emb_len:]])
  276. vecs.append(tensor)
  277. return torch.stack(vecs)
  278. def add_circular_option_to_conv_2d():
  279. conv2d_constructor = torch.nn.Conv2d.__init__
  280. def conv2d_constructor_circular(self, *args, **kwargs):
  281. return conv2d_constructor(self, *args, padding_mode='circular', **kwargs)
  282. torch.nn.Conv2d.__init__ = conv2d_constructor_circular
  283. model_hijack = StableDiffusionModelHijack()
  284. def register_buffer(self, name, attr):
  285. """
  286. Fix register buffer bug for Mac OS.
  287. """
  288. if type(attr) == torch.Tensor:
  289. if attr.device != devices.device:
  290. attr = attr.to(device=devices.device, dtype=(torch.float32 if devices.device.type == 'mps' else None))
  291. setattr(self, name, attr)
  292. ldm.models.diffusion.ddim.DDIMSampler.register_buffer = register_buffer
  293. ldm.models.diffusion.plms.PLMSSampler.register_buffer = register_buffer