sd_hijack_ddpm_v1.py 66 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443
  1. # This script is copied from the compvis/stable-diffusion repo (aka the SD V1 repo)
  2. # Original filename: ldm/models/diffusion/ddpm.py
  3. # The purpose to reinstate the old DDPM logic which works with VQ, whereas the V2 one doesn't
  4. # Some models such as LDSR require VQ to work correctly
  5. # The classes are suffixed with "V1" and added back to the "ldm.models.diffusion.ddpm" module
  6. import torch
  7. import torch.nn as nn
  8. import numpy as np
  9. import pytorch_lightning as pl
  10. from torch.optim.lr_scheduler import LambdaLR
  11. from einops import rearrange, repeat
  12. from contextlib import contextmanager
  13. from functools import partial
  14. from tqdm import tqdm
  15. from torchvision.utils import make_grid
  16. from pytorch_lightning.utilities.distributed import rank_zero_only
  17. from ldm.util import log_txt_as_img, exists, default, ismap, isimage, mean_flat, count_params, instantiate_from_config
  18. from ldm.modules.ema import LitEma
  19. from ldm.modules.distributions.distributions import normal_kl, DiagonalGaussianDistribution
  20. from ldm.models.autoencoder import VQModelInterface, IdentityFirstStage, AutoencoderKL
  21. from ldm.modules.diffusionmodules.util import make_beta_schedule, extract_into_tensor, noise_like
  22. from ldm.models.diffusion.ddim import DDIMSampler
  23. import ldm.models.diffusion.ddpm
  24. __conditioning_keys__ = {'concat': 'c_concat',
  25. 'crossattn': 'c_crossattn',
  26. 'adm': 'y'}
  27. def disabled_train(self, mode=True):
  28. """Overwrite model.train with this function to make sure train/eval mode
  29. does not change anymore."""
  30. return self
  31. def uniform_on_device(r1, r2, shape, device):
  32. return (r1 - r2) * torch.rand(*shape, device=device) + r2
  33. class DDPMV1(pl.LightningModule):
  34. # classic DDPM with Gaussian diffusion, in image space
  35. def __init__(self,
  36. unet_config,
  37. timesteps=1000,
  38. beta_schedule="linear",
  39. loss_type="l2",
  40. ckpt_path=None,
  41. ignore_keys=None,
  42. load_only_unet=False,
  43. monitor="val/loss",
  44. use_ema=True,
  45. first_stage_key="image",
  46. image_size=256,
  47. channels=3,
  48. log_every_t=100,
  49. clip_denoised=True,
  50. linear_start=1e-4,
  51. linear_end=2e-2,
  52. cosine_s=8e-3,
  53. given_betas=None,
  54. original_elbo_weight=0.,
  55. v_posterior=0., # weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta
  56. l_simple_weight=1.,
  57. conditioning_key=None,
  58. parameterization="eps", # all assuming fixed variance schedules
  59. scheduler_config=None,
  60. use_positional_encodings=False,
  61. learn_logvar=False,
  62. logvar_init=0.,
  63. ):
  64. super().__init__()
  65. assert parameterization in ["eps", "x0"], 'currently only supporting "eps" and "x0"'
  66. self.parameterization = parameterization
  67. print(f"{self.__class__.__name__}: Running in {self.parameterization}-prediction mode")
  68. self.cond_stage_model = None
  69. self.clip_denoised = clip_denoised
  70. self.log_every_t = log_every_t
  71. self.first_stage_key = first_stage_key
  72. self.image_size = image_size # try conv?
  73. self.channels = channels
  74. self.use_positional_encodings = use_positional_encodings
  75. self.model = DiffusionWrapperV1(unet_config, conditioning_key)
  76. count_params(self.model, verbose=True)
  77. self.use_ema = use_ema
  78. if self.use_ema:
  79. self.model_ema = LitEma(self.model)
  80. print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
  81. self.use_scheduler = scheduler_config is not None
  82. if self.use_scheduler:
  83. self.scheduler_config = scheduler_config
  84. self.v_posterior = v_posterior
  85. self.original_elbo_weight = original_elbo_weight
  86. self.l_simple_weight = l_simple_weight
  87. if monitor is not None:
  88. self.monitor = monitor
  89. if ckpt_path is not None:
  90. self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys or [], only_model=load_only_unet)
  91. self.register_schedule(given_betas=given_betas, beta_schedule=beta_schedule, timesteps=timesteps,
  92. linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s)
  93. self.loss_type = loss_type
  94. self.learn_logvar = learn_logvar
  95. self.logvar = torch.full(fill_value=logvar_init, size=(self.num_timesteps,))
  96. if self.learn_logvar:
  97. self.logvar = nn.Parameter(self.logvar, requires_grad=True)
  98. def register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000,
  99. linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
  100. if exists(given_betas):
  101. betas = given_betas
  102. else:
  103. betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end,
  104. cosine_s=cosine_s)
  105. alphas = 1. - betas
  106. alphas_cumprod = np.cumprod(alphas, axis=0)
  107. alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
  108. timesteps, = betas.shape
  109. self.num_timesteps = int(timesteps)
  110. self.linear_start = linear_start
  111. self.linear_end = linear_end
  112. assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep'
  113. to_torch = partial(torch.tensor, dtype=torch.float32)
  114. self.register_buffer('betas', to_torch(betas))
  115. self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
  116. self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev))
  117. # calculations for diffusion q(x_t | x_{t-1}) and others
  118. self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
  119. self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
  120. self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod)))
  121. self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod)))
  122. self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1)))
  123. # calculations for posterior q(x_{t-1} | x_t, x_0)
  124. posterior_variance = (1 - self.v_posterior) * betas * (1. - alphas_cumprod_prev) / (
  125. 1. - alphas_cumprod) + self.v_posterior * betas
  126. # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
  127. self.register_buffer('posterior_variance', to_torch(posterior_variance))
  128. # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
  129. self.register_buffer('posterior_log_variance_clipped', to_torch(np.log(np.maximum(posterior_variance, 1e-20))))
  130. self.register_buffer('posterior_mean_coef1', to_torch(
  131. betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)))
  132. self.register_buffer('posterior_mean_coef2', to_torch(
  133. (1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod)))
  134. if self.parameterization == "eps":
  135. lvlb_weights = self.betas ** 2 / (
  136. 2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod))
  137. elif self.parameterization == "x0":
  138. lvlb_weights = 0.5 * np.sqrt(torch.Tensor(alphas_cumprod)) / (2. * 1 - torch.Tensor(alphas_cumprod))
  139. else:
  140. raise NotImplementedError("mu not supported")
  141. # TODO how to choose this term
  142. lvlb_weights[0] = lvlb_weights[1]
  143. self.register_buffer('lvlb_weights', lvlb_weights, persistent=False)
  144. assert not torch.isnan(self.lvlb_weights).all()
  145. @contextmanager
  146. def ema_scope(self, context=None):
  147. if self.use_ema:
  148. self.model_ema.store(self.model.parameters())
  149. self.model_ema.copy_to(self.model)
  150. if context is not None:
  151. print(f"{context}: Switched to EMA weights")
  152. try:
  153. yield None
  154. finally:
  155. if self.use_ema:
  156. self.model_ema.restore(self.model.parameters())
  157. if context is not None:
  158. print(f"{context}: Restored training weights")
  159. def init_from_ckpt(self, path, ignore_keys=None, only_model=False):
  160. sd = torch.load(path, map_location="cpu")
  161. if "state_dict" in list(sd.keys()):
  162. sd = sd["state_dict"]
  163. keys = list(sd.keys())
  164. for k in keys:
  165. for ik in ignore_keys or []:
  166. if k.startswith(ik):
  167. print("Deleting key {} from state_dict.".format(k))
  168. del sd[k]
  169. missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(
  170. sd, strict=False)
  171. print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
  172. if missing:
  173. print(f"Missing Keys: {missing}")
  174. if unexpected:
  175. print(f"Unexpected Keys: {unexpected}")
  176. def q_mean_variance(self, x_start, t):
  177. """
  178. Get the distribution q(x_t | x_0).
  179. :param x_start: the [N x C x ...] tensor of noiseless inputs.
  180. :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
  181. :return: A tuple (mean, variance, log_variance), all of x_start's shape.
  182. """
  183. mean = (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start)
  184. variance = extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape)
  185. log_variance = extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape)
  186. return mean, variance, log_variance
  187. def predict_start_from_noise(self, x_t, t, noise):
  188. return (
  189. extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
  190. extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
  191. )
  192. def q_posterior(self, x_start, x_t, t):
  193. posterior_mean = (
  194. extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start +
  195. extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
  196. )
  197. posterior_variance = extract_into_tensor(self.posterior_variance, t, x_t.shape)
  198. posterior_log_variance_clipped = extract_into_tensor(self.posterior_log_variance_clipped, t, x_t.shape)
  199. return posterior_mean, posterior_variance, posterior_log_variance_clipped
  200. def p_mean_variance(self, x, t, clip_denoised: bool):
  201. model_out = self.model(x, t)
  202. if self.parameterization == "eps":
  203. x_recon = self.predict_start_from_noise(x, t=t, noise=model_out)
  204. elif self.parameterization == "x0":
  205. x_recon = model_out
  206. if clip_denoised:
  207. x_recon.clamp_(-1., 1.)
  208. model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
  209. return model_mean, posterior_variance, posterior_log_variance
  210. @torch.no_grad()
  211. def p_sample(self, x, t, clip_denoised=True, repeat_noise=False):
  212. b, *_, device = *x.shape, x.device
  213. model_mean, _, model_log_variance = self.p_mean_variance(x=x, t=t, clip_denoised=clip_denoised)
  214. noise = noise_like(x.shape, device, repeat_noise)
  215. # no noise when t == 0
  216. nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
  217. return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
  218. @torch.no_grad()
  219. def p_sample_loop(self, shape, return_intermediates=False):
  220. device = self.betas.device
  221. b = shape[0]
  222. img = torch.randn(shape, device=device)
  223. intermediates = [img]
  224. for i in tqdm(reversed(range(0, self.num_timesteps)), desc='Sampling t', total=self.num_timesteps):
  225. img = self.p_sample(img, torch.full((b,), i, device=device, dtype=torch.long),
  226. clip_denoised=self.clip_denoised)
  227. if i % self.log_every_t == 0 or i == self.num_timesteps - 1:
  228. intermediates.append(img)
  229. if return_intermediates:
  230. return img, intermediates
  231. return img
  232. @torch.no_grad()
  233. def sample(self, batch_size=16, return_intermediates=False):
  234. image_size = self.image_size
  235. channels = self.channels
  236. return self.p_sample_loop((batch_size, channels, image_size, image_size),
  237. return_intermediates=return_intermediates)
  238. def q_sample(self, x_start, t, noise=None):
  239. noise = default(noise, lambda: torch.randn_like(x_start))
  240. return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
  241. extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise)
  242. def get_loss(self, pred, target, mean=True):
  243. if self.loss_type == 'l1':
  244. loss = (target - pred).abs()
  245. if mean:
  246. loss = loss.mean()
  247. elif self.loss_type == 'l2':
  248. if mean:
  249. loss = torch.nn.functional.mse_loss(target, pred)
  250. else:
  251. loss = torch.nn.functional.mse_loss(target, pred, reduction='none')
  252. else:
  253. raise NotImplementedError("unknown loss type '{loss_type}'")
  254. return loss
  255. def p_losses(self, x_start, t, noise=None):
  256. noise = default(noise, lambda: torch.randn_like(x_start))
  257. x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
  258. model_out = self.model(x_noisy, t)
  259. loss_dict = {}
  260. if self.parameterization == "eps":
  261. target = noise
  262. elif self.parameterization == "x0":
  263. target = x_start
  264. else:
  265. raise NotImplementedError(f"Paramterization {self.parameterization} not yet supported")
  266. loss = self.get_loss(model_out, target, mean=False).mean(dim=[1, 2, 3])
  267. log_prefix = 'train' if self.training else 'val'
  268. loss_dict.update({f'{log_prefix}/loss_simple': loss.mean()})
  269. loss_simple = loss.mean() * self.l_simple_weight
  270. loss_vlb = (self.lvlb_weights[t] * loss).mean()
  271. loss_dict.update({f'{log_prefix}/loss_vlb': loss_vlb})
  272. loss = loss_simple + self.original_elbo_weight * loss_vlb
  273. loss_dict.update({f'{log_prefix}/loss': loss})
  274. return loss, loss_dict
  275. def forward(self, x, *args, **kwargs):
  276. # b, c, h, w, device, img_size, = *x.shape, x.device, self.image_size
  277. # assert h == img_size and w == img_size, f'height and width of image must be {img_size}'
  278. t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long()
  279. return self.p_losses(x, t, *args, **kwargs)
  280. def get_input(self, batch, k):
  281. x = batch[k]
  282. if len(x.shape) == 3:
  283. x = x[..., None]
  284. x = rearrange(x, 'b h w c -> b c h w')
  285. x = x.to(memory_format=torch.contiguous_format).float()
  286. return x
  287. def shared_step(self, batch):
  288. x = self.get_input(batch, self.first_stage_key)
  289. loss, loss_dict = self(x)
  290. return loss, loss_dict
  291. def training_step(self, batch, batch_idx):
  292. loss, loss_dict = self.shared_step(batch)
  293. self.log_dict(loss_dict, prog_bar=True,
  294. logger=True, on_step=True, on_epoch=True)
  295. self.log("global_step", self.global_step,
  296. prog_bar=True, logger=True, on_step=True, on_epoch=False)
  297. if self.use_scheduler:
  298. lr = self.optimizers().param_groups[0]['lr']
  299. self.log('lr_abs', lr, prog_bar=True, logger=True, on_step=True, on_epoch=False)
  300. return loss
  301. @torch.no_grad()
  302. def validation_step(self, batch, batch_idx):
  303. _, loss_dict_no_ema = self.shared_step(batch)
  304. with self.ema_scope():
  305. _, loss_dict_ema = self.shared_step(batch)
  306. loss_dict_ema = {key + '_ema': loss_dict_ema[key] for key in loss_dict_ema}
  307. self.log_dict(loss_dict_no_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True)
  308. self.log_dict(loss_dict_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True)
  309. def on_train_batch_end(self, *args, **kwargs):
  310. if self.use_ema:
  311. self.model_ema(self.model)
  312. def _get_rows_from_list(self, samples):
  313. n_imgs_per_row = len(samples)
  314. denoise_grid = rearrange(samples, 'n b c h w -> b n c h w')
  315. denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w')
  316. denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row)
  317. return denoise_grid
  318. @torch.no_grad()
  319. def log_images(self, batch, N=8, n_row=2, sample=True, return_keys=None, **kwargs):
  320. log = {}
  321. x = self.get_input(batch, self.first_stage_key)
  322. N = min(x.shape[0], N)
  323. n_row = min(x.shape[0], n_row)
  324. x = x.to(self.device)[:N]
  325. log["inputs"] = x
  326. # get diffusion row
  327. diffusion_row = []
  328. x_start = x[:n_row]
  329. for t in range(self.num_timesteps):
  330. if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
  331. t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
  332. t = t.to(self.device).long()
  333. noise = torch.randn_like(x_start)
  334. x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
  335. diffusion_row.append(x_noisy)
  336. log["diffusion_row"] = self._get_rows_from_list(diffusion_row)
  337. if sample:
  338. # get denoise row
  339. with self.ema_scope("Plotting"):
  340. samples, denoise_row = self.sample(batch_size=N, return_intermediates=True)
  341. log["samples"] = samples
  342. log["denoise_row"] = self._get_rows_from_list(denoise_row)
  343. if return_keys:
  344. if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0:
  345. return log
  346. else:
  347. return {key: log[key] for key in return_keys}
  348. return log
  349. def configure_optimizers(self):
  350. lr = self.learning_rate
  351. params = list(self.model.parameters())
  352. if self.learn_logvar:
  353. params = params + [self.logvar]
  354. opt = torch.optim.AdamW(params, lr=lr)
  355. return opt
  356. class LatentDiffusionV1(DDPMV1):
  357. """main class"""
  358. def __init__(self,
  359. first_stage_config,
  360. cond_stage_config,
  361. num_timesteps_cond=None,
  362. cond_stage_key="image",
  363. cond_stage_trainable=False,
  364. concat_mode=True,
  365. cond_stage_forward=None,
  366. conditioning_key=None,
  367. scale_factor=1.0,
  368. scale_by_std=False,
  369. *args, **kwargs):
  370. self.num_timesteps_cond = default(num_timesteps_cond, 1)
  371. self.scale_by_std = scale_by_std
  372. assert self.num_timesteps_cond <= kwargs['timesteps']
  373. # for backwards compatibility after implementation of DiffusionWrapper
  374. if conditioning_key is None:
  375. conditioning_key = 'concat' if concat_mode else 'crossattn'
  376. if cond_stage_config == '__is_unconditional__':
  377. conditioning_key = None
  378. ckpt_path = kwargs.pop("ckpt_path", None)
  379. ignore_keys = kwargs.pop("ignore_keys", [])
  380. super().__init__(*args, conditioning_key=conditioning_key, **kwargs)
  381. self.concat_mode = concat_mode
  382. self.cond_stage_trainable = cond_stage_trainable
  383. self.cond_stage_key = cond_stage_key
  384. try:
  385. self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1
  386. except Exception:
  387. self.num_downs = 0
  388. if not scale_by_std:
  389. self.scale_factor = scale_factor
  390. else:
  391. self.register_buffer('scale_factor', torch.tensor(scale_factor))
  392. self.instantiate_first_stage(first_stage_config)
  393. self.instantiate_cond_stage(cond_stage_config)
  394. self.cond_stage_forward = cond_stage_forward
  395. self.clip_denoised = False
  396. self.bbox_tokenizer = None
  397. self.restarted_from_ckpt = False
  398. if ckpt_path is not None:
  399. self.init_from_ckpt(ckpt_path, ignore_keys)
  400. self.restarted_from_ckpt = True
  401. def make_cond_schedule(self, ):
  402. self.cond_ids = torch.full(size=(self.num_timesteps,), fill_value=self.num_timesteps - 1, dtype=torch.long)
  403. ids = torch.round(torch.linspace(0, self.num_timesteps - 1, self.num_timesteps_cond)).long()
  404. self.cond_ids[:self.num_timesteps_cond] = ids
  405. @rank_zero_only
  406. @torch.no_grad()
  407. def on_train_batch_start(self, batch, batch_idx, dataloader_idx):
  408. # only for very first batch
  409. if self.scale_by_std and self.current_epoch == 0 and self.global_step == 0 and batch_idx == 0 and not self.restarted_from_ckpt:
  410. assert self.scale_factor == 1., 'rather not use custom rescaling and std-rescaling simultaneously'
  411. # set rescale weight to 1./std of encodings
  412. print("### USING STD-RESCALING ###")
  413. x = super().get_input(batch, self.first_stage_key)
  414. x = x.to(self.device)
  415. encoder_posterior = self.encode_first_stage(x)
  416. z = self.get_first_stage_encoding(encoder_posterior).detach()
  417. del self.scale_factor
  418. self.register_buffer('scale_factor', 1. / z.flatten().std())
  419. print(f"setting self.scale_factor to {self.scale_factor}")
  420. print("### USING STD-RESCALING ###")
  421. def register_schedule(self,
  422. given_betas=None, beta_schedule="linear", timesteps=1000,
  423. linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
  424. super().register_schedule(given_betas, beta_schedule, timesteps, linear_start, linear_end, cosine_s)
  425. self.shorten_cond_schedule = self.num_timesteps_cond > 1
  426. if self.shorten_cond_schedule:
  427. self.make_cond_schedule()
  428. def instantiate_first_stage(self, config):
  429. model = instantiate_from_config(config)
  430. self.first_stage_model = model.eval()
  431. self.first_stage_model.train = disabled_train
  432. for param in self.first_stage_model.parameters():
  433. param.requires_grad = False
  434. def instantiate_cond_stage(self, config):
  435. if not self.cond_stage_trainable:
  436. if config == "__is_first_stage__":
  437. print("Using first stage also as cond stage.")
  438. self.cond_stage_model = self.first_stage_model
  439. elif config == "__is_unconditional__":
  440. print(f"Training {self.__class__.__name__} as an unconditional model.")
  441. self.cond_stage_model = None
  442. # self.be_unconditional = True
  443. else:
  444. model = instantiate_from_config(config)
  445. self.cond_stage_model = model.eval()
  446. self.cond_stage_model.train = disabled_train
  447. for param in self.cond_stage_model.parameters():
  448. param.requires_grad = False
  449. else:
  450. assert config != '__is_first_stage__'
  451. assert config != '__is_unconditional__'
  452. model = instantiate_from_config(config)
  453. self.cond_stage_model = model
  454. def _get_denoise_row_from_list(self, samples, desc='', force_no_decoder_quantization=False):
  455. denoise_row = []
  456. for zd in tqdm(samples, desc=desc):
  457. denoise_row.append(self.decode_first_stage(zd.to(self.device),
  458. force_not_quantize=force_no_decoder_quantization))
  459. n_imgs_per_row = len(denoise_row)
  460. denoise_row = torch.stack(denoise_row) # n_log_step, n_row, C, H, W
  461. denoise_grid = rearrange(denoise_row, 'n b c h w -> b n c h w')
  462. denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w')
  463. denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row)
  464. return denoise_grid
  465. def get_first_stage_encoding(self, encoder_posterior):
  466. if isinstance(encoder_posterior, DiagonalGaussianDistribution):
  467. z = encoder_posterior.sample()
  468. elif isinstance(encoder_posterior, torch.Tensor):
  469. z = encoder_posterior
  470. else:
  471. raise NotImplementedError(f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented")
  472. return self.scale_factor * z
  473. def get_learned_conditioning(self, c):
  474. if self.cond_stage_forward is None:
  475. if hasattr(self.cond_stage_model, 'encode') and callable(self.cond_stage_model.encode):
  476. c = self.cond_stage_model.encode(c)
  477. if isinstance(c, DiagonalGaussianDistribution):
  478. c = c.mode()
  479. else:
  480. c = self.cond_stage_model(c)
  481. else:
  482. assert hasattr(self.cond_stage_model, self.cond_stage_forward)
  483. c = getattr(self.cond_stage_model, self.cond_stage_forward)(c)
  484. return c
  485. def meshgrid(self, h, w):
  486. y = torch.arange(0, h).view(h, 1, 1).repeat(1, w, 1)
  487. x = torch.arange(0, w).view(1, w, 1).repeat(h, 1, 1)
  488. arr = torch.cat([y, x], dim=-1)
  489. return arr
  490. def delta_border(self, h, w):
  491. """
  492. :param h: height
  493. :param w: width
  494. :return: normalized distance to image border,
  495. wtith min distance = 0 at border and max dist = 0.5 at image center
  496. """
  497. lower_right_corner = torch.tensor([h - 1, w - 1]).view(1, 1, 2)
  498. arr = self.meshgrid(h, w) / lower_right_corner
  499. dist_left_up = torch.min(arr, dim=-1, keepdims=True)[0]
  500. dist_right_down = torch.min(1 - arr, dim=-1, keepdims=True)[0]
  501. edge_dist = torch.min(torch.cat([dist_left_up, dist_right_down], dim=-1), dim=-1)[0]
  502. return edge_dist
  503. def get_weighting(self, h, w, Ly, Lx, device):
  504. weighting = self.delta_border(h, w)
  505. weighting = torch.clip(weighting, self.split_input_params["clip_min_weight"],
  506. self.split_input_params["clip_max_weight"], )
  507. weighting = weighting.view(1, h * w, 1).repeat(1, 1, Ly * Lx).to(device)
  508. if self.split_input_params["tie_braker"]:
  509. L_weighting = self.delta_border(Ly, Lx)
  510. L_weighting = torch.clip(L_weighting,
  511. self.split_input_params["clip_min_tie_weight"],
  512. self.split_input_params["clip_max_tie_weight"])
  513. L_weighting = L_weighting.view(1, 1, Ly * Lx).to(device)
  514. weighting = weighting * L_weighting
  515. return weighting
  516. def get_fold_unfold(self, x, kernel_size, stride, uf=1, df=1): # todo load once not every time, shorten code
  517. """
  518. :param x: img of size (bs, c, h, w)
  519. :return: n img crops of size (n, bs, c, kernel_size[0], kernel_size[1])
  520. """
  521. bs, nc, h, w = x.shape
  522. # number of crops in image
  523. Ly = (h - kernel_size[0]) // stride[0] + 1
  524. Lx = (w - kernel_size[1]) // stride[1] + 1
  525. if uf == 1 and df == 1:
  526. fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
  527. unfold = torch.nn.Unfold(**fold_params)
  528. fold = torch.nn.Fold(output_size=x.shape[2:], **fold_params)
  529. weighting = self.get_weighting(kernel_size[0], kernel_size[1], Ly, Lx, x.device).to(x.dtype)
  530. normalization = fold(weighting).view(1, 1, h, w) # normalizes the overlap
  531. weighting = weighting.view((1, 1, kernel_size[0], kernel_size[1], Ly * Lx))
  532. elif uf > 1 and df == 1:
  533. fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
  534. unfold = torch.nn.Unfold(**fold_params)
  535. fold_params2 = dict(kernel_size=(kernel_size[0] * uf, kernel_size[0] * uf),
  536. dilation=1, padding=0,
  537. stride=(stride[0] * uf, stride[1] * uf))
  538. fold = torch.nn.Fold(output_size=(x.shape[2] * uf, x.shape[3] * uf), **fold_params2)
  539. weighting = self.get_weighting(kernel_size[0] * uf, kernel_size[1] * uf, Ly, Lx, x.device).to(x.dtype)
  540. normalization = fold(weighting).view(1, 1, h * uf, w * uf) # normalizes the overlap
  541. weighting = weighting.view((1, 1, kernel_size[0] * uf, kernel_size[1] * uf, Ly * Lx))
  542. elif df > 1 and uf == 1:
  543. fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
  544. unfold = torch.nn.Unfold(**fold_params)
  545. fold_params2 = dict(kernel_size=(kernel_size[0] // df, kernel_size[0] // df),
  546. dilation=1, padding=0,
  547. stride=(stride[0] // df, stride[1] // df))
  548. fold = torch.nn.Fold(output_size=(x.shape[2] // df, x.shape[3] // df), **fold_params2)
  549. weighting = self.get_weighting(kernel_size[0] // df, kernel_size[1] // df, Ly, Lx, x.device).to(x.dtype)
  550. normalization = fold(weighting).view(1, 1, h // df, w // df) # normalizes the overlap
  551. weighting = weighting.view((1, 1, kernel_size[0] // df, kernel_size[1] // df, Ly * Lx))
  552. else:
  553. raise NotImplementedError
  554. return fold, unfold, normalization, weighting
  555. @torch.no_grad()
  556. def get_input(self, batch, k, return_first_stage_outputs=False, force_c_encode=False,
  557. cond_key=None, return_original_cond=False, bs=None):
  558. x = super().get_input(batch, k)
  559. if bs is not None:
  560. x = x[:bs]
  561. x = x.to(self.device)
  562. encoder_posterior = self.encode_first_stage(x)
  563. z = self.get_first_stage_encoding(encoder_posterior).detach()
  564. if self.model.conditioning_key is not None:
  565. if cond_key is None:
  566. cond_key = self.cond_stage_key
  567. if cond_key != self.first_stage_key:
  568. if cond_key in ['caption', 'coordinates_bbox']:
  569. xc = batch[cond_key]
  570. elif cond_key == 'class_label':
  571. xc = batch
  572. else:
  573. xc = super().get_input(batch, cond_key).to(self.device)
  574. else:
  575. xc = x
  576. if not self.cond_stage_trainable or force_c_encode:
  577. if isinstance(xc, dict) or isinstance(xc, list):
  578. # import pudb; pudb.set_trace()
  579. c = self.get_learned_conditioning(xc)
  580. else:
  581. c = self.get_learned_conditioning(xc.to(self.device))
  582. else:
  583. c = xc
  584. if bs is not None:
  585. c = c[:bs]
  586. if self.use_positional_encodings:
  587. pos_x, pos_y = self.compute_latent_shifts(batch)
  588. ckey = __conditioning_keys__[self.model.conditioning_key]
  589. c = {ckey: c, 'pos_x': pos_x, 'pos_y': pos_y}
  590. else:
  591. c = None
  592. xc = None
  593. if self.use_positional_encodings:
  594. pos_x, pos_y = self.compute_latent_shifts(batch)
  595. c = {'pos_x': pos_x, 'pos_y': pos_y}
  596. out = [z, c]
  597. if return_first_stage_outputs:
  598. xrec = self.decode_first_stage(z)
  599. out.extend([x, xrec])
  600. if return_original_cond:
  601. out.append(xc)
  602. return out
  603. @torch.no_grad()
  604. def decode_first_stage(self, z, predict_cids=False, force_not_quantize=False):
  605. if predict_cids:
  606. if z.dim() == 4:
  607. z = torch.argmax(z.exp(), dim=1).long()
  608. z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None)
  609. z = rearrange(z, 'b h w c -> b c h w').contiguous()
  610. z = 1. / self.scale_factor * z
  611. if hasattr(self, "split_input_params"):
  612. if self.split_input_params["patch_distributed_vq"]:
  613. ks = self.split_input_params["ks"] # eg. (128, 128)
  614. stride = self.split_input_params["stride"] # eg. (64, 64)
  615. uf = self.split_input_params["vqf"]
  616. bs, nc, h, w = z.shape
  617. if ks[0] > h or ks[1] > w:
  618. ks = (min(ks[0], h), min(ks[1], w))
  619. print("reducing Kernel")
  620. if stride[0] > h or stride[1] > w:
  621. stride = (min(stride[0], h), min(stride[1], w))
  622. print("reducing stride")
  623. fold, unfold, normalization, weighting = self.get_fold_unfold(z, ks, stride, uf=uf)
  624. z = unfold(z) # (bn, nc * prod(**ks), L)
  625. # 1. Reshape to img shape
  626. z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L )
  627. # 2. apply model loop over last dim
  628. if isinstance(self.first_stage_model, VQModelInterface):
  629. output_list = [self.first_stage_model.decode(z[:, :, :, :, i],
  630. force_not_quantize=predict_cids or force_not_quantize)
  631. for i in range(z.shape[-1])]
  632. else:
  633. output_list = [self.first_stage_model.decode(z[:, :, :, :, i])
  634. for i in range(z.shape[-1])]
  635. o = torch.stack(output_list, axis=-1) # # (bn, nc, ks[0], ks[1], L)
  636. o = o * weighting
  637. # Reverse 1. reshape to img shape
  638. o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L)
  639. # stitch crops together
  640. decoded = fold(o)
  641. decoded = decoded / normalization # norm is shape (1, 1, h, w)
  642. return decoded
  643. else:
  644. if isinstance(self.first_stage_model, VQModelInterface):
  645. return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize)
  646. else:
  647. return self.first_stage_model.decode(z)
  648. else:
  649. if isinstance(self.first_stage_model, VQModelInterface):
  650. return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize)
  651. else:
  652. return self.first_stage_model.decode(z)
  653. # same as above but without decorator
  654. def differentiable_decode_first_stage(self, z, predict_cids=False, force_not_quantize=False):
  655. if predict_cids:
  656. if z.dim() == 4:
  657. z = torch.argmax(z.exp(), dim=1).long()
  658. z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None)
  659. z = rearrange(z, 'b h w c -> b c h w').contiguous()
  660. z = 1. / self.scale_factor * z
  661. if hasattr(self, "split_input_params"):
  662. if self.split_input_params["patch_distributed_vq"]:
  663. ks = self.split_input_params["ks"] # eg. (128, 128)
  664. stride = self.split_input_params["stride"] # eg. (64, 64)
  665. uf = self.split_input_params["vqf"]
  666. bs, nc, h, w = z.shape
  667. if ks[0] > h or ks[1] > w:
  668. ks = (min(ks[0], h), min(ks[1], w))
  669. print("reducing Kernel")
  670. if stride[0] > h or stride[1] > w:
  671. stride = (min(stride[0], h), min(stride[1], w))
  672. print("reducing stride")
  673. fold, unfold, normalization, weighting = self.get_fold_unfold(z, ks, stride, uf=uf)
  674. z = unfold(z) # (bn, nc * prod(**ks), L)
  675. # 1. Reshape to img shape
  676. z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L )
  677. # 2. apply model loop over last dim
  678. if isinstance(self.first_stage_model, VQModelInterface):
  679. output_list = [self.first_stage_model.decode(z[:, :, :, :, i],
  680. force_not_quantize=predict_cids or force_not_quantize)
  681. for i in range(z.shape[-1])]
  682. else:
  683. output_list = [self.first_stage_model.decode(z[:, :, :, :, i])
  684. for i in range(z.shape[-1])]
  685. o = torch.stack(output_list, axis=-1) # # (bn, nc, ks[0], ks[1], L)
  686. o = o * weighting
  687. # Reverse 1. reshape to img shape
  688. o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L)
  689. # stitch crops together
  690. decoded = fold(o)
  691. decoded = decoded / normalization # norm is shape (1, 1, h, w)
  692. return decoded
  693. else:
  694. if isinstance(self.first_stage_model, VQModelInterface):
  695. return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize)
  696. else:
  697. return self.first_stage_model.decode(z)
  698. else:
  699. if isinstance(self.first_stage_model, VQModelInterface):
  700. return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize)
  701. else:
  702. return self.first_stage_model.decode(z)
  703. @torch.no_grad()
  704. def encode_first_stage(self, x):
  705. if hasattr(self, "split_input_params"):
  706. if self.split_input_params["patch_distributed_vq"]:
  707. ks = self.split_input_params["ks"] # eg. (128, 128)
  708. stride = self.split_input_params["stride"] # eg. (64, 64)
  709. df = self.split_input_params["vqf"]
  710. self.split_input_params['original_image_size'] = x.shape[-2:]
  711. bs, nc, h, w = x.shape
  712. if ks[0] > h or ks[1] > w:
  713. ks = (min(ks[0], h), min(ks[1], w))
  714. print("reducing Kernel")
  715. if stride[0] > h or stride[1] > w:
  716. stride = (min(stride[0], h), min(stride[1], w))
  717. print("reducing stride")
  718. fold, unfold, normalization, weighting = self.get_fold_unfold(x, ks, stride, df=df)
  719. z = unfold(x) # (bn, nc * prod(**ks), L)
  720. # Reshape to img shape
  721. z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L )
  722. output_list = [self.first_stage_model.encode(z[:, :, :, :, i])
  723. for i in range(z.shape[-1])]
  724. o = torch.stack(output_list, axis=-1)
  725. o = o * weighting
  726. # Reverse reshape to img shape
  727. o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L)
  728. # stitch crops together
  729. decoded = fold(o)
  730. decoded = decoded / normalization
  731. return decoded
  732. else:
  733. return self.first_stage_model.encode(x)
  734. else:
  735. return self.first_stage_model.encode(x)
  736. def shared_step(self, batch, **kwargs):
  737. x, c = self.get_input(batch, self.first_stage_key)
  738. loss = self(x, c)
  739. return loss
  740. def forward(self, x, c, *args, **kwargs):
  741. t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long()
  742. if self.model.conditioning_key is not None:
  743. assert c is not None
  744. if self.cond_stage_trainable:
  745. c = self.get_learned_conditioning(c)
  746. if self.shorten_cond_schedule: # TODO: drop this option
  747. tc = self.cond_ids[t].to(self.device)
  748. c = self.q_sample(x_start=c, t=tc, noise=torch.randn_like(c.float()))
  749. return self.p_losses(x, c, t, *args, **kwargs)
  750. def apply_model(self, x_noisy, t, cond, return_ids=False):
  751. if isinstance(cond, dict):
  752. # hybrid case, cond is exptected to be a dict
  753. pass
  754. else:
  755. if not isinstance(cond, list):
  756. cond = [cond]
  757. key = 'c_concat' if self.model.conditioning_key == 'concat' else 'c_crossattn'
  758. cond = {key: cond}
  759. if hasattr(self, "split_input_params"):
  760. assert len(cond) == 1 # todo can only deal with one conditioning atm
  761. assert not return_ids
  762. ks = self.split_input_params["ks"] # eg. (128, 128)
  763. stride = self.split_input_params["stride"] # eg. (64, 64)
  764. h, w = x_noisy.shape[-2:]
  765. fold, unfold, normalization, weighting = self.get_fold_unfold(x_noisy, ks, stride)
  766. z = unfold(x_noisy) # (bn, nc * prod(**ks), L)
  767. # Reshape to img shape
  768. z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L )
  769. z_list = [z[:, :, :, :, i] for i in range(z.shape[-1])]
  770. if self.cond_stage_key in ["image", "LR_image", "segmentation",
  771. 'bbox_img'] and self.model.conditioning_key: # todo check for completeness
  772. c_key = next(iter(cond.keys())) # get key
  773. c = next(iter(cond.values())) # get value
  774. assert (len(c) == 1) # todo extend to list with more than one elem
  775. c = c[0] # get element
  776. c = unfold(c)
  777. c = c.view((c.shape[0], -1, ks[0], ks[1], c.shape[-1])) # (bn, nc, ks[0], ks[1], L )
  778. cond_list = [{c_key: [c[:, :, :, :, i]]} for i in range(c.shape[-1])]
  779. elif self.cond_stage_key == 'coordinates_bbox':
  780. assert 'original_image_size' in self.split_input_params, 'BoudingBoxRescaling is missing original_image_size'
  781. # assuming padding of unfold is always 0 and its dilation is always 1
  782. n_patches_per_row = int((w - ks[0]) / stride[0] + 1)
  783. full_img_h, full_img_w = self.split_input_params['original_image_size']
  784. # as we are operating on latents, we need the factor from the original image size to the
  785. # spatial latent size to properly rescale the crops for regenerating the bbox annotations
  786. num_downs = self.first_stage_model.encoder.num_resolutions - 1
  787. rescale_latent = 2 ** (num_downs)
  788. # get top left postions of patches as conforming for the bbbox tokenizer, therefore we
  789. # need to rescale the tl patch coordinates to be in between (0,1)
  790. tl_patch_coordinates = [(rescale_latent * stride[0] * (patch_nr % n_patches_per_row) / full_img_w,
  791. rescale_latent * stride[1] * (patch_nr // n_patches_per_row) / full_img_h)
  792. for patch_nr in range(z.shape[-1])]
  793. # patch_limits are tl_coord, width and height coordinates as (x_tl, y_tl, h, w)
  794. patch_limits = [(x_tl, y_tl,
  795. rescale_latent * ks[0] / full_img_w,
  796. rescale_latent * ks[1] / full_img_h) for x_tl, y_tl in tl_patch_coordinates]
  797. # patch_values = [(np.arange(x_tl,min(x_tl+ks, 1.)),np.arange(y_tl,min(y_tl+ks, 1.))) for x_tl, y_tl in tl_patch_coordinates]
  798. # tokenize crop coordinates for the bounding boxes of the respective patches
  799. patch_limits_tknzd = [torch.LongTensor(self.bbox_tokenizer._crop_encoder(bbox))[None].to(self.device)
  800. for bbox in patch_limits] # list of length l with tensors of shape (1, 2)
  801. print(patch_limits_tknzd[0].shape)
  802. # cut tknzd crop position from conditioning
  803. assert isinstance(cond, dict), 'cond must be dict to be fed into model'
  804. cut_cond = cond['c_crossattn'][0][..., :-2].to(self.device)
  805. print(cut_cond.shape)
  806. adapted_cond = torch.stack([torch.cat([cut_cond, p], dim=1) for p in patch_limits_tknzd])
  807. adapted_cond = rearrange(adapted_cond, 'l b n -> (l b) n')
  808. print(adapted_cond.shape)
  809. adapted_cond = self.get_learned_conditioning(adapted_cond)
  810. print(adapted_cond.shape)
  811. adapted_cond = rearrange(adapted_cond, '(l b) n d -> l b n d', l=z.shape[-1])
  812. print(adapted_cond.shape)
  813. cond_list = [{'c_crossattn': [e]} for e in adapted_cond]
  814. else:
  815. cond_list = [cond for i in range(z.shape[-1])] # Todo make this more efficient
  816. # apply model by loop over crops
  817. output_list = [self.model(z_list[i], t, **cond_list[i]) for i in range(z.shape[-1])]
  818. assert not isinstance(output_list[0],
  819. tuple) # todo cant deal with multiple model outputs check this never happens
  820. o = torch.stack(output_list, axis=-1)
  821. o = o * weighting
  822. # Reverse reshape to img shape
  823. o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L)
  824. # stitch crops together
  825. x_recon = fold(o) / normalization
  826. else:
  827. x_recon = self.model(x_noisy, t, **cond)
  828. if isinstance(x_recon, tuple) and not return_ids:
  829. return x_recon[0]
  830. else:
  831. return x_recon
  832. def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
  833. return (extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart) / \
  834. extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
  835. def _prior_bpd(self, x_start):
  836. """
  837. Get the prior KL term for the variational lower-bound, measured in
  838. bits-per-dim.
  839. This term can't be optimized, as it only depends on the encoder.
  840. :param x_start: the [N x C x ...] tensor of inputs.
  841. :return: a batch of [N] KL values (in bits), one per batch element.
  842. """
  843. batch_size = x_start.shape[0]
  844. t = torch.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device)
  845. qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t)
  846. kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0)
  847. return mean_flat(kl_prior) / np.log(2.0)
  848. def p_losses(self, x_start, cond, t, noise=None):
  849. noise = default(noise, lambda: torch.randn_like(x_start))
  850. x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
  851. model_output = self.apply_model(x_noisy, t, cond)
  852. loss_dict = {}
  853. prefix = 'train' if self.training else 'val'
  854. if self.parameterization == "x0":
  855. target = x_start
  856. elif self.parameterization == "eps":
  857. target = noise
  858. else:
  859. raise NotImplementedError()
  860. loss_simple = self.get_loss(model_output, target, mean=False).mean([1, 2, 3])
  861. loss_dict.update({f'{prefix}/loss_simple': loss_simple.mean()})
  862. logvar_t = self.logvar[t].to(self.device)
  863. loss = loss_simple / torch.exp(logvar_t) + logvar_t
  864. # loss = loss_simple / torch.exp(self.logvar) + self.logvar
  865. if self.learn_logvar:
  866. loss_dict.update({f'{prefix}/loss_gamma': loss.mean()})
  867. loss_dict.update({'logvar': self.logvar.data.mean()})
  868. loss = self.l_simple_weight * loss.mean()
  869. loss_vlb = self.get_loss(model_output, target, mean=False).mean(dim=(1, 2, 3))
  870. loss_vlb = (self.lvlb_weights[t] * loss_vlb).mean()
  871. loss_dict.update({f'{prefix}/loss_vlb': loss_vlb})
  872. loss += (self.original_elbo_weight * loss_vlb)
  873. loss_dict.update({f'{prefix}/loss': loss})
  874. return loss, loss_dict
  875. def p_mean_variance(self, x, c, t, clip_denoised: bool, return_codebook_ids=False, quantize_denoised=False,
  876. return_x0=False, score_corrector=None, corrector_kwargs=None):
  877. t_in = t
  878. model_out = self.apply_model(x, t_in, c, return_ids=return_codebook_ids)
  879. if score_corrector is not None:
  880. assert self.parameterization == "eps"
  881. model_out = score_corrector.modify_score(self, model_out, x, t, c, **corrector_kwargs)
  882. if return_codebook_ids:
  883. model_out, logits = model_out
  884. if self.parameterization == "eps":
  885. x_recon = self.predict_start_from_noise(x, t=t, noise=model_out)
  886. elif self.parameterization == "x0":
  887. x_recon = model_out
  888. else:
  889. raise NotImplementedError()
  890. if clip_denoised:
  891. x_recon.clamp_(-1., 1.)
  892. if quantize_denoised:
  893. x_recon, _, [_, _, indices] = self.first_stage_model.quantize(x_recon)
  894. model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
  895. if return_codebook_ids:
  896. return model_mean, posterior_variance, posterior_log_variance, logits
  897. elif return_x0:
  898. return model_mean, posterior_variance, posterior_log_variance, x_recon
  899. else:
  900. return model_mean, posterior_variance, posterior_log_variance
  901. @torch.no_grad()
  902. def p_sample(self, x, c, t, clip_denoised=False, repeat_noise=False,
  903. return_codebook_ids=False, quantize_denoised=False, return_x0=False,
  904. temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None):
  905. b, *_, device = *x.shape, x.device
  906. outputs = self.p_mean_variance(x=x, c=c, t=t, clip_denoised=clip_denoised,
  907. return_codebook_ids=return_codebook_ids,
  908. quantize_denoised=quantize_denoised,
  909. return_x0=return_x0,
  910. score_corrector=score_corrector, corrector_kwargs=corrector_kwargs)
  911. if return_codebook_ids:
  912. raise DeprecationWarning("Support dropped.")
  913. model_mean, _, model_log_variance, logits = outputs
  914. elif return_x0:
  915. model_mean, _, model_log_variance, x0 = outputs
  916. else:
  917. model_mean, _, model_log_variance = outputs
  918. noise = noise_like(x.shape, device, repeat_noise) * temperature
  919. if noise_dropout > 0.:
  920. noise = torch.nn.functional.dropout(noise, p=noise_dropout)
  921. # no noise when t == 0
  922. nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
  923. if return_codebook_ids:
  924. return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, logits.argmax(dim=1)
  925. if return_x0:
  926. return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, x0
  927. else:
  928. return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
  929. @torch.no_grad()
  930. def progressive_denoising(self, cond, shape, verbose=True, callback=None, quantize_denoised=False,
  931. img_callback=None, mask=None, x0=None, temperature=1., noise_dropout=0.,
  932. score_corrector=None, corrector_kwargs=None, batch_size=None, x_T=None, start_T=None,
  933. log_every_t=None):
  934. if not log_every_t:
  935. log_every_t = self.log_every_t
  936. timesteps = self.num_timesteps
  937. if batch_size is not None:
  938. b = batch_size if batch_size is not None else shape[0]
  939. shape = [batch_size] + list(shape)
  940. else:
  941. b = batch_size = shape[0]
  942. if x_T is None:
  943. img = torch.randn(shape, device=self.device)
  944. else:
  945. img = x_T
  946. intermediates = []
  947. if cond is not None:
  948. if isinstance(cond, dict):
  949. cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
  950. [x[:batch_size] for x in cond[key]] for key in cond}
  951. else:
  952. cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
  953. if start_T is not None:
  954. timesteps = min(timesteps, start_T)
  955. iterator = tqdm(reversed(range(0, timesteps)), desc='Progressive Generation',
  956. total=timesteps) if verbose else reversed(
  957. range(0, timesteps))
  958. if type(temperature) == float:
  959. temperature = [temperature] * timesteps
  960. for i in iterator:
  961. ts = torch.full((b,), i, device=self.device, dtype=torch.long)
  962. if self.shorten_cond_schedule:
  963. assert self.model.conditioning_key != 'hybrid'
  964. tc = self.cond_ids[ts].to(cond.device)
  965. cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond))
  966. img, x0_partial = self.p_sample(img, cond, ts,
  967. clip_denoised=self.clip_denoised,
  968. quantize_denoised=quantize_denoised, return_x0=True,
  969. temperature=temperature[i], noise_dropout=noise_dropout,
  970. score_corrector=score_corrector, corrector_kwargs=corrector_kwargs)
  971. if mask is not None:
  972. assert x0 is not None
  973. img_orig = self.q_sample(x0, ts)
  974. img = img_orig * mask + (1. - mask) * img
  975. if i % log_every_t == 0 or i == timesteps - 1:
  976. intermediates.append(x0_partial)
  977. if callback:
  978. callback(i)
  979. if img_callback:
  980. img_callback(img, i)
  981. return img, intermediates
  982. @torch.no_grad()
  983. def p_sample_loop(self, cond, shape, return_intermediates=False,
  984. x_T=None, verbose=True, callback=None, timesteps=None, quantize_denoised=False,
  985. mask=None, x0=None, img_callback=None, start_T=None,
  986. log_every_t=None):
  987. if not log_every_t:
  988. log_every_t = self.log_every_t
  989. device = self.betas.device
  990. b = shape[0]
  991. if x_T is None:
  992. img = torch.randn(shape, device=device)
  993. else:
  994. img = x_T
  995. intermediates = [img]
  996. if timesteps is None:
  997. timesteps = self.num_timesteps
  998. if start_T is not None:
  999. timesteps = min(timesteps, start_T)
  1000. iterator = tqdm(reversed(range(0, timesteps)), desc='Sampling t', total=timesteps) if verbose else reversed(
  1001. range(0, timesteps))
  1002. if mask is not None:
  1003. assert x0 is not None
  1004. assert x0.shape[2:3] == mask.shape[2:3] # spatial size has to match
  1005. for i in iterator:
  1006. ts = torch.full((b,), i, device=device, dtype=torch.long)
  1007. if self.shorten_cond_schedule:
  1008. assert self.model.conditioning_key != 'hybrid'
  1009. tc = self.cond_ids[ts].to(cond.device)
  1010. cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond))
  1011. img = self.p_sample(img, cond, ts,
  1012. clip_denoised=self.clip_denoised,
  1013. quantize_denoised=quantize_denoised)
  1014. if mask is not None:
  1015. img_orig = self.q_sample(x0, ts)
  1016. img = img_orig * mask + (1. - mask) * img
  1017. if i % log_every_t == 0 or i == timesteps - 1:
  1018. intermediates.append(img)
  1019. if callback:
  1020. callback(i)
  1021. if img_callback:
  1022. img_callback(img, i)
  1023. if return_intermediates:
  1024. return img, intermediates
  1025. return img
  1026. @torch.no_grad()
  1027. def sample(self, cond, batch_size=16, return_intermediates=False, x_T=None,
  1028. verbose=True, timesteps=None, quantize_denoised=False,
  1029. mask=None, x0=None, shape=None,**kwargs):
  1030. if shape is None:
  1031. shape = (batch_size, self.channels, self.image_size, self.image_size)
  1032. if cond is not None:
  1033. if isinstance(cond, dict):
  1034. cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
  1035. [x[:batch_size] for x in cond[key]] for key in cond}
  1036. else:
  1037. cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
  1038. return self.p_sample_loop(cond,
  1039. shape,
  1040. return_intermediates=return_intermediates, x_T=x_T,
  1041. verbose=verbose, timesteps=timesteps, quantize_denoised=quantize_denoised,
  1042. mask=mask, x0=x0)
  1043. @torch.no_grad()
  1044. def sample_log(self,cond,batch_size,ddim, ddim_steps,**kwargs):
  1045. if ddim:
  1046. ddim_sampler = DDIMSampler(self)
  1047. shape = (self.channels, self.image_size, self.image_size)
  1048. samples, intermediates =ddim_sampler.sample(ddim_steps,batch_size,
  1049. shape,cond,verbose=False,**kwargs)
  1050. else:
  1051. samples, intermediates = self.sample(cond=cond, batch_size=batch_size,
  1052. return_intermediates=True,**kwargs)
  1053. return samples, intermediates
  1054. @torch.no_grad()
  1055. def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta=1., return_keys=None,
  1056. quantize_denoised=True, inpaint=True, plot_denoise_rows=False, plot_progressive_rows=True,
  1057. plot_diffusion_rows=True, **kwargs):
  1058. use_ddim = ddim_steps is not None
  1059. log = {}
  1060. z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key,
  1061. return_first_stage_outputs=True,
  1062. force_c_encode=True,
  1063. return_original_cond=True,
  1064. bs=N)
  1065. N = min(x.shape[0], N)
  1066. n_row = min(x.shape[0], n_row)
  1067. log["inputs"] = x
  1068. log["reconstruction"] = xrec
  1069. if self.model.conditioning_key is not None:
  1070. if hasattr(self.cond_stage_model, "decode"):
  1071. xc = self.cond_stage_model.decode(c)
  1072. log["conditioning"] = xc
  1073. elif self.cond_stage_key in ["caption"]:
  1074. xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["caption"])
  1075. log["conditioning"] = xc
  1076. elif self.cond_stage_key == 'class_label':
  1077. xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"])
  1078. log['conditioning'] = xc
  1079. elif isimage(xc):
  1080. log["conditioning"] = xc
  1081. if ismap(xc):
  1082. log["original_conditioning"] = self.to_rgb(xc)
  1083. if plot_diffusion_rows:
  1084. # get diffusion row
  1085. diffusion_row = []
  1086. z_start = z[:n_row]
  1087. for t in range(self.num_timesteps):
  1088. if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
  1089. t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
  1090. t = t.to(self.device).long()
  1091. noise = torch.randn_like(z_start)
  1092. z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise)
  1093. diffusion_row.append(self.decode_first_stage(z_noisy))
  1094. diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W
  1095. diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w')
  1096. diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w')
  1097. diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0])
  1098. log["diffusion_row"] = diffusion_grid
  1099. if sample:
  1100. # get denoise row
  1101. with self.ema_scope("Plotting"):
  1102. samples, z_denoise_row = self.sample_log(cond=c,batch_size=N,ddim=use_ddim,
  1103. ddim_steps=ddim_steps,eta=ddim_eta)
  1104. # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True)
  1105. x_samples = self.decode_first_stage(samples)
  1106. log["samples"] = x_samples
  1107. if plot_denoise_rows:
  1108. denoise_grid = self._get_denoise_row_from_list(z_denoise_row)
  1109. log["denoise_row"] = denoise_grid
  1110. if quantize_denoised and not isinstance(self.first_stage_model, AutoencoderKL) and not isinstance(
  1111. self.first_stage_model, IdentityFirstStage):
  1112. # also display when quantizing x0 while sampling
  1113. with self.ema_scope("Plotting Quantized Denoised"):
  1114. samples, z_denoise_row = self.sample_log(cond=c,batch_size=N,ddim=use_ddim,
  1115. ddim_steps=ddim_steps,eta=ddim_eta,
  1116. quantize_denoised=True)
  1117. # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True,
  1118. # quantize_denoised=True)
  1119. x_samples = self.decode_first_stage(samples.to(self.device))
  1120. log["samples_x0_quantized"] = x_samples
  1121. if inpaint:
  1122. # make a simple center square
  1123. h, w = z.shape[2], z.shape[3]
  1124. mask = torch.ones(N, h, w).to(self.device)
  1125. # zeros will be filled in
  1126. mask[:, h // 4:3 * h // 4, w // 4:3 * w // 4] = 0.
  1127. mask = mask[:, None, ...]
  1128. with self.ema_scope("Plotting Inpaint"):
  1129. samples, _ = self.sample_log(cond=c,batch_size=N,ddim=use_ddim, eta=ddim_eta,
  1130. ddim_steps=ddim_steps, x0=z[:N], mask=mask)
  1131. x_samples = self.decode_first_stage(samples.to(self.device))
  1132. log["samples_inpainting"] = x_samples
  1133. log["mask"] = mask
  1134. # outpaint
  1135. with self.ema_scope("Plotting Outpaint"):
  1136. samples, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,eta=ddim_eta,
  1137. ddim_steps=ddim_steps, x0=z[:N], mask=mask)
  1138. x_samples = self.decode_first_stage(samples.to(self.device))
  1139. log["samples_outpainting"] = x_samples
  1140. if plot_progressive_rows:
  1141. with self.ema_scope("Plotting Progressives"):
  1142. img, progressives = self.progressive_denoising(c,
  1143. shape=(self.channels, self.image_size, self.image_size),
  1144. batch_size=N)
  1145. prog_row = self._get_denoise_row_from_list(progressives, desc="Progressive Generation")
  1146. log["progressive_row"] = prog_row
  1147. if return_keys:
  1148. if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0:
  1149. return log
  1150. else:
  1151. return {key: log[key] for key in return_keys}
  1152. return log
  1153. def configure_optimizers(self):
  1154. lr = self.learning_rate
  1155. params = list(self.model.parameters())
  1156. if self.cond_stage_trainable:
  1157. print(f"{self.__class__.__name__}: Also optimizing conditioner params!")
  1158. params = params + list(self.cond_stage_model.parameters())
  1159. if self.learn_logvar:
  1160. print('Diffusion model optimizing logvar')
  1161. params.append(self.logvar)
  1162. opt = torch.optim.AdamW(params, lr=lr)
  1163. if self.use_scheduler:
  1164. assert 'target' in self.scheduler_config
  1165. scheduler = instantiate_from_config(self.scheduler_config)
  1166. print("Setting up LambdaLR scheduler...")
  1167. scheduler = [
  1168. {
  1169. 'scheduler': LambdaLR(opt, lr_lambda=scheduler.schedule),
  1170. 'interval': 'step',
  1171. 'frequency': 1
  1172. }]
  1173. return [opt], scheduler
  1174. return opt
  1175. @torch.no_grad()
  1176. def to_rgb(self, x):
  1177. x = x.float()
  1178. if not hasattr(self, "colorize"):
  1179. self.colorize = torch.randn(3, x.shape[1], 1, 1).to(x)
  1180. x = nn.functional.conv2d(x, weight=self.colorize)
  1181. x = 2. * (x - x.min()) / (x.max() - x.min()) - 1.
  1182. return x
  1183. class DiffusionWrapperV1(pl.LightningModule):
  1184. def __init__(self, diff_model_config, conditioning_key):
  1185. super().__init__()
  1186. self.diffusion_model = instantiate_from_config(diff_model_config)
  1187. self.conditioning_key = conditioning_key
  1188. assert self.conditioning_key in [None, 'concat', 'crossattn', 'hybrid', 'adm']
  1189. def forward(self, x, t, c_concat: list = None, c_crossattn: list = None):
  1190. if self.conditioning_key is None:
  1191. out = self.diffusion_model(x, t)
  1192. elif self.conditioning_key == 'concat':
  1193. xc = torch.cat([x] + c_concat, dim=1)
  1194. out = self.diffusion_model(xc, t)
  1195. elif self.conditioning_key == 'crossattn':
  1196. cc = torch.cat(c_crossattn, 1)
  1197. out = self.diffusion_model(x, t, context=cc)
  1198. elif self.conditioning_key == 'hybrid':
  1199. xc = torch.cat([x] + c_concat, dim=1)
  1200. cc = torch.cat(c_crossattn, 1)
  1201. out = self.diffusion_model(xc, t, context=cc)
  1202. elif self.conditioning_key == 'adm':
  1203. cc = c_crossattn[0]
  1204. out = self.diffusion_model(x, t, y=cc)
  1205. else:
  1206. raise NotImplementedError()
  1207. return out
  1208. class Layout2ImgDiffusionV1(LatentDiffusionV1):
  1209. # TODO: move all layout-specific hacks to this class
  1210. def __init__(self, cond_stage_key, *args, **kwargs):
  1211. assert cond_stage_key == 'coordinates_bbox', 'Layout2ImgDiffusion only for cond_stage_key="coordinates_bbox"'
  1212. super().__init__(*args, cond_stage_key=cond_stage_key, **kwargs)
  1213. def log_images(self, batch, N=8, *args, **kwargs):
  1214. logs = super().log_images(*args, batch=batch, N=N, **kwargs)
  1215. key = 'train' if self.training else 'validation'
  1216. dset = self.trainer.datamodule.datasets[key]
  1217. mapper = dset.conditional_builders[self.cond_stage_key]
  1218. bbox_imgs = []
  1219. map_fn = lambda catno: dset.get_textual_label(dset.get_category_id(catno))
  1220. for tknzd_bbox in batch[self.cond_stage_key][:N]:
  1221. bboximg = mapper.plot(tknzd_bbox.detach().cpu(), map_fn, (256, 256))
  1222. bbox_imgs.append(bboximg)
  1223. cond_img = torch.stack(bbox_imgs, dim=0)
  1224. logs['bbox_image'] = cond_img
  1225. return logs
  1226. ldm.models.diffusion.ddpm.DDPMV1 = DDPMV1
  1227. ldm.models.diffusion.ddpm.LatentDiffusionV1 = LatentDiffusionV1
  1228. ldm.models.diffusion.ddpm.DiffusionWrapperV1 = DiffusionWrapperV1
  1229. ldm.models.diffusion.ddpm.Layout2ImgDiffusionV1 = Layout2ImgDiffusionV1