sd_hijack_autoencoder.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292
  1. # The content of this file comes from the ldm/models/autoencoder.py file of the compvis/stable-diffusion repo
  2. # The VQModel & VQModelInterface were subsequently removed from ldm/models/autoencoder.py when we moved to the stability-ai/stablediffusion repo
  3. # As the LDSR upscaler relies on VQModel & VQModelInterface, the hijack aims to put them back into the ldm.models.autoencoder
  4. import numpy as np
  5. import torch
  6. import pytorch_lightning as pl
  7. import torch.nn.functional as F
  8. from contextlib import contextmanager
  9. from torch.optim.lr_scheduler import LambdaLR
  10. from ldm.modules.ema import LitEma
  11. from vqvae_quantize import VectorQuantizer2 as VectorQuantizer
  12. from ldm.modules.diffusionmodules.model import Encoder, Decoder
  13. from ldm.util import instantiate_from_config
  14. import ldm.models.autoencoder
  15. from packaging import version
  16. class VQModel(pl.LightningModule):
  17. def __init__(self,
  18. ddconfig,
  19. lossconfig,
  20. n_embed,
  21. embed_dim,
  22. ckpt_path=None,
  23. ignore_keys=None,
  24. image_key="image",
  25. colorize_nlabels=None,
  26. monitor=None,
  27. batch_resize_range=None,
  28. scheduler_config=None,
  29. lr_g_factor=1.0,
  30. remap=None,
  31. sane_index_shape=False, # tell vector quantizer to return indices as bhw
  32. use_ema=False
  33. ):
  34. super().__init__()
  35. self.embed_dim = embed_dim
  36. self.n_embed = n_embed
  37. self.image_key = image_key
  38. self.encoder = Encoder(**ddconfig)
  39. self.decoder = Decoder(**ddconfig)
  40. self.loss = instantiate_from_config(lossconfig)
  41. self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25,
  42. remap=remap,
  43. sane_index_shape=sane_index_shape)
  44. self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1)
  45. self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
  46. if colorize_nlabels is not None:
  47. assert type(colorize_nlabels)==int
  48. self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
  49. if monitor is not None:
  50. self.monitor = monitor
  51. self.batch_resize_range = batch_resize_range
  52. if self.batch_resize_range is not None:
  53. print(f"{self.__class__.__name__}: Using per-batch resizing in range {batch_resize_range}.")
  54. self.use_ema = use_ema
  55. if self.use_ema:
  56. self.model_ema = LitEma(self)
  57. print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
  58. if ckpt_path is not None:
  59. self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys or [])
  60. self.scheduler_config = scheduler_config
  61. self.lr_g_factor = lr_g_factor
  62. @contextmanager
  63. def ema_scope(self, context=None):
  64. if self.use_ema:
  65. self.model_ema.store(self.parameters())
  66. self.model_ema.copy_to(self)
  67. if context is not None:
  68. print(f"{context}: Switched to EMA weights")
  69. try:
  70. yield None
  71. finally:
  72. if self.use_ema:
  73. self.model_ema.restore(self.parameters())
  74. if context is not None:
  75. print(f"{context}: Restored training weights")
  76. def init_from_ckpt(self, path, ignore_keys=None):
  77. sd = torch.load(path, map_location="cpu")["state_dict"]
  78. keys = list(sd.keys())
  79. for k in keys:
  80. for ik in ignore_keys or []:
  81. if k.startswith(ik):
  82. print("Deleting key {} from state_dict.".format(k))
  83. del sd[k]
  84. missing, unexpected = self.load_state_dict(sd, strict=False)
  85. print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
  86. if len(missing) > 0:
  87. print(f"Missing Keys: {missing}")
  88. print(f"Unexpected Keys: {unexpected}")
  89. def on_train_batch_end(self, *args, **kwargs):
  90. if self.use_ema:
  91. self.model_ema(self)
  92. def encode(self, x):
  93. h = self.encoder(x)
  94. h = self.quant_conv(h)
  95. quant, emb_loss, info = self.quantize(h)
  96. return quant, emb_loss, info
  97. def encode_to_prequant(self, x):
  98. h = self.encoder(x)
  99. h = self.quant_conv(h)
  100. return h
  101. def decode(self, quant):
  102. quant = self.post_quant_conv(quant)
  103. dec = self.decoder(quant)
  104. return dec
  105. def decode_code(self, code_b):
  106. quant_b = self.quantize.embed_code(code_b)
  107. dec = self.decode(quant_b)
  108. return dec
  109. def forward(self, input, return_pred_indices=False):
  110. quant, diff, (_,_,ind) = self.encode(input)
  111. dec = self.decode(quant)
  112. if return_pred_indices:
  113. return dec, diff, ind
  114. return dec, diff
  115. def get_input(self, batch, k):
  116. x = batch[k]
  117. if len(x.shape) == 3:
  118. x = x[..., None]
  119. x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
  120. if self.batch_resize_range is not None:
  121. lower_size = self.batch_resize_range[0]
  122. upper_size = self.batch_resize_range[1]
  123. if self.global_step <= 4:
  124. # do the first few batches with max size to avoid later oom
  125. new_resize = upper_size
  126. else:
  127. new_resize = np.random.choice(np.arange(lower_size, upper_size+16, 16))
  128. if new_resize != x.shape[2]:
  129. x = F.interpolate(x, size=new_resize, mode="bicubic")
  130. x = x.detach()
  131. return x
  132. def training_step(self, batch, batch_idx, optimizer_idx):
  133. # https://github.com/pytorch/pytorch/issues/37142
  134. # try not to fool the heuristics
  135. x = self.get_input(batch, self.image_key)
  136. xrec, qloss, ind = self(x, return_pred_indices=True)
  137. if optimizer_idx == 0:
  138. # autoencode
  139. aeloss, log_dict_ae = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
  140. last_layer=self.get_last_layer(), split="train",
  141. predicted_indices=ind)
  142. self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)
  143. return aeloss
  144. if optimizer_idx == 1:
  145. # discriminator
  146. discloss, log_dict_disc = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
  147. last_layer=self.get_last_layer(), split="train")
  148. self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True)
  149. return discloss
  150. def validation_step(self, batch, batch_idx):
  151. log_dict = self._validation_step(batch, batch_idx)
  152. with self.ema_scope():
  153. self._validation_step(batch, batch_idx, suffix="_ema")
  154. return log_dict
  155. def _validation_step(self, batch, batch_idx, suffix=""):
  156. x = self.get_input(batch, self.image_key)
  157. xrec, qloss, ind = self(x, return_pred_indices=True)
  158. aeloss, log_dict_ae = self.loss(qloss, x, xrec, 0,
  159. self.global_step,
  160. last_layer=self.get_last_layer(),
  161. split="val"+suffix,
  162. predicted_indices=ind
  163. )
  164. discloss, log_dict_disc = self.loss(qloss, x, xrec, 1,
  165. self.global_step,
  166. last_layer=self.get_last_layer(),
  167. split="val"+suffix,
  168. predicted_indices=ind
  169. )
  170. rec_loss = log_dict_ae[f"val{suffix}/rec_loss"]
  171. self.log(f"val{suffix}/rec_loss", rec_loss,
  172. prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
  173. self.log(f"val{suffix}/aeloss", aeloss,
  174. prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
  175. if version.parse(pl.__version__) >= version.parse('1.4.0'):
  176. del log_dict_ae[f"val{suffix}/rec_loss"]
  177. self.log_dict(log_dict_ae)
  178. self.log_dict(log_dict_disc)
  179. return self.log_dict
  180. def configure_optimizers(self):
  181. lr_d = self.learning_rate
  182. lr_g = self.lr_g_factor*self.learning_rate
  183. print("lr_d", lr_d)
  184. print("lr_g", lr_g)
  185. opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
  186. list(self.decoder.parameters())+
  187. list(self.quantize.parameters())+
  188. list(self.quant_conv.parameters())+
  189. list(self.post_quant_conv.parameters()),
  190. lr=lr_g, betas=(0.5, 0.9))
  191. opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
  192. lr=lr_d, betas=(0.5, 0.9))
  193. if self.scheduler_config is not None:
  194. scheduler = instantiate_from_config(self.scheduler_config)
  195. print("Setting up LambdaLR scheduler...")
  196. scheduler = [
  197. {
  198. 'scheduler': LambdaLR(opt_ae, lr_lambda=scheduler.schedule),
  199. 'interval': 'step',
  200. 'frequency': 1
  201. },
  202. {
  203. 'scheduler': LambdaLR(opt_disc, lr_lambda=scheduler.schedule),
  204. 'interval': 'step',
  205. 'frequency': 1
  206. },
  207. ]
  208. return [opt_ae, opt_disc], scheduler
  209. return [opt_ae, opt_disc], []
  210. def get_last_layer(self):
  211. return self.decoder.conv_out.weight
  212. def log_images(self, batch, only_inputs=False, plot_ema=False, **kwargs):
  213. log = {}
  214. x = self.get_input(batch, self.image_key)
  215. x = x.to(self.device)
  216. if only_inputs:
  217. log["inputs"] = x
  218. return log
  219. xrec, _ = self(x)
  220. if x.shape[1] > 3:
  221. # colorize with random projection
  222. assert xrec.shape[1] > 3
  223. x = self.to_rgb(x)
  224. xrec = self.to_rgb(xrec)
  225. log["inputs"] = x
  226. log["reconstructions"] = xrec
  227. if plot_ema:
  228. with self.ema_scope():
  229. xrec_ema, _ = self(x)
  230. if x.shape[1] > 3:
  231. xrec_ema = self.to_rgb(xrec_ema)
  232. log["reconstructions_ema"] = xrec_ema
  233. return log
  234. def to_rgb(self, x):
  235. assert self.image_key == "segmentation"
  236. if not hasattr(self, "colorize"):
  237. self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
  238. x = F.conv2d(x, weight=self.colorize)
  239. x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
  240. return x
  241. class VQModelInterface(VQModel):
  242. def __init__(self, embed_dim, *args, **kwargs):
  243. super().__init__(*args, embed_dim=embed_dim, **kwargs)
  244. self.embed_dim = embed_dim
  245. def encode(self, x):
  246. h = self.encoder(x)
  247. h = self.quant_conv(h)
  248. return h
  249. def decode(self, h, force_not_quantize=False):
  250. # also go through quantization layer
  251. if not force_not_quantize:
  252. quant, emb_loss, info = self.quantize(h)
  253. else:
  254. quant = h
  255. quant = self.post_quant_conv(quant)
  256. dec = self.decoder(quant)
  257. return dec
  258. ldm.models.autoencoder.VQModel = VQModel
  259. ldm.models.autoencoder.VQModelInterface = VQModelInterface