sd_samplers_custom_schedulers.py 367 B

123456789101112
  1. import torch
  2. def sgm_uniform(n, sigma_min, sigma_max, inner_model, device):
  3. start = inner_model.sigma_to_t(torch.tensor(sigma_max))
  4. end = inner_model.sigma_to_t(torch.tensor(sigma_min))
  5. sigs = [
  6. inner_model.t_to_sigma(ts)
  7. for ts in torch.linspace(start, end, n)[:-1]
  8. ]
  9. sigs += [0.0]
  10. return torch.FloatTensor(sigs).to(device)