|
@@ -9,7 +9,7 @@ import k_diffusion.sampling
|
|
import torchsde._brownian.brownian_interval
|
|
import torchsde._brownian.brownian_interval
|
|
import ldm.models.diffusion.ddim
|
|
import ldm.models.diffusion.ddim
|
|
import ldm.models.diffusion.plms
|
|
import ldm.models.diffusion.plms
|
|
-from modules import prompt_parser, devices, processing, images
|
|
|
|
|
|
+from modules import prompt_parser, devices, processing, images, sd_vae_approx
|
|
|
|
|
|
from modules.shared import opts, cmd_opts, state
|
|
from modules.shared import opts, cmd_opts, state
|
|
import modules.shared as shared
|
|
import modules.shared as shared
|
|
@@ -106,28 +106,31 @@ def setup_img2img_steps(p, steps=None):
|
|
return steps, t_enc
|
|
return steps, t_enc
|
|
|
|
|
|
|
|
|
|
-def single_sample_to_image(sample, approximation=False):
|
|
|
|
- if approximation:
|
|
|
|
- # https://discuss.huggingface.co/t/decoding-latents-to-rgb-without-upscaling/23204/2
|
|
|
|
- coefs = torch.tensor(
|
|
|
|
- [[ 0.298, 0.207, 0.208],
|
|
|
|
- [ 0.187, 0.286, 0.173],
|
|
|
|
- [-0.158, 0.189, 0.264],
|
|
|
|
- [-0.184, -0.271, -0.473]]).to(sample.device)
|
|
|
|
- x_sample = torch.einsum("lxy,lr -> rxy", sample, coefs)
|
|
|
|
|
|
+approximation_indexes = {"Full": 0, "Approx NN": 1, "Approx cheap": 2}
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+def single_sample_to_image(sample, approximation=None):
|
|
|
|
+ if approximation is None:
|
|
|
|
+ approximation = approximation_indexes.get(opts.show_progress_type, 0)
|
|
|
|
+
|
|
|
|
+ if approximation == 2:
|
|
|
|
+ x_sample = sd_vae_approx.cheap_approximation(sample)
|
|
|
|
+ elif approximation == 1:
|
|
|
|
+ x_sample = sd_vae_approx.model()(sample.to(devices.device, devices.dtype).unsqueeze(0))[0].detach()
|
|
else:
|
|
else:
|
|
x_sample = processing.decode_first_stage(shared.sd_model, sample.unsqueeze(0))[0]
|
|
x_sample = processing.decode_first_stage(shared.sd_model, sample.unsqueeze(0))[0]
|
|
|
|
+
|
|
x_sample = torch.clamp((x_sample + 1.0) / 2.0, min=0.0, max=1.0)
|
|
x_sample = torch.clamp((x_sample + 1.0) / 2.0, min=0.0, max=1.0)
|
|
x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
|
|
x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
|
|
x_sample = x_sample.astype(np.uint8)
|
|
x_sample = x_sample.astype(np.uint8)
|
|
return Image.fromarray(x_sample)
|
|
return Image.fromarray(x_sample)
|
|
|
|
|
|
|
|
|
|
-def sample_to_image(samples, index=0, approximation=False):
|
|
|
|
|
|
+def sample_to_image(samples, index=0, approximation=None):
|
|
return single_sample_to_image(samples[index], approximation)
|
|
return single_sample_to_image(samples[index], approximation)
|
|
|
|
|
|
|
|
|
|
-def samples_to_image_grid(samples, approximation=False):
|
|
|
|
|
|
+def samples_to_image_grid(samples, approximation=None):
|
|
return images.image_grid([single_sample_to_image(sample, approximation) for sample in samples])
|
|
return images.image_grid([single_sample_to_image(sample, approximation) for sample in samples])
|
|
|
|
|
|
|
|
|
|
@@ -136,7 +139,7 @@ def store_latent(decoded):
|
|
|
|
|
|
if opts.show_progress_every_n_steps > 0 and shared.state.sampling_step % opts.show_progress_every_n_steps == 0:
|
|
if opts.show_progress_every_n_steps > 0 and shared.state.sampling_step % opts.show_progress_every_n_steps == 0:
|
|
if not shared.parallel_processing_allowed:
|
|
if not shared.parallel_processing_allowed:
|
|
- shared.state.current_image = sample_to_image(decoded, approximation=opts.show_progress_approximate)
|
|
|
|
|
|
+ shared.state.current_image = sample_to_image(decoded)
|
|
|
|
|
|
|
|
|
|
class InterruptedException(BaseException):
|
|
class InterruptedException(BaseException):
|