|
@@ -24,10 +24,15 @@ from pytorch_lightning.utilities.distributed import rank_zero_only
|
|
from ldm.util import log_txt_as_img, exists, default, ismap, isimage, mean_flat, count_params, instantiate_from_config
|
|
from ldm.util import log_txt_as_img, exists, default, ismap, isimage, mean_flat, count_params, instantiate_from_config
|
|
from ldm.modules.ema import LitEma
|
|
from ldm.modules.ema import LitEma
|
|
from ldm.modules.distributions.distributions import normal_kl, DiagonalGaussianDistribution
|
|
from ldm.modules.distributions.distributions import normal_kl, DiagonalGaussianDistribution
|
|
-from ldm.models.autoencoder import VQModelInterface, IdentityFirstStage, AutoencoderKL
|
|
|
|
|
|
+from ldm.models.autoencoder import IdentityFirstStage, AutoencoderKL
|
|
from ldm.modules.diffusionmodules.util import make_beta_schedule, extract_into_tensor, noise_like
|
|
from ldm.modules.diffusionmodules.util import make_beta_schedule, extract_into_tensor, noise_like
|
|
from ldm.models.diffusion.ddim import DDIMSampler
|
|
from ldm.models.diffusion.ddim import DDIMSampler
|
|
|
|
|
|
|
|
+try:
|
|
|
|
+ from ldm.models.autoencoder import VQModelInterface
|
|
|
|
+except Exception:
|
|
|
|
+ class VQModelInterface:
|
|
|
|
+ pass
|
|
|
|
|
|
__conditioning_keys__ = {'concat': 'c_concat',
|
|
__conditioning_keys__ = {'concat': 'c_concat',
|
|
'crossattn': 'c_crossattn',
|
|
'crossattn': 'c_crossattn',
|