|
@@ -2,6 +2,7 @@ import dataclasses
|
|
|
import torch
|
|
|
import k_diffusion
|
|
|
import numpy as np
|
|
|
+from scipy import stats
|
|
|
|
|
|
from modules import shared
|
|
|
|
|
@@ -115,6 +116,17 @@ def ddim_scheduler(n, sigma_min, sigma_max, inner_model, device):
|
|
|
return torch.FloatTensor(sigs).to(device)
|
|
|
|
|
|
|
|
|
+def beta_scheduler(n, sigma_min, sigma_max, inner_model, device):
|
|
|
+ # From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024) """
|
|
|
+ alpha = shared.opts.beta_dist_alpha
|
|
|
+ beta = shared.opts.beta_dist_beta
|
|
|
+ timesteps = 1 - np.linspace(0, 1, n)
|
|
|
+ timesteps = [stats.beta.ppf(x, alpha, beta) for x in timesteps]
|
|
|
+ sigmas = [sigma_min + (x * (sigma_max-sigma_min)) for x in timesteps]
|
|
|
+ sigmas += [0.0]
|
|
|
+ return torch.FloatTensor(sigmas).to(device)
|
|
|
+
|
|
|
+
|
|
|
schedulers = [
|
|
|
Scheduler('automatic', 'Automatic', None),
|
|
|
Scheduler('uniform', 'Uniform', uniform, need_inner_model=True),
|
|
@@ -127,6 +139,7 @@ schedulers = [
|
|
|
Scheduler('simple', 'Simple', simple_scheduler, need_inner_model=True),
|
|
|
Scheduler('normal', 'Normal', normal_scheduler, need_inner_model=True),
|
|
|
Scheduler('ddim', 'DDIM', ddim_scheduler, need_inner_model=True),
|
|
|
+ Scheduler('beta', 'Beta', beta_scheduler, need_inner_model=True),
|
|
|
]
|
|
|
|
|
|
schedulers_map = {**{x.name: x for x in schedulers}, **{x.label: x for x in schedulers}}
|