sd_samplers_lcm.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104
  1. import torch
  2. from k_diffusion import utils, sampling
  3. from k_diffusion.external import DiscreteEpsDDPMDenoiser
  4. from k_diffusion.sampling import default_noise_sampler, trange
  5. from modules import shared, sd_samplers_cfg_denoiser, sd_samplers_kdiffusion, sd_samplers_common
  6. class LCMCompVisDenoiser(DiscreteEpsDDPMDenoiser):
  7. def __init__(self, model):
  8. timesteps = 1000
  9. original_timesteps = 50 # LCM Original Timesteps (default=50, for current version of LCM)
  10. self.skip_steps = timesteps // original_timesteps
  11. alphas_cumprod_valid = torch.zeros((original_timesteps), dtype=torch.float32, device=model.device)
  12. for x in range(original_timesteps):
  13. alphas_cumprod_valid[original_timesteps - 1 - x] = model.alphas_cumprod[timesteps - 1 - x * self.skip_steps]
  14. super().__init__(model, alphas_cumprod_valid, quantize=None)
  15. def get_sigmas(self, n=None,):
  16. if n is None:
  17. return sampling.append_zero(self.sigmas.flip(0))
  18. start = self.sigma_to_t(self.sigma_max)
  19. end = self.sigma_to_t(self.sigma_min)
  20. t = torch.linspace(start, end, n, device=shared.sd_model.device)
  21. return sampling.append_zero(self.t_to_sigma(t))
  22. def sigma_to_t(self, sigma, quantize=None):
  23. log_sigma = sigma.log()
  24. dists = log_sigma - self.log_sigmas[:, None]
  25. return dists.abs().argmin(dim=0).view(sigma.shape) * self.skip_steps + (self.skip_steps - 1)
  26. def t_to_sigma(self, timestep):
  27. t = torch.clamp(((timestep - (self.skip_steps - 1)) / self.skip_steps).float(), min=0, max=(len(self.sigmas) - 1))
  28. return super().t_to_sigma(t)
  29. def get_eps(self, *args, **kwargs):
  30. return self.inner_model.apply_model(*args, **kwargs)
  31. def get_scaled_out(self, sigma, output, input):
  32. sigma_data = 0.5
  33. scaled_timestep = utils.append_dims(self.sigma_to_t(sigma), output.ndim) * 10.0
  34. c_skip = sigma_data**2 / (scaled_timestep**2 + sigma_data**2)
  35. c_out = scaled_timestep / (scaled_timestep**2 + sigma_data**2) ** 0.5
  36. return c_out * output + c_skip * input
  37. def forward(self, input, sigma, **kwargs):
  38. c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)]
  39. eps = self.get_eps(input * c_in, self.sigma_to_t(sigma), **kwargs)
  40. return self.get_scaled_out(sigma, input + eps * c_out, input)
  41. def sample_lcm(model, x, sigmas, extra_args=None, callback=None, disable=None, noise_sampler=None):
  42. extra_args = {} if extra_args is None else extra_args
  43. noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
  44. s_in = x.new_ones([x.shape[0]])
  45. for i in trange(len(sigmas) - 1, disable=disable):
  46. denoised = model(x, sigmas[i] * s_in, **extra_args)
  47. if callback is not None:
  48. callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
  49. x = denoised
  50. if sigmas[i + 1] > 0:
  51. x += sigmas[i + 1] * noise_sampler(sigmas[i], sigmas[i + 1])
  52. return x
  53. class CFGDenoiserLCM(sd_samplers_cfg_denoiser.CFGDenoiser):
  54. @property
  55. def inner_model(self):
  56. if self.model_wrap is None:
  57. denoiser = LCMCompVisDenoiser
  58. self.model_wrap = denoiser(shared.sd_model)
  59. return self.model_wrap
  60. class LCMSampler(sd_samplers_kdiffusion.KDiffusionSampler):
  61. def __init__(self, funcname, sd_model, options=None):
  62. super().__init__(funcname, sd_model, options)
  63. self.model_wrap_cfg = CFGDenoiserLCM(self)
  64. self.model_wrap = self.model_wrap_cfg.inner_model
  65. samplers_lcm = [('LCM', sample_lcm, ['k_lcm'], {})]
  66. samplers_data_lcm = [
  67. sd_samplers_common.SamplerData(label, lambda model, funcname=funcname: LCMSampler(funcname, model), aliases, options)
  68. for label, funcname, aliases, options in samplers_lcm
  69. ]