|
@@ -375,7 +375,7 @@ class DDPMV1(pl.LightningModule):
|
|
|
|
|
|
@torch.no_grad()
|
|
@torch.no_grad()
|
|
def log_images(self, batch, N=8, n_row=2, sample=True, return_keys=None, **kwargs):
|
|
def log_images(self, batch, N=8, n_row=2, sample=True, return_keys=None, **kwargs):
|
|
- log = dict()
|
|
|
|
|
|
+ log = {}
|
|
x = self.get_input(batch, self.first_stage_key)
|
|
x = self.get_input(batch, self.first_stage_key)
|
|
N = min(x.shape[0], N)
|
|
N = min(x.shape[0], N)
|
|
n_row = min(x.shape[0], n_row)
|
|
n_row = min(x.shape[0], n_row)
|
|
@@ -383,7 +383,7 @@ class DDPMV1(pl.LightningModule):
|
|
log["inputs"] = x
|
|
log["inputs"] = x
|
|
|
|
|
|
# get diffusion row
|
|
# get diffusion row
|
|
- diffusion_row = list()
|
|
|
|
|
|
+ diffusion_row = []
|
|
x_start = x[:n_row]
|
|
x_start = x[:n_row]
|
|
|
|
|
|
for t in range(self.num_timesteps):
|
|
for t in range(self.num_timesteps):
|
|
@@ -1247,7 +1247,7 @@ class LatentDiffusionV1(DDPMV1):
|
|
|
|
|
|
use_ddim = ddim_steps is not None
|
|
use_ddim = ddim_steps is not None
|
|
|
|
|
|
- log = dict()
|
|
|
|
|
|
+ log = {}
|
|
z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key,
|
|
z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key,
|
|
return_first_stage_outputs=True,
|
|
return_first_stage_outputs=True,
|
|
force_c_encode=True,
|
|
force_c_encode=True,
|
|
@@ -1274,7 +1274,7 @@ class LatentDiffusionV1(DDPMV1):
|
|
|
|
|
|
if plot_diffusion_rows:
|
|
if plot_diffusion_rows:
|
|
# get diffusion row
|
|
# get diffusion row
|
|
- diffusion_row = list()
|
|
|
|
|
|
+ diffusion_row = []
|
|
z_start = z[:n_row]
|
|
z_start = z[:n_row]
|
|
for t in range(self.num_timesteps):
|
|
for t in range(self.num_timesteps):
|
|
if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
|
|
if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
|