sd_hijack_autoencoder.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293
  1. # The content of this file comes from the ldm/models/autoencoder.py file of the compvis/stable-diffusion repo
  2. # The VQModel & VQModelInterface were subsequently removed from ldm/models/autoencoder.py when we moved to the stability-ai/stablediffusion repo
  3. # As the LDSR upscaler relies on VQModel & VQModelInterface, the hijack aims to put them back into the ldm.models.autoencoder
  4. import numpy as np
  5. import torch
  6. import pytorch_lightning as pl
  7. import torch.nn.functional as F
  8. from contextlib import contextmanager
  9. from torch.optim.lr_scheduler import LambdaLR
  10. from ldm.modules.ema import LitEma
  11. from vqvae_quantize import VectorQuantizer2 as VectorQuantizer
  12. from ldm.modules.diffusionmodules.model import Encoder, Decoder
  13. from ldm.util import instantiate_from_config
  14. import ldm.models.autoencoder
  15. from packaging import version
  16. class VQModel(pl.LightningModule):
  17. def __init__(self,
  18. ddconfig,
  19. lossconfig,
  20. n_embed,
  21. embed_dim,
  22. ckpt_path=None,
  23. ignore_keys=None,
  24. image_key="image",
  25. colorize_nlabels=None,
  26. monitor=None,
  27. batch_resize_range=None,
  28. scheduler_config=None,
  29. lr_g_factor=1.0,
  30. remap=None,
  31. sane_index_shape=False, # tell vector quantizer to return indices as bhw
  32. use_ema=False
  33. ):
  34. super().__init__()
  35. self.embed_dim = embed_dim
  36. self.n_embed = n_embed
  37. self.image_key = image_key
  38. self.encoder = Encoder(**ddconfig)
  39. self.decoder = Decoder(**ddconfig)
  40. self.loss = instantiate_from_config(lossconfig)
  41. self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25,
  42. remap=remap,
  43. sane_index_shape=sane_index_shape)
  44. self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1)
  45. self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
  46. if colorize_nlabels is not None:
  47. assert type(colorize_nlabels)==int
  48. self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
  49. if monitor is not None:
  50. self.monitor = monitor
  51. self.batch_resize_range = batch_resize_range
  52. if self.batch_resize_range is not None:
  53. print(f"{self.__class__.__name__}: Using per-batch resizing in range {batch_resize_range}.")
  54. self.use_ema = use_ema
  55. if self.use_ema:
  56. self.model_ema = LitEma(self)
  57. print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
  58. if ckpt_path is not None:
  59. self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys or [])
  60. self.scheduler_config = scheduler_config
  61. self.lr_g_factor = lr_g_factor
  62. @contextmanager
  63. def ema_scope(self, context=None):
  64. if self.use_ema:
  65. self.model_ema.store(self.parameters())
  66. self.model_ema.copy_to(self)
  67. if context is not None:
  68. print(f"{context}: Switched to EMA weights")
  69. try:
  70. yield None
  71. finally:
  72. if self.use_ema:
  73. self.model_ema.restore(self.parameters())
  74. if context is not None:
  75. print(f"{context}: Restored training weights")
  76. def init_from_ckpt(self, path, ignore_keys=None):
  77. sd = torch.load(path, map_location="cpu")["state_dict"]
  78. keys = list(sd.keys())
  79. for k in keys:
  80. for ik in ignore_keys or []:
  81. if k.startswith(ik):
  82. print("Deleting key {} from state_dict.".format(k))
  83. del sd[k]
  84. missing, unexpected = self.load_state_dict(sd, strict=False)
  85. print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
  86. if missing:
  87. print(f"Missing Keys: {missing}")
  88. if unexpected:
  89. print(f"Unexpected Keys: {unexpected}")
  90. def on_train_batch_end(self, *args, **kwargs):
  91. if self.use_ema:
  92. self.model_ema(self)
  93. def encode(self, x):
  94. h = self.encoder(x)
  95. h = self.quant_conv(h)
  96. quant, emb_loss, info = self.quantize(h)
  97. return quant, emb_loss, info
  98. def encode_to_prequant(self, x):
  99. h = self.encoder(x)
  100. h = self.quant_conv(h)
  101. return h
  102. def decode(self, quant):
  103. quant = self.post_quant_conv(quant)
  104. dec = self.decoder(quant)
  105. return dec
  106. def decode_code(self, code_b):
  107. quant_b = self.quantize.embed_code(code_b)
  108. dec = self.decode(quant_b)
  109. return dec
  110. def forward(self, input, return_pred_indices=False):
  111. quant, diff, (_,_,ind) = self.encode(input)
  112. dec = self.decode(quant)
  113. if return_pred_indices:
  114. return dec, diff, ind
  115. return dec, diff
  116. def get_input(self, batch, k):
  117. x = batch[k]
  118. if len(x.shape) == 3:
  119. x = x[..., None]
  120. x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
  121. if self.batch_resize_range is not None:
  122. lower_size = self.batch_resize_range[0]
  123. upper_size = self.batch_resize_range[1]
  124. if self.global_step <= 4:
  125. # do the first few batches with max size to avoid later oom
  126. new_resize = upper_size
  127. else:
  128. new_resize = np.random.choice(np.arange(lower_size, upper_size+16, 16))
  129. if new_resize != x.shape[2]:
  130. x = F.interpolate(x, size=new_resize, mode="bicubic")
  131. x = x.detach()
  132. return x
  133. def training_step(self, batch, batch_idx, optimizer_idx):
  134. # https://github.com/pytorch/pytorch/issues/37142
  135. # try not to fool the heuristics
  136. x = self.get_input(batch, self.image_key)
  137. xrec, qloss, ind = self(x, return_pred_indices=True)
  138. if optimizer_idx == 0:
  139. # autoencode
  140. aeloss, log_dict_ae = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
  141. last_layer=self.get_last_layer(), split="train",
  142. predicted_indices=ind)
  143. self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)
  144. return aeloss
  145. if optimizer_idx == 1:
  146. # discriminator
  147. discloss, log_dict_disc = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
  148. last_layer=self.get_last_layer(), split="train")
  149. self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True)
  150. return discloss
  151. def validation_step(self, batch, batch_idx):
  152. log_dict = self._validation_step(batch, batch_idx)
  153. with self.ema_scope():
  154. self._validation_step(batch, batch_idx, suffix="_ema")
  155. return log_dict
  156. def _validation_step(self, batch, batch_idx, suffix=""):
  157. x = self.get_input(batch, self.image_key)
  158. xrec, qloss, ind = self(x, return_pred_indices=True)
  159. aeloss, log_dict_ae = self.loss(qloss, x, xrec, 0,
  160. self.global_step,
  161. last_layer=self.get_last_layer(),
  162. split="val"+suffix,
  163. predicted_indices=ind
  164. )
  165. discloss, log_dict_disc = self.loss(qloss, x, xrec, 1,
  166. self.global_step,
  167. last_layer=self.get_last_layer(),
  168. split="val"+suffix,
  169. predicted_indices=ind
  170. )
  171. rec_loss = log_dict_ae[f"val{suffix}/rec_loss"]
  172. self.log(f"val{suffix}/rec_loss", rec_loss,
  173. prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
  174. self.log(f"val{suffix}/aeloss", aeloss,
  175. prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
  176. if version.parse(pl.__version__) >= version.parse('1.4.0'):
  177. del log_dict_ae[f"val{suffix}/rec_loss"]
  178. self.log_dict(log_dict_ae)
  179. self.log_dict(log_dict_disc)
  180. return self.log_dict
  181. def configure_optimizers(self):
  182. lr_d = self.learning_rate
  183. lr_g = self.lr_g_factor*self.learning_rate
  184. print("lr_d", lr_d)
  185. print("lr_g", lr_g)
  186. opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
  187. list(self.decoder.parameters())+
  188. list(self.quantize.parameters())+
  189. list(self.quant_conv.parameters())+
  190. list(self.post_quant_conv.parameters()),
  191. lr=lr_g, betas=(0.5, 0.9))
  192. opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
  193. lr=lr_d, betas=(0.5, 0.9))
  194. if self.scheduler_config is not None:
  195. scheduler = instantiate_from_config(self.scheduler_config)
  196. print("Setting up LambdaLR scheduler...")
  197. scheduler = [
  198. {
  199. 'scheduler': LambdaLR(opt_ae, lr_lambda=scheduler.schedule),
  200. 'interval': 'step',
  201. 'frequency': 1
  202. },
  203. {
  204. 'scheduler': LambdaLR(opt_disc, lr_lambda=scheduler.schedule),
  205. 'interval': 'step',
  206. 'frequency': 1
  207. },
  208. ]
  209. return [opt_ae, opt_disc], scheduler
  210. return [opt_ae, opt_disc], []
  211. def get_last_layer(self):
  212. return self.decoder.conv_out.weight
  213. def log_images(self, batch, only_inputs=False, plot_ema=False, **kwargs):
  214. log = {}
  215. x = self.get_input(batch, self.image_key)
  216. x = x.to(self.device)
  217. if only_inputs:
  218. log["inputs"] = x
  219. return log
  220. xrec, _ = self(x)
  221. if x.shape[1] > 3:
  222. # colorize with random projection
  223. assert xrec.shape[1] > 3
  224. x = self.to_rgb(x)
  225. xrec = self.to_rgb(xrec)
  226. log["inputs"] = x
  227. log["reconstructions"] = xrec
  228. if plot_ema:
  229. with self.ema_scope():
  230. xrec_ema, _ = self(x)
  231. if x.shape[1] > 3:
  232. xrec_ema = self.to_rgb(xrec_ema)
  233. log["reconstructions_ema"] = xrec_ema
  234. return log
  235. def to_rgb(self, x):
  236. assert self.image_key == "segmentation"
  237. if not hasattr(self, "colorize"):
  238. self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
  239. x = F.conv2d(x, weight=self.colorize)
  240. x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
  241. return x
  242. class VQModelInterface(VQModel):
  243. def __init__(self, embed_dim, *args, **kwargs):
  244. super().__init__(*args, embed_dim=embed_dim, **kwargs)
  245. self.embed_dim = embed_dim
  246. def encode(self, x):
  247. h = self.encoder(x)
  248. h = self.quant_conv(h)
  249. return h
  250. def decode(self, h, force_not_quantize=False):
  251. # also go through quantization layer
  252. if not force_not_quantize:
  253. quant, emb_loss, info = self.quantize(h)
  254. else:
  255. quant = h
  256. quant = self.post_quant_conv(quant)
  257. dec = self.decoder(quant)
  258. return dec
  259. ldm.models.autoencoder.VQModel = VQModel
  260. ldm.models.autoencoder.VQModelInterface = VQModelInterface