ddpm_edit.py 66 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460
  1. """
  2. wild mixture of
  3. https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
  4. https://github.com/openai/improved-diffusion/blob/e94489283bb876ac1477d5dd7709bbbd2d9902ce/improved_diffusion/gaussian_diffusion.py
  5. https://github.com/CompVis/taming-transformers
  6. -- merci
  7. """
  8. # File modified by authors of InstructPix2Pix from original (https://github.com/CompVis/stable-diffusion).
  9. # See more details in LICENSE.
  10. import torch
  11. import torch.nn as nn
  12. import numpy as np
  13. import pytorch_lightning as pl
  14. from torch.optim.lr_scheduler import LambdaLR
  15. from einops import rearrange, repeat
  16. from contextlib import contextmanager
  17. from functools import partial
  18. from tqdm import tqdm
  19. from torchvision.utils import make_grid
  20. from pytorch_lightning.utilities.distributed import rank_zero_only
  21. from ldm.util import log_txt_as_img, exists, default, ismap, isimage, mean_flat, count_params, instantiate_from_config
  22. from ldm.modules.ema import LitEma
  23. from ldm.modules.distributions.distributions import normal_kl, DiagonalGaussianDistribution
  24. from ldm.models.autoencoder import IdentityFirstStage, AutoencoderKL
  25. from ldm.modules.diffusionmodules.util import make_beta_schedule, extract_into_tensor, noise_like
  26. from ldm.models.diffusion.ddim import DDIMSampler
  27. try:
  28. from ldm.models.autoencoder import VQModelInterface
  29. except Exception:
  30. class VQModelInterface:
  31. pass
  32. __conditioning_keys__ = {'concat': 'c_concat',
  33. 'crossattn': 'c_crossattn',
  34. 'adm': 'y'}
  35. def disabled_train(self, mode=True):
  36. """Overwrite model.train with this function to make sure train/eval mode
  37. does not change anymore."""
  38. return self
  39. def uniform_on_device(r1, r2, shape, device):
  40. return (r1 - r2) * torch.rand(*shape, device=device) + r2
  41. class DDPM(pl.LightningModule):
  42. # classic DDPM with Gaussian diffusion, in image space
  43. def __init__(self,
  44. unet_config,
  45. timesteps=1000,
  46. beta_schedule="linear",
  47. loss_type="l2",
  48. ckpt_path=None,
  49. ignore_keys=None,
  50. load_only_unet=False,
  51. monitor="val/loss",
  52. use_ema=True,
  53. first_stage_key="image",
  54. image_size=256,
  55. channels=3,
  56. log_every_t=100,
  57. clip_denoised=True,
  58. linear_start=1e-4,
  59. linear_end=2e-2,
  60. cosine_s=8e-3,
  61. given_betas=None,
  62. original_elbo_weight=0.,
  63. v_posterior=0., # weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta
  64. l_simple_weight=1.,
  65. conditioning_key=None,
  66. parameterization="eps", # all assuming fixed variance schedules
  67. scheduler_config=None,
  68. use_positional_encodings=False,
  69. learn_logvar=False,
  70. logvar_init=0.,
  71. load_ema=True,
  72. ):
  73. super().__init__()
  74. assert parameterization in ["eps", "x0"], 'currently only supporting "eps" and "x0"'
  75. self.parameterization = parameterization
  76. print(f"{self.__class__.__name__}: Running in {self.parameterization}-prediction mode")
  77. self.cond_stage_model = None
  78. self.clip_denoised = clip_denoised
  79. self.log_every_t = log_every_t
  80. self.first_stage_key = first_stage_key
  81. self.image_size = image_size # try conv?
  82. self.channels = channels
  83. self.use_positional_encodings = use_positional_encodings
  84. self.model = DiffusionWrapper(unet_config, conditioning_key)
  85. count_params(self.model, verbose=True)
  86. self.use_ema = use_ema
  87. self.use_scheduler = scheduler_config is not None
  88. if self.use_scheduler:
  89. self.scheduler_config = scheduler_config
  90. self.v_posterior = v_posterior
  91. self.original_elbo_weight = original_elbo_weight
  92. self.l_simple_weight = l_simple_weight
  93. if monitor is not None:
  94. self.monitor = monitor
  95. if self.use_ema and load_ema:
  96. self.model_ema = LitEma(self.model)
  97. print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
  98. if ckpt_path is not None:
  99. self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys or [], only_model=load_only_unet)
  100. # If initialing from EMA-only checkpoint, create EMA model after loading.
  101. if self.use_ema and not load_ema:
  102. self.model_ema = LitEma(self.model)
  103. print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
  104. self.register_schedule(given_betas=given_betas, beta_schedule=beta_schedule, timesteps=timesteps,
  105. linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s)
  106. self.loss_type = loss_type
  107. self.learn_logvar = learn_logvar
  108. self.logvar = torch.full(fill_value=logvar_init, size=(self.num_timesteps,))
  109. if self.learn_logvar:
  110. self.logvar = nn.Parameter(self.logvar, requires_grad=True)
  111. def register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000,
  112. linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
  113. if exists(given_betas):
  114. betas = given_betas
  115. else:
  116. betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end,
  117. cosine_s=cosine_s)
  118. alphas = 1. - betas
  119. alphas_cumprod = np.cumprod(alphas, axis=0)
  120. alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
  121. timesteps, = betas.shape
  122. self.num_timesteps = int(timesteps)
  123. self.linear_start = linear_start
  124. self.linear_end = linear_end
  125. assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep'
  126. to_torch = partial(torch.tensor, dtype=torch.float32)
  127. self.register_buffer('betas', to_torch(betas))
  128. self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
  129. self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev))
  130. # calculations for diffusion q(x_t | x_{t-1}) and others
  131. self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
  132. self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
  133. self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod)))
  134. self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod)))
  135. self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1)))
  136. # calculations for posterior q(x_{t-1} | x_t, x_0)
  137. posterior_variance = (1 - self.v_posterior) * betas * (1. - alphas_cumprod_prev) / (
  138. 1. - alphas_cumprod) + self.v_posterior * betas
  139. # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
  140. self.register_buffer('posterior_variance', to_torch(posterior_variance))
  141. # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
  142. self.register_buffer('posterior_log_variance_clipped', to_torch(np.log(np.maximum(posterior_variance, 1e-20))))
  143. self.register_buffer('posterior_mean_coef1', to_torch(
  144. betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)))
  145. self.register_buffer('posterior_mean_coef2', to_torch(
  146. (1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod)))
  147. if self.parameterization == "eps":
  148. lvlb_weights = self.betas ** 2 / (
  149. 2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod))
  150. elif self.parameterization == "x0":
  151. lvlb_weights = 0.5 * np.sqrt(torch.Tensor(alphas_cumprod)) / (2. * 1 - torch.Tensor(alphas_cumprod))
  152. else:
  153. raise NotImplementedError("mu not supported")
  154. # TODO how to choose this term
  155. lvlb_weights[0] = lvlb_weights[1]
  156. self.register_buffer('lvlb_weights', lvlb_weights, persistent=False)
  157. assert not torch.isnan(self.lvlb_weights).all()
  158. @contextmanager
  159. def ema_scope(self, context=None):
  160. if self.use_ema:
  161. self.model_ema.store(self.model.parameters())
  162. self.model_ema.copy_to(self.model)
  163. if context is not None:
  164. print(f"{context}: Switched to EMA weights")
  165. try:
  166. yield None
  167. finally:
  168. if self.use_ema:
  169. self.model_ema.restore(self.model.parameters())
  170. if context is not None:
  171. print(f"{context}: Restored training weights")
  172. def init_from_ckpt(self, path, ignore_keys=None, only_model=False):
  173. ignore_keys = ignore_keys or []
  174. sd = torch.load(path, map_location="cpu")
  175. if "state_dict" in list(sd.keys()):
  176. sd = sd["state_dict"]
  177. keys = list(sd.keys())
  178. # Our model adds additional channels to the first layer to condition on an input image.
  179. # For the first layer, copy existing channel weights and initialize new channel weights to zero.
  180. input_keys = [
  181. "model.diffusion_model.input_blocks.0.0.weight",
  182. "model_ema.diffusion_modelinput_blocks00weight",
  183. ]
  184. self_sd = self.state_dict()
  185. for input_key in input_keys:
  186. if input_key not in sd or input_key not in self_sd:
  187. continue
  188. input_weight = self_sd[input_key]
  189. if input_weight.size() != sd[input_key].size():
  190. print(f"Manual init: {input_key}")
  191. input_weight.zero_()
  192. input_weight[:, :4, :, :].copy_(sd[input_key])
  193. ignore_keys.append(input_key)
  194. for k in keys:
  195. for ik in ignore_keys:
  196. if k.startswith(ik):
  197. print(f"Deleting key {k} from state_dict.")
  198. del sd[k]
  199. missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(
  200. sd, strict=False)
  201. print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
  202. if missing:
  203. print(f"Missing Keys: {missing}")
  204. if unexpected:
  205. print(f"Unexpected Keys: {unexpected}")
  206. def q_mean_variance(self, x_start, t):
  207. """
  208. Get the distribution q(x_t | x_0).
  209. :param x_start: the [N x C x ...] tensor of noiseless inputs.
  210. :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
  211. :return: A tuple (mean, variance, log_variance), all of x_start's shape.
  212. """
  213. mean = (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start)
  214. variance = extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape)
  215. log_variance = extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape)
  216. return mean, variance, log_variance
  217. def predict_start_from_noise(self, x_t, t, noise):
  218. return (
  219. extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
  220. extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
  221. )
  222. def q_posterior(self, x_start, x_t, t):
  223. posterior_mean = (
  224. extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start +
  225. extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
  226. )
  227. posterior_variance = extract_into_tensor(self.posterior_variance, t, x_t.shape)
  228. posterior_log_variance_clipped = extract_into_tensor(self.posterior_log_variance_clipped, t, x_t.shape)
  229. return posterior_mean, posterior_variance, posterior_log_variance_clipped
  230. def p_mean_variance(self, x, t, clip_denoised: bool):
  231. model_out = self.model(x, t)
  232. if self.parameterization == "eps":
  233. x_recon = self.predict_start_from_noise(x, t=t, noise=model_out)
  234. elif self.parameterization == "x0":
  235. x_recon = model_out
  236. if clip_denoised:
  237. x_recon.clamp_(-1., 1.)
  238. model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
  239. return model_mean, posterior_variance, posterior_log_variance
  240. @torch.no_grad()
  241. def p_sample(self, x, t, clip_denoised=True, repeat_noise=False):
  242. b, *_, device = *x.shape, x.device
  243. model_mean, _, model_log_variance = self.p_mean_variance(x=x, t=t, clip_denoised=clip_denoised)
  244. noise = noise_like(x.shape, device, repeat_noise)
  245. # no noise when t == 0
  246. nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
  247. return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
  248. @torch.no_grad()
  249. def p_sample_loop(self, shape, return_intermediates=False):
  250. device = self.betas.device
  251. b = shape[0]
  252. img = torch.randn(shape, device=device)
  253. intermediates = [img]
  254. for i in tqdm(reversed(range(0, self.num_timesteps)), desc='Sampling t', total=self.num_timesteps):
  255. img = self.p_sample(img, torch.full((b,), i, device=device, dtype=torch.long),
  256. clip_denoised=self.clip_denoised)
  257. if i % self.log_every_t == 0 or i == self.num_timesteps - 1:
  258. intermediates.append(img)
  259. if return_intermediates:
  260. return img, intermediates
  261. return img
  262. @torch.no_grad()
  263. def sample(self, batch_size=16, return_intermediates=False):
  264. image_size = self.image_size
  265. channels = self.channels
  266. return self.p_sample_loop((batch_size, channels, image_size, image_size),
  267. return_intermediates=return_intermediates)
  268. def q_sample(self, x_start, t, noise=None):
  269. noise = default(noise, lambda: torch.randn_like(x_start))
  270. return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
  271. extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise)
  272. def get_loss(self, pred, target, mean=True):
  273. if self.loss_type == 'l1':
  274. loss = (target - pred).abs()
  275. if mean:
  276. loss = loss.mean()
  277. elif self.loss_type == 'l2':
  278. if mean:
  279. loss = torch.nn.functional.mse_loss(target, pred)
  280. else:
  281. loss = torch.nn.functional.mse_loss(target, pred, reduction='none')
  282. else:
  283. raise NotImplementedError("unknown loss type '{loss_type}'")
  284. return loss
  285. def p_losses(self, x_start, t, noise=None):
  286. noise = default(noise, lambda: torch.randn_like(x_start))
  287. x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
  288. model_out = self.model(x_noisy, t)
  289. loss_dict = {}
  290. if self.parameterization == "eps":
  291. target = noise
  292. elif self.parameterization == "x0":
  293. target = x_start
  294. else:
  295. raise NotImplementedError(f"Paramterization {self.parameterization} not yet supported")
  296. loss = self.get_loss(model_out, target, mean=False).mean(dim=[1, 2, 3])
  297. log_prefix = 'train' if self.training else 'val'
  298. loss_dict.update({f'{log_prefix}/loss_simple': loss.mean()})
  299. loss_simple = loss.mean() * self.l_simple_weight
  300. loss_vlb = (self.lvlb_weights[t] * loss).mean()
  301. loss_dict.update({f'{log_prefix}/loss_vlb': loss_vlb})
  302. loss = loss_simple + self.original_elbo_weight * loss_vlb
  303. loss_dict.update({f'{log_prefix}/loss': loss})
  304. return loss, loss_dict
  305. def forward(self, x, *args, **kwargs):
  306. # b, c, h, w, device, img_size, = *x.shape, x.device, self.image_size
  307. # assert h == img_size and w == img_size, f'height and width of image must be {img_size}'
  308. t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long()
  309. return self.p_losses(x, t, *args, **kwargs)
  310. def get_input(self, batch, k):
  311. return batch[k]
  312. def shared_step(self, batch):
  313. x = self.get_input(batch, self.first_stage_key)
  314. loss, loss_dict = self(x)
  315. return loss, loss_dict
  316. def training_step(self, batch, batch_idx):
  317. loss, loss_dict = self.shared_step(batch)
  318. self.log_dict(loss_dict, prog_bar=True,
  319. logger=True, on_step=True, on_epoch=True)
  320. self.log("global_step", self.global_step,
  321. prog_bar=True, logger=True, on_step=True, on_epoch=False)
  322. if self.use_scheduler:
  323. lr = self.optimizers().param_groups[0]['lr']
  324. self.log('lr_abs', lr, prog_bar=True, logger=True, on_step=True, on_epoch=False)
  325. return loss
  326. @torch.no_grad()
  327. def validation_step(self, batch, batch_idx):
  328. _, loss_dict_no_ema = self.shared_step(batch)
  329. with self.ema_scope():
  330. _, loss_dict_ema = self.shared_step(batch)
  331. loss_dict_ema = {f"{key}_ema": loss_dict_ema[key] for key in loss_dict_ema}
  332. self.log_dict(loss_dict_no_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True)
  333. self.log_dict(loss_dict_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True)
  334. def on_train_batch_end(self, *args, **kwargs):
  335. if self.use_ema:
  336. self.model_ema(self.model)
  337. def _get_rows_from_list(self, samples):
  338. n_imgs_per_row = len(samples)
  339. denoise_grid = rearrange(samples, 'n b c h w -> b n c h w')
  340. denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w')
  341. denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row)
  342. return denoise_grid
  343. @torch.no_grad()
  344. def log_images(self, batch, N=8, n_row=2, sample=True, return_keys=None, **kwargs):
  345. log = {}
  346. x = self.get_input(batch, self.first_stage_key)
  347. N = min(x.shape[0], N)
  348. n_row = min(x.shape[0], n_row)
  349. x = x.to(self.device)[:N]
  350. log["inputs"] = x
  351. # get diffusion row
  352. diffusion_row = []
  353. x_start = x[:n_row]
  354. for t in range(self.num_timesteps):
  355. if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
  356. t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
  357. t = t.to(self.device).long()
  358. noise = torch.randn_like(x_start)
  359. x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
  360. diffusion_row.append(x_noisy)
  361. log["diffusion_row"] = self._get_rows_from_list(diffusion_row)
  362. if sample:
  363. # get denoise row
  364. with self.ema_scope("Plotting"):
  365. samples, denoise_row = self.sample(batch_size=N, return_intermediates=True)
  366. log["samples"] = samples
  367. log["denoise_row"] = self._get_rows_from_list(denoise_row)
  368. if return_keys:
  369. if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0:
  370. return log
  371. else:
  372. return {key: log[key] for key in return_keys}
  373. return log
  374. def configure_optimizers(self):
  375. lr = self.learning_rate
  376. params = list(self.model.parameters())
  377. if self.learn_logvar:
  378. params = params + [self.logvar]
  379. opt = torch.optim.AdamW(params, lr=lr)
  380. return opt
  381. class LatentDiffusion(DDPM):
  382. """main class"""
  383. def __init__(self,
  384. first_stage_config,
  385. cond_stage_config,
  386. num_timesteps_cond=None,
  387. cond_stage_key="image",
  388. cond_stage_trainable=False,
  389. concat_mode=True,
  390. cond_stage_forward=None,
  391. conditioning_key=None,
  392. scale_factor=1.0,
  393. scale_by_std=False,
  394. load_ema=True,
  395. *args, **kwargs):
  396. self.num_timesteps_cond = default(num_timesteps_cond, 1)
  397. self.scale_by_std = scale_by_std
  398. assert self.num_timesteps_cond <= kwargs['timesteps']
  399. # for backwards compatibility after implementation of DiffusionWrapper
  400. if conditioning_key is None:
  401. conditioning_key = 'concat' if concat_mode else 'crossattn'
  402. if cond_stage_config == '__is_unconditional__':
  403. conditioning_key = None
  404. ckpt_path = kwargs.pop("ckpt_path", None)
  405. ignore_keys = kwargs.pop("ignore_keys", [])
  406. super().__init__(*args, conditioning_key=conditioning_key, load_ema=load_ema, **kwargs)
  407. self.concat_mode = concat_mode
  408. self.cond_stage_trainable = cond_stage_trainable
  409. self.cond_stage_key = cond_stage_key
  410. try:
  411. self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1
  412. except Exception:
  413. self.num_downs = 0
  414. if not scale_by_std:
  415. self.scale_factor = scale_factor
  416. else:
  417. self.register_buffer('scale_factor', torch.tensor(scale_factor))
  418. self.instantiate_first_stage(first_stage_config)
  419. self.instantiate_cond_stage(cond_stage_config)
  420. self.cond_stage_forward = cond_stage_forward
  421. self.clip_denoised = False
  422. self.bbox_tokenizer = None
  423. self.restarted_from_ckpt = False
  424. if ckpt_path is not None:
  425. self.init_from_ckpt(ckpt_path, ignore_keys)
  426. self.restarted_from_ckpt = True
  427. if self.use_ema and not load_ema:
  428. self.model_ema = LitEma(self.model)
  429. print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
  430. def make_cond_schedule(self, ):
  431. self.cond_ids = torch.full(size=(self.num_timesteps,), fill_value=self.num_timesteps - 1, dtype=torch.long)
  432. ids = torch.round(torch.linspace(0, self.num_timesteps - 1, self.num_timesteps_cond)).long()
  433. self.cond_ids[:self.num_timesteps_cond] = ids
  434. @rank_zero_only
  435. @torch.no_grad()
  436. def on_train_batch_start(self, batch, batch_idx, dataloader_idx):
  437. # only for very first batch
  438. if self.scale_by_std and self.current_epoch == 0 and self.global_step == 0 and batch_idx == 0 and not self.restarted_from_ckpt:
  439. assert self.scale_factor == 1., 'rather not use custom rescaling and std-rescaling simultaneously'
  440. # set rescale weight to 1./std of encodings
  441. print("### USING STD-RESCALING ###")
  442. x = super().get_input(batch, self.first_stage_key)
  443. x = x.to(self.device)
  444. encoder_posterior = self.encode_first_stage(x)
  445. z = self.get_first_stage_encoding(encoder_posterior).detach()
  446. del self.scale_factor
  447. self.register_buffer('scale_factor', 1. / z.flatten().std())
  448. print(f"setting self.scale_factor to {self.scale_factor}")
  449. print("### USING STD-RESCALING ###")
  450. def register_schedule(self,
  451. given_betas=None, beta_schedule="linear", timesteps=1000,
  452. linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
  453. super().register_schedule(given_betas, beta_schedule, timesteps, linear_start, linear_end, cosine_s)
  454. self.shorten_cond_schedule = self.num_timesteps_cond > 1
  455. if self.shorten_cond_schedule:
  456. self.make_cond_schedule()
  457. def instantiate_first_stage(self, config):
  458. model = instantiate_from_config(config)
  459. self.first_stage_model = model.eval()
  460. self.first_stage_model.train = disabled_train
  461. for param in self.first_stage_model.parameters():
  462. param.requires_grad = False
  463. def instantiate_cond_stage(self, config):
  464. if not self.cond_stage_trainable:
  465. if config == "__is_first_stage__":
  466. print("Using first stage also as cond stage.")
  467. self.cond_stage_model = self.first_stage_model
  468. elif config == "__is_unconditional__":
  469. print(f"Training {self.__class__.__name__} as an unconditional model.")
  470. self.cond_stage_model = None
  471. # self.be_unconditional = True
  472. else:
  473. model = instantiate_from_config(config)
  474. self.cond_stage_model = model.eval()
  475. self.cond_stage_model.train = disabled_train
  476. for param in self.cond_stage_model.parameters():
  477. param.requires_grad = False
  478. else:
  479. assert config != '__is_first_stage__'
  480. assert config != '__is_unconditional__'
  481. model = instantiate_from_config(config)
  482. self.cond_stage_model = model
  483. def _get_denoise_row_from_list(self, samples, desc='', force_no_decoder_quantization=False):
  484. denoise_row = []
  485. for zd in tqdm(samples, desc=desc):
  486. denoise_row.append(self.decode_first_stage(zd.to(self.device),
  487. force_not_quantize=force_no_decoder_quantization))
  488. n_imgs_per_row = len(denoise_row)
  489. denoise_row = torch.stack(denoise_row) # n_log_step, n_row, C, H, W
  490. denoise_grid = rearrange(denoise_row, 'n b c h w -> b n c h w')
  491. denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w')
  492. denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row)
  493. return denoise_grid
  494. def get_first_stage_encoding(self, encoder_posterior):
  495. if isinstance(encoder_posterior, DiagonalGaussianDistribution):
  496. z = encoder_posterior.sample()
  497. elif isinstance(encoder_posterior, torch.Tensor):
  498. z = encoder_posterior
  499. else:
  500. raise NotImplementedError(f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented")
  501. return self.scale_factor * z
  502. def get_learned_conditioning(self, c):
  503. if self.cond_stage_forward is None:
  504. if hasattr(self.cond_stage_model, 'encode') and callable(self.cond_stage_model.encode):
  505. c = self.cond_stage_model.encode(c)
  506. if isinstance(c, DiagonalGaussianDistribution):
  507. c = c.mode()
  508. else:
  509. c = self.cond_stage_model(c)
  510. else:
  511. assert hasattr(self.cond_stage_model, self.cond_stage_forward)
  512. c = getattr(self.cond_stage_model, self.cond_stage_forward)(c)
  513. return c
  514. def meshgrid(self, h, w):
  515. y = torch.arange(0, h).view(h, 1, 1).repeat(1, w, 1)
  516. x = torch.arange(0, w).view(1, w, 1).repeat(h, 1, 1)
  517. arr = torch.cat([y, x], dim=-1)
  518. return arr
  519. def delta_border(self, h, w):
  520. """
  521. :param h: height
  522. :param w: width
  523. :return: normalized distance to image border,
  524. wtith min distance = 0 at border and max dist = 0.5 at image center
  525. """
  526. lower_right_corner = torch.tensor([h - 1, w - 1]).view(1, 1, 2)
  527. arr = self.meshgrid(h, w) / lower_right_corner
  528. dist_left_up = torch.min(arr, dim=-1, keepdims=True)[0]
  529. dist_right_down = torch.min(1 - arr, dim=-1, keepdims=True)[0]
  530. edge_dist = torch.min(torch.cat([dist_left_up, dist_right_down], dim=-1), dim=-1)[0]
  531. return edge_dist
  532. def get_weighting(self, h, w, Ly, Lx, device):
  533. weighting = self.delta_border(h, w)
  534. weighting = torch.clip(weighting, self.split_input_params["clip_min_weight"],
  535. self.split_input_params["clip_max_weight"], )
  536. weighting = weighting.view(1, h * w, 1).repeat(1, 1, Ly * Lx).to(device)
  537. if self.split_input_params["tie_braker"]:
  538. L_weighting = self.delta_border(Ly, Lx)
  539. L_weighting = torch.clip(L_weighting,
  540. self.split_input_params["clip_min_tie_weight"],
  541. self.split_input_params["clip_max_tie_weight"])
  542. L_weighting = L_weighting.view(1, 1, Ly * Lx).to(device)
  543. weighting = weighting * L_weighting
  544. return weighting
  545. def get_fold_unfold(self, x, kernel_size, stride, uf=1, df=1): # todo load once not every time, shorten code
  546. """
  547. :param x: img of size (bs, c, h, w)
  548. :return: n img crops of size (n, bs, c, kernel_size[0], kernel_size[1])
  549. """
  550. bs, nc, h, w = x.shape
  551. # number of crops in image
  552. Ly = (h - kernel_size[0]) // stride[0] + 1
  553. Lx = (w - kernel_size[1]) // stride[1] + 1
  554. if uf == 1 and df == 1:
  555. fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
  556. unfold = torch.nn.Unfold(**fold_params)
  557. fold = torch.nn.Fold(output_size=x.shape[2:], **fold_params)
  558. weighting = self.get_weighting(kernel_size[0], kernel_size[1], Ly, Lx, x.device).to(x.dtype)
  559. normalization = fold(weighting).view(1, 1, h, w) # normalizes the overlap
  560. weighting = weighting.view((1, 1, kernel_size[0], kernel_size[1], Ly * Lx))
  561. elif uf > 1 and df == 1:
  562. fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
  563. unfold = torch.nn.Unfold(**fold_params)
  564. fold_params2 = dict(kernel_size=(kernel_size[0] * uf, kernel_size[0] * uf),
  565. dilation=1, padding=0,
  566. stride=(stride[0] * uf, stride[1] * uf))
  567. fold = torch.nn.Fold(output_size=(x.shape[2] * uf, x.shape[3] * uf), **fold_params2)
  568. weighting = self.get_weighting(kernel_size[0] * uf, kernel_size[1] * uf, Ly, Lx, x.device).to(x.dtype)
  569. normalization = fold(weighting).view(1, 1, h * uf, w * uf) # normalizes the overlap
  570. weighting = weighting.view((1, 1, kernel_size[0] * uf, kernel_size[1] * uf, Ly * Lx))
  571. elif df > 1 and uf == 1:
  572. fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
  573. unfold = torch.nn.Unfold(**fold_params)
  574. fold_params2 = dict(kernel_size=(kernel_size[0] // df, kernel_size[0] // df),
  575. dilation=1, padding=0,
  576. stride=(stride[0] // df, stride[1] // df))
  577. fold = torch.nn.Fold(output_size=(x.shape[2] // df, x.shape[3] // df), **fold_params2)
  578. weighting = self.get_weighting(kernel_size[0] // df, kernel_size[1] // df, Ly, Lx, x.device).to(x.dtype)
  579. normalization = fold(weighting).view(1, 1, h // df, w // df) # normalizes the overlap
  580. weighting = weighting.view((1, 1, kernel_size[0] // df, kernel_size[1] // df, Ly * Lx))
  581. else:
  582. raise NotImplementedError
  583. return fold, unfold, normalization, weighting
  584. @torch.no_grad()
  585. def get_input(self, batch, k, return_first_stage_outputs=False, force_c_encode=False,
  586. cond_key=None, return_original_cond=False, bs=None, uncond=0.05):
  587. x = super().get_input(batch, k)
  588. if bs is not None:
  589. x = x[:bs]
  590. x = x.to(self.device)
  591. encoder_posterior = self.encode_first_stage(x)
  592. z = self.get_first_stage_encoding(encoder_posterior).detach()
  593. cond_key = cond_key or self.cond_stage_key
  594. xc = super().get_input(batch, cond_key)
  595. if bs is not None:
  596. xc["c_crossattn"] = xc["c_crossattn"][:bs]
  597. xc["c_concat"] = xc["c_concat"][:bs]
  598. cond = {}
  599. # To support classifier-free guidance, randomly drop out only text conditioning 5%, only image conditioning 5%, and both 5%.
  600. random = torch.rand(x.size(0), device=x.device)
  601. prompt_mask = rearrange(random < 2 * uncond, "n -> n 1 1")
  602. input_mask = 1 - rearrange((random >= uncond).float() * (random < 3 * uncond).float(), "n -> n 1 1 1")
  603. null_prompt = self.get_learned_conditioning([""])
  604. cond["c_crossattn"] = [torch.where(prompt_mask, null_prompt, self.get_learned_conditioning(xc["c_crossattn"]).detach())]
  605. cond["c_concat"] = [input_mask * self.encode_first_stage((xc["c_concat"].to(self.device))).mode().detach()]
  606. out = [z, cond]
  607. if return_first_stage_outputs:
  608. xrec = self.decode_first_stage(z)
  609. out.extend([x, xrec])
  610. if return_original_cond:
  611. out.append(xc)
  612. return out
  613. @torch.no_grad()
  614. def decode_first_stage(self, z, predict_cids=False, force_not_quantize=False):
  615. if predict_cids:
  616. if z.dim() == 4:
  617. z = torch.argmax(z.exp(), dim=1).long()
  618. z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None)
  619. z = rearrange(z, 'b h w c -> b c h w').contiguous()
  620. z = 1. / self.scale_factor * z
  621. if hasattr(self, "split_input_params"):
  622. if self.split_input_params["patch_distributed_vq"]:
  623. ks = self.split_input_params["ks"] # eg. (128, 128)
  624. stride = self.split_input_params["stride"] # eg. (64, 64)
  625. uf = self.split_input_params["vqf"]
  626. bs, nc, h, w = z.shape
  627. if ks[0] > h or ks[1] > w:
  628. ks = (min(ks[0], h), min(ks[1], w))
  629. print("reducing Kernel")
  630. if stride[0] > h or stride[1] > w:
  631. stride = (min(stride[0], h), min(stride[1], w))
  632. print("reducing stride")
  633. fold, unfold, normalization, weighting = self.get_fold_unfold(z, ks, stride, uf=uf)
  634. z = unfold(z) # (bn, nc * prod(**ks), L)
  635. # 1. Reshape to img shape
  636. z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L )
  637. # 2. apply model loop over last dim
  638. if isinstance(self.first_stage_model, VQModelInterface):
  639. output_list = [self.first_stage_model.decode(z[:, :, :, :, i],
  640. force_not_quantize=predict_cids or force_not_quantize)
  641. for i in range(z.shape[-1])]
  642. else:
  643. output_list = [self.first_stage_model.decode(z[:, :, :, :, i])
  644. for i in range(z.shape[-1])]
  645. o = torch.stack(output_list, axis=-1) # # (bn, nc, ks[0], ks[1], L)
  646. o = o * weighting
  647. # Reverse 1. reshape to img shape
  648. o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L)
  649. # stitch crops together
  650. decoded = fold(o)
  651. decoded = decoded / normalization # norm is shape (1, 1, h, w)
  652. return decoded
  653. else:
  654. if isinstance(self.first_stage_model, VQModelInterface):
  655. return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize)
  656. else:
  657. return self.first_stage_model.decode(z)
  658. else:
  659. if isinstance(self.first_stage_model, VQModelInterface):
  660. return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize)
  661. else:
  662. return self.first_stage_model.decode(z)
  663. # same as above but without decorator
  664. def differentiable_decode_first_stage(self, z, predict_cids=False, force_not_quantize=False):
  665. if predict_cids:
  666. if z.dim() == 4:
  667. z = torch.argmax(z.exp(), dim=1).long()
  668. z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None)
  669. z = rearrange(z, 'b h w c -> b c h w').contiguous()
  670. z = 1. / self.scale_factor * z
  671. if hasattr(self, "split_input_params"):
  672. if self.split_input_params["patch_distributed_vq"]:
  673. ks = self.split_input_params["ks"] # eg. (128, 128)
  674. stride = self.split_input_params["stride"] # eg. (64, 64)
  675. uf = self.split_input_params["vqf"]
  676. bs, nc, h, w = z.shape
  677. if ks[0] > h or ks[1] > w:
  678. ks = (min(ks[0], h), min(ks[1], w))
  679. print("reducing Kernel")
  680. if stride[0] > h or stride[1] > w:
  681. stride = (min(stride[0], h), min(stride[1], w))
  682. print("reducing stride")
  683. fold, unfold, normalization, weighting = self.get_fold_unfold(z, ks, stride, uf=uf)
  684. z = unfold(z) # (bn, nc * prod(**ks), L)
  685. # 1. Reshape to img shape
  686. z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L )
  687. # 2. apply model loop over last dim
  688. if isinstance(self.first_stage_model, VQModelInterface):
  689. output_list = [self.first_stage_model.decode(z[:, :, :, :, i],
  690. force_not_quantize=predict_cids or force_not_quantize)
  691. for i in range(z.shape[-1])]
  692. else:
  693. output_list = [self.first_stage_model.decode(z[:, :, :, :, i])
  694. for i in range(z.shape[-1])]
  695. o = torch.stack(output_list, axis=-1) # # (bn, nc, ks[0], ks[1], L)
  696. o = o * weighting
  697. # Reverse 1. reshape to img shape
  698. o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L)
  699. # stitch crops together
  700. decoded = fold(o)
  701. decoded = decoded / normalization # norm is shape (1, 1, h, w)
  702. return decoded
  703. else:
  704. if isinstance(self.first_stage_model, VQModelInterface):
  705. return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize)
  706. else:
  707. return self.first_stage_model.decode(z)
  708. else:
  709. if isinstance(self.first_stage_model, VQModelInterface):
  710. return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize)
  711. else:
  712. return self.first_stage_model.decode(z)
  713. @torch.no_grad()
  714. def encode_first_stage(self, x):
  715. if hasattr(self, "split_input_params"):
  716. if self.split_input_params["patch_distributed_vq"]:
  717. ks = self.split_input_params["ks"] # eg. (128, 128)
  718. stride = self.split_input_params["stride"] # eg. (64, 64)
  719. df = self.split_input_params["vqf"]
  720. self.split_input_params['original_image_size'] = x.shape[-2:]
  721. bs, nc, h, w = x.shape
  722. if ks[0] > h or ks[1] > w:
  723. ks = (min(ks[0], h), min(ks[1], w))
  724. print("reducing Kernel")
  725. if stride[0] > h or stride[1] > w:
  726. stride = (min(stride[0], h), min(stride[1], w))
  727. print("reducing stride")
  728. fold, unfold, normalization, weighting = self.get_fold_unfold(x, ks, stride, df=df)
  729. z = unfold(x) # (bn, nc * prod(**ks), L)
  730. # Reshape to img shape
  731. z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L )
  732. output_list = [self.first_stage_model.encode(z[:, :, :, :, i])
  733. for i in range(z.shape[-1])]
  734. o = torch.stack(output_list, axis=-1)
  735. o = o * weighting
  736. # Reverse reshape to img shape
  737. o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L)
  738. # stitch crops together
  739. decoded = fold(o)
  740. decoded = decoded / normalization
  741. return decoded
  742. else:
  743. return self.first_stage_model.encode(x)
  744. else:
  745. return self.first_stage_model.encode(x)
  746. def shared_step(self, batch, **kwargs):
  747. x, c = self.get_input(batch, self.first_stage_key)
  748. loss = self(x, c)
  749. return loss
  750. def forward(self, x, c, *args, **kwargs):
  751. t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long()
  752. if self.model.conditioning_key is not None:
  753. assert c is not None
  754. if self.cond_stage_trainable:
  755. c = self.get_learned_conditioning(c)
  756. if self.shorten_cond_schedule: # TODO: drop this option
  757. tc = self.cond_ids[t].to(self.device)
  758. c = self.q_sample(x_start=c, t=tc, noise=torch.randn_like(c.float()))
  759. return self.p_losses(x, c, t, *args, **kwargs)
  760. def apply_model(self, x_noisy, t, cond, return_ids=False):
  761. if isinstance(cond, dict):
  762. # hybrid case, cond is exptected to be a dict
  763. pass
  764. else:
  765. if not isinstance(cond, list):
  766. cond = [cond]
  767. key = 'c_concat' if self.model.conditioning_key == 'concat' else 'c_crossattn'
  768. cond = {key: cond}
  769. if hasattr(self, "split_input_params"):
  770. assert len(cond) == 1 # todo can only deal with one conditioning atm
  771. assert not return_ids
  772. ks = self.split_input_params["ks"] # eg. (128, 128)
  773. stride = self.split_input_params["stride"] # eg. (64, 64)
  774. h, w = x_noisy.shape[-2:]
  775. fold, unfold, normalization, weighting = self.get_fold_unfold(x_noisy, ks, stride)
  776. z = unfold(x_noisy) # (bn, nc * prod(**ks), L)
  777. # Reshape to img shape
  778. z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L )
  779. z_list = [z[:, :, :, :, i] for i in range(z.shape[-1])]
  780. if self.cond_stage_key in ["image", "LR_image", "segmentation",
  781. 'bbox_img'] and self.model.conditioning_key: # todo check for completeness
  782. c_key = next(iter(cond.keys())) # get key
  783. c = next(iter(cond.values())) # get value
  784. assert (len(c) == 1) # todo extend to list with more than one elem
  785. c = c[0] # get element
  786. c = unfold(c)
  787. c = c.view((c.shape[0], -1, ks[0], ks[1], c.shape[-1])) # (bn, nc, ks[0], ks[1], L )
  788. cond_list = [{c_key: [c[:, :, :, :, i]]} for i in range(c.shape[-1])]
  789. elif self.cond_stage_key == 'coordinates_bbox':
  790. assert 'original_image_size' in self.split_input_params, 'BoudingBoxRescaling is missing original_image_size'
  791. # assuming padding of unfold is always 0 and its dilation is always 1
  792. n_patches_per_row = int((w - ks[0]) / stride[0] + 1)
  793. full_img_h, full_img_w = self.split_input_params['original_image_size']
  794. # as we are operating on latents, we need the factor from the original image size to the
  795. # spatial latent size to properly rescale the crops for regenerating the bbox annotations
  796. num_downs = self.first_stage_model.encoder.num_resolutions - 1
  797. rescale_latent = 2 ** (num_downs)
  798. # get top left postions of patches as conforming for the bbbox tokenizer, therefore we
  799. # need to rescale the tl patch coordinates to be in between (0,1)
  800. tl_patch_coordinates = [(rescale_latent * stride[0] * (patch_nr % n_patches_per_row) / full_img_w,
  801. rescale_latent * stride[1] * (patch_nr // n_patches_per_row) / full_img_h)
  802. for patch_nr in range(z.shape[-1])]
  803. # patch_limits are tl_coord, width and height coordinates as (x_tl, y_tl, h, w)
  804. patch_limits = [(x_tl, y_tl,
  805. rescale_latent * ks[0] / full_img_w,
  806. rescale_latent * ks[1] / full_img_h) for x_tl, y_tl in tl_patch_coordinates]
  807. # patch_values = [(np.arange(x_tl,min(x_tl+ks, 1.)),np.arange(y_tl,min(y_tl+ks, 1.))) for x_tl, y_tl in tl_patch_coordinates]
  808. # tokenize crop coordinates for the bounding boxes of the respective patches
  809. patch_limits_tknzd = [torch.LongTensor(self.bbox_tokenizer._crop_encoder(bbox))[None].to(self.device)
  810. for bbox in patch_limits] # list of length l with tensors of shape (1, 2)
  811. print(patch_limits_tknzd[0].shape)
  812. # cut tknzd crop position from conditioning
  813. assert isinstance(cond, dict), 'cond must be dict to be fed into model'
  814. cut_cond = cond['c_crossattn'][0][..., :-2].to(self.device)
  815. print(cut_cond.shape)
  816. adapted_cond = torch.stack([torch.cat([cut_cond, p], dim=1) for p in patch_limits_tknzd])
  817. adapted_cond = rearrange(adapted_cond, 'l b n -> (l b) n')
  818. print(adapted_cond.shape)
  819. adapted_cond = self.get_learned_conditioning(adapted_cond)
  820. print(adapted_cond.shape)
  821. adapted_cond = rearrange(adapted_cond, '(l b) n d -> l b n d', l=z.shape[-1])
  822. print(adapted_cond.shape)
  823. cond_list = [{'c_crossattn': [e]} for e in adapted_cond]
  824. else:
  825. cond_list = [cond for i in range(z.shape[-1])] # Todo make this more efficient
  826. # apply model by loop over crops
  827. output_list = [self.model(z_list[i], t, **cond_list[i]) for i in range(z.shape[-1])]
  828. assert not isinstance(output_list[0],
  829. tuple) # todo cant deal with multiple model outputs check this never happens
  830. o = torch.stack(output_list, axis=-1)
  831. o = o * weighting
  832. # Reverse reshape to img shape
  833. o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L)
  834. # stitch crops together
  835. x_recon = fold(o) / normalization
  836. else:
  837. x_recon = self.model(x_noisy, t, **cond)
  838. if isinstance(x_recon, tuple) and not return_ids:
  839. return x_recon[0]
  840. else:
  841. return x_recon
  842. def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
  843. return (extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart) / \
  844. extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
  845. def _prior_bpd(self, x_start):
  846. """
  847. Get the prior KL term for the variational lower-bound, measured in
  848. bits-per-dim.
  849. This term can't be optimized, as it only depends on the encoder.
  850. :param x_start: the [N x C x ...] tensor of inputs.
  851. :return: a batch of [N] KL values (in bits), one per batch element.
  852. """
  853. batch_size = x_start.shape[0]
  854. t = torch.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device)
  855. qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t)
  856. kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0)
  857. return mean_flat(kl_prior) / np.log(2.0)
  858. def p_losses(self, x_start, cond, t, noise=None):
  859. noise = default(noise, lambda: torch.randn_like(x_start))
  860. x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
  861. model_output = self.apply_model(x_noisy, t, cond)
  862. loss_dict = {}
  863. prefix = 'train' if self.training else 'val'
  864. if self.parameterization == "x0":
  865. target = x_start
  866. elif self.parameterization == "eps":
  867. target = noise
  868. else:
  869. raise NotImplementedError()
  870. loss_simple = self.get_loss(model_output, target, mean=False).mean([1, 2, 3])
  871. loss_dict.update({f'{prefix}/loss_simple': loss_simple.mean()})
  872. logvar_t = self.logvar[t].to(self.device)
  873. loss = loss_simple / torch.exp(logvar_t) + logvar_t
  874. # loss = loss_simple / torch.exp(self.logvar) + self.logvar
  875. if self.learn_logvar:
  876. loss_dict.update({f'{prefix}/loss_gamma': loss.mean()})
  877. loss_dict.update({'logvar': self.logvar.data.mean()})
  878. loss = self.l_simple_weight * loss.mean()
  879. loss_vlb = self.get_loss(model_output, target, mean=False).mean(dim=(1, 2, 3))
  880. loss_vlb = (self.lvlb_weights[t] * loss_vlb).mean()
  881. loss_dict.update({f'{prefix}/loss_vlb': loss_vlb})
  882. loss += (self.original_elbo_weight * loss_vlb)
  883. loss_dict.update({f'{prefix}/loss': loss})
  884. return loss, loss_dict
  885. def p_mean_variance(self, x, c, t, clip_denoised: bool, return_codebook_ids=False, quantize_denoised=False,
  886. return_x0=False, score_corrector=None, corrector_kwargs=None):
  887. t_in = t
  888. model_out = self.apply_model(x, t_in, c, return_ids=return_codebook_ids)
  889. if score_corrector is not None:
  890. assert self.parameterization == "eps"
  891. model_out = score_corrector.modify_score(self, model_out, x, t, c, **corrector_kwargs)
  892. if return_codebook_ids:
  893. model_out, logits = model_out
  894. if self.parameterization == "eps":
  895. x_recon = self.predict_start_from_noise(x, t=t, noise=model_out)
  896. elif self.parameterization == "x0":
  897. x_recon = model_out
  898. else:
  899. raise NotImplementedError()
  900. if clip_denoised:
  901. x_recon.clamp_(-1., 1.)
  902. if quantize_denoised:
  903. x_recon, _, [_, _, indices] = self.first_stage_model.quantize(x_recon)
  904. model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
  905. if return_codebook_ids:
  906. return model_mean, posterior_variance, posterior_log_variance, logits
  907. elif return_x0:
  908. return model_mean, posterior_variance, posterior_log_variance, x_recon
  909. else:
  910. return model_mean, posterior_variance, posterior_log_variance
  911. @torch.no_grad()
  912. def p_sample(self, x, c, t, clip_denoised=False, repeat_noise=False,
  913. return_codebook_ids=False, quantize_denoised=False, return_x0=False,
  914. temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None):
  915. b, *_, device = *x.shape, x.device
  916. outputs = self.p_mean_variance(x=x, c=c, t=t, clip_denoised=clip_denoised,
  917. return_codebook_ids=return_codebook_ids,
  918. quantize_denoised=quantize_denoised,
  919. return_x0=return_x0,
  920. score_corrector=score_corrector, corrector_kwargs=corrector_kwargs)
  921. if return_codebook_ids:
  922. raise DeprecationWarning("Support dropped.")
  923. model_mean, _, model_log_variance, logits = outputs
  924. elif return_x0:
  925. model_mean, _, model_log_variance, x0 = outputs
  926. else:
  927. model_mean, _, model_log_variance = outputs
  928. noise = noise_like(x.shape, device, repeat_noise) * temperature
  929. if noise_dropout > 0.:
  930. noise = torch.nn.functional.dropout(noise, p=noise_dropout)
  931. # no noise when t == 0
  932. nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
  933. if return_codebook_ids:
  934. return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, logits.argmax(dim=1)
  935. if return_x0:
  936. return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, x0
  937. else:
  938. return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
  939. @torch.no_grad()
  940. def progressive_denoising(self, cond, shape, verbose=True, callback=None, quantize_denoised=False,
  941. img_callback=None, mask=None, x0=None, temperature=1., noise_dropout=0.,
  942. score_corrector=None, corrector_kwargs=None, batch_size=None, x_T=None, start_T=None,
  943. log_every_t=None):
  944. if not log_every_t:
  945. log_every_t = self.log_every_t
  946. timesteps = self.num_timesteps
  947. if batch_size is not None:
  948. b = batch_size if batch_size is not None else shape[0]
  949. shape = [batch_size] + list(shape)
  950. else:
  951. b = batch_size = shape[0]
  952. if x_T is None:
  953. img = torch.randn(shape, device=self.device)
  954. else:
  955. img = x_T
  956. intermediates = []
  957. if cond is not None:
  958. if isinstance(cond, dict):
  959. cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
  960. [x[:batch_size] for x in cond[key]] for key in cond}
  961. else:
  962. cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
  963. if start_T is not None:
  964. timesteps = min(timesteps, start_T)
  965. iterator = tqdm(reversed(range(0, timesteps)), desc='Progressive Generation',
  966. total=timesteps) if verbose else reversed(
  967. range(0, timesteps))
  968. if type(temperature) == float:
  969. temperature = [temperature] * timesteps
  970. for i in iterator:
  971. ts = torch.full((b,), i, device=self.device, dtype=torch.long)
  972. if self.shorten_cond_schedule:
  973. assert self.model.conditioning_key != 'hybrid'
  974. tc = self.cond_ids[ts].to(cond.device)
  975. cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond))
  976. img, x0_partial = self.p_sample(img, cond, ts,
  977. clip_denoised=self.clip_denoised,
  978. quantize_denoised=quantize_denoised, return_x0=True,
  979. temperature=temperature[i], noise_dropout=noise_dropout,
  980. score_corrector=score_corrector, corrector_kwargs=corrector_kwargs)
  981. if mask is not None:
  982. assert x0 is not None
  983. img_orig = self.q_sample(x0, ts)
  984. img = img_orig * mask + (1. - mask) * img
  985. if i % log_every_t == 0 or i == timesteps - 1:
  986. intermediates.append(x0_partial)
  987. if callback:
  988. callback(i)
  989. if img_callback:
  990. img_callback(img, i)
  991. return img, intermediates
  992. @torch.no_grad()
  993. def p_sample_loop(self, cond, shape, return_intermediates=False,
  994. x_T=None, verbose=True, callback=None, timesteps=None, quantize_denoised=False,
  995. mask=None, x0=None, img_callback=None, start_T=None,
  996. log_every_t=None):
  997. if not log_every_t:
  998. log_every_t = self.log_every_t
  999. device = self.betas.device
  1000. b = shape[0]
  1001. if x_T is None:
  1002. img = torch.randn(shape, device=device)
  1003. else:
  1004. img = x_T
  1005. intermediates = [img]
  1006. if timesteps is None:
  1007. timesteps = self.num_timesteps
  1008. if start_T is not None:
  1009. timesteps = min(timesteps, start_T)
  1010. iterator = tqdm(reversed(range(0, timesteps)), desc='Sampling t', total=timesteps) if verbose else reversed(
  1011. range(0, timesteps))
  1012. if mask is not None:
  1013. assert x0 is not None
  1014. assert x0.shape[2:3] == mask.shape[2:3] # spatial size has to match
  1015. for i in iterator:
  1016. ts = torch.full((b,), i, device=device, dtype=torch.long)
  1017. if self.shorten_cond_schedule:
  1018. assert self.model.conditioning_key != 'hybrid'
  1019. tc = self.cond_ids[ts].to(cond.device)
  1020. cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond))
  1021. img = self.p_sample(img, cond, ts,
  1022. clip_denoised=self.clip_denoised,
  1023. quantize_denoised=quantize_denoised)
  1024. if mask is not None:
  1025. img_orig = self.q_sample(x0, ts)
  1026. img = img_orig * mask + (1. - mask) * img
  1027. if i % log_every_t == 0 or i == timesteps - 1:
  1028. intermediates.append(img)
  1029. if callback:
  1030. callback(i)
  1031. if img_callback:
  1032. img_callback(img, i)
  1033. if return_intermediates:
  1034. return img, intermediates
  1035. return img
  1036. @torch.no_grad()
  1037. def sample(self, cond, batch_size=16, return_intermediates=False, x_T=None,
  1038. verbose=True, timesteps=None, quantize_denoised=False,
  1039. mask=None, x0=None, shape=None,**kwargs):
  1040. if shape is None:
  1041. shape = (batch_size, self.channels, self.image_size, self.image_size)
  1042. if cond is not None:
  1043. if isinstance(cond, dict):
  1044. cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
  1045. [x[:batch_size] for x in cond[key]] for key in cond}
  1046. else:
  1047. cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
  1048. return self.p_sample_loop(cond,
  1049. shape,
  1050. return_intermediates=return_intermediates, x_T=x_T,
  1051. verbose=verbose, timesteps=timesteps, quantize_denoised=quantize_denoised,
  1052. mask=mask, x0=x0)
  1053. @torch.no_grad()
  1054. def sample_log(self,cond,batch_size,ddim, ddim_steps,**kwargs):
  1055. if ddim:
  1056. ddim_sampler = DDIMSampler(self)
  1057. shape = (self.channels, self.image_size, self.image_size)
  1058. samples, intermediates =ddim_sampler.sample(ddim_steps,batch_size,
  1059. shape,cond,verbose=False,**kwargs)
  1060. else:
  1061. samples, intermediates = self.sample(cond=cond, batch_size=batch_size,
  1062. return_intermediates=True,**kwargs)
  1063. return samples, intermediates
  1064. @torch.no_grad()
  1065. def log_images(self, batch, N=4, n_row=4, sample=True, ddim_steps=200, ddim_eta=1., return_keys=None,
  1066. quantize_denoised=True, inpaint=False, plot_denoise_rows=False, plot_progressive_rows=False,
  1067. plot_diffusion_rows=False, **kwargs):
  1068. use_ddim = False
  1069. log = {}
  1070. z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key,
  1071. return_first_stage_outputs=True,
  1072. force_c_encode=True,
  1073. return_original_cond=True,
  1074. bs=N, uncond=0)
  1075. N = min(x.shape[0], N)
  1076. n_row = min(x.shape[0], n_row)
  1077. log["inputs"] = x
  1078. log["reals"] = xc["c_concat"]
  1079. log["reconstruction"] = xrec
  1080. if self.model.conditioning_key is not None:
  1081. if hasattr(self.cond_stage_model, "decode"):
  1082. xc = self.cond_stage_model.decode(c)
  1083. log["conditioning"] = xc
  1084. elif self.cond_stage_key in ["caption"]:
  1085. xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["caption"])
  1086. log["conditioning"] = xc
  1087. elif self.cond_stage_key == 'class_label':
  1088. xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"])
  1089. log['conditioning'] = xc
  1090. elif isimage(xc):
  1091. log["conditioning"] = xc
  1092. if ismap(xc):
  1093. log["original_conditioning"] = self.to_rgb(xc)
  1094. if plot_diffusion_rows:
  1095. # get diffusion row
  1096. diffusion_row = []
  1097. z_start = z[:n_row]
  1098. for t in range(self.num_timesteps):
  1099. if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
  1100. t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
  1101. t = t.to(self.device).long()
  1102. noise = torch.randn_like(z_start)
  1103. z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise)
  1104. diffusion_row.append(self.decode_first_stage(z_noisy))
  1105. diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W
  1106. diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w')
  1107. diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w')
  1108. diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0])
  1109. log["diffusion_row"] = diffusion_grid
  1110. if sample:
  1111. # get denoise row
  1112. with self.ema_scope("Plotting"):
  1113. samples, z_denoise_row = self.sample_log(cond=c,batch_size=N,ddim=use_ddim,
  1114. ddim_steps=ddim_steps,eta=ddim_eta)
  1115. # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True)
  1116. x_samples = self.decode_first_stage(samples)
  1117. log["samples"] = x_samples
  1118. if plot_denoise_rows:
  1119. denoise_grid = self._get_denoise_row_from_list(z_denoise_row)
  1120. log["denoise_row"] = denoise_grid
  1121. if quantize_denoised and not isinstance(self.first_stage_model, AutoencoderKL) and not isinstance(
  1122. self.first_stage_model, IdentityFirstStage):
  1123. # also display when quantizing x0 while sampling
  1124. with self.ema_scope("Plotting Quantized Denoised"):
  1125. samples, z_denoise_row = self.sample_log(cond=c,batch_size=N,ddim=use_ddim,
  1126. ddim_steps=ddim_steps,eta=ddim_eta,
  1127. quantize_denoised=True)
  1128. # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True,
  1129. # quantize_denoised=True)
  1130. x_samples = self.decode_first_stage(samples.to(self.device))
  1131. log["samples_x0_quantized"] = x_samples
  1132. if inpaint:
  1133. # make a simple center square
  1134. h, w = z.shape[2], z.shape[3]
  1135. mask = torch.ones(N, h, w).to(self.device)
  1136. # zeros will be filled in
  1137. mask[:, h // 4:3 * h // 4, w // 4:3 * w // 4] = 0.
  1138. mask = mask[:, None, ...]
  1139. with self.ema_scope("Plotting Inpaint"):
  1140. samples, _ = self.sample_log(cond=c,batch_size=N,ddim=use_ddim, eta=ddim_eta,
  1141. ddim_steps=ddim_steps, x0=z[:N], mask=mask)
  1142. x_samples = self.decode_first_stage(samples.to(self.device))
  1143. log["samples_inpainting"] = x_samples
  1144. log["mask"] = mask
  1145. # outpaint
  1146. with self.ema_scope("Plotting Outpaint"):
  1147. samples, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,eta=ddim_eta,
  1148. ddim_steps=ddim_steps, x0=z[:N], mask=mask)
  1149. x_samples = self.decode_first_stage(samples.to(self.device))
  1150. log["samples_outpainting"] = x_samples
  1151. if plot_progressive_rows:
  1152. with self.ema_scope("Plotting Progressives"):
  1153. img, progressives = self.progressive_denoising(c,
  1154. shape=(self.channels, self.image_size, self.image_size),
  1155. batch_size=N)
  1156. prog_row = self._get_denoise_row_from_list(progressives, desc="Progressive Generation")
  1157. log["progressive_row"] = prog_row
  1158. if return_keys:
  1159. if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0:
  1160. return log
  1161. else:
  1162. return {key: log[key] for key in return_keys}
  1163. return log
  1164. def configure_optimizers(self):
  1165. lr = self.learning_rate
  1166. params = list(self.model.parameters())
  1167. if self.cond_stage_trainable:
  1168. print(f"{self.__class__.__name__}: Also optimizing conditioner params!")
  1169. params = params + list(self.cond_stage_model.parameters())
  1170. if self.learn_logvar:
  1171. print('Diffusion model optimizing logvar')
  1172. params.append(self.logvar)
  1173. opt = torch.optim.AdamW(params, lr=lr)
  1174. if self.use_scheduler:
  1175. assert 'target' in self.scheduler_config
  1176. scheduler = instantiate_from_config(self.scheduler_config)
  1177. print("Setting up LambdaLR scheduler...")
  1178. scheduler = [
  1179. {
  1180. 'scheduler': LambdaLR(opt, lr_lambda=scheduler.schedule),
  1181. 'interval': 'step',
  1182. 'frequency': 1
  1183. }]
  1184. return [opt], scheduler
  1185. return opt
  1186. @torch.no_grad()
  1187. def to_rgb(self, x):
  1188. x = x.float()
  1189. if not hasattr(self, "colorize"):
  1190. self.colorize = torch.randn(3, x.shape[1], 1, 1).to(x)
  1191. x = nn.functional.conv2d(x, weight=self.colorize)
  1192. x = 2. * (x - x.min()) / (x.max() - x.min()) - 1.
  1193. return x
  1194. class DiffusionWrapper(pl.LightningModule):
  1195. def __init__(self, diff_model_config, conditioning_key):
  1196. super().__init__()
  1197. self.diffusion_model = instantiate_from_config(diff_model_config)
  1198. self.conditioning_key = conditioning_key
  1199. assert self.conditioning_key in [None, 'concat', 'crossattn', 'hybrid', 'adm']
  1200. def forward(self, x, t, c_concat: list = None, c_crossattn: list = None):
  1201. if self.conditioning_key is None:
  1202. out = self.diffusion_model(x, t)
  1203. elif self.conditioning_key == 'concat':
  1204. xc = torch.cat([x] + c_concat, dim=1)
  1205. out = self.diffusion_model(xc, t)
  1206. elif self.conditioning_key == 'crossattn':
  1207. cc = torch.cat(c_crossattn, 1)
  1208. out = self.diffusion_model(x, t, context=cc)
  1209. elif self.conditioning_key == 'hybrid':
  1210. xc = torch.cat([x] + c_concat, dim=1)
  1211. cc = torch.cat(c_crossattn, 1)
  1212. out = self.diffusion_model(xc, t, context=cc)
  1213. elif self.conditioning_key == 'adm':
  1214. cc = c_crossattn[0]
  1215. out = self.diffusion_model(x, t, y=cc)
  1216. else:
  1217. raise NotImplementedError()
  1218. return out
  1219. class Layout2ImgDiffusion(LatentDiffusion):
  1220. # TODO: move all layout-specific hacks to this class
  1221. def __init__(self, cond_stage_key, *args, **kwargs):
  1222. assert cond_stage_key == 'coordinates_bbox', 'Layout2ImgDiffusion only for cond_stage_key="coordinates_bbox"'
  1223. super().__init__(*args, cond_stage_key=cond_stage_key, **kwargs)
  1224. def log_images(self, batch, N=8, *args, **kwargs):
  1225. logs = super().log_images(*args, batch=batch, N=N, **kwargs)
  1226. key = 'train' if self.training else 'validation'
  1227. dset = self.trainer.datamodule.datasets[key]
  1228. mapper = dset.conditional_builders[self.cond_stage_key]
  1229. bbox_imgs = []
  1230. map_fn = lambda catno: dset.get_textual_label(dset.get_category_id(catno))
  1231. for tknzd_bbox in batch[self.cond_stage_key][:N]:
  1232. bboximg = mapper.plot(tknzd_bbox.detach().cpu(), map_fn, (256, 256))
  1233. bbox_imgs.append(bboximg)
  1234. cond_img = torch.stack(bbox_imgs, dim=0)
  1235. logs['bbox_image'] = cond_img
  1236. return logs