|
from comfy.k_diffusion import sampling as k_diffusion_sampling |
|
import torch |
|
|
|
def get_sigmas_karras(model, n, s_min, s_max, device): |
|
return k_diffusion_sampling.get_sigmas_karras(n, s_min, s_max, device=device) |
|
|
|
|
|
def get_sigmas_exponential(model, n, s_min, s_max, device): |
|
return k_diffusion_sampling.get_sigmas_exponential(n, s_min, s_max, device=device) |
|
|
|
|
|
def get_sigmas_normal(model, n, s_min, s_max, device): |
|
t_min, t_max = model.sigma_to_t(torch.tensor([s_min, s_max], device=device)) |
|
t = torch.linspace(t_max, t_min, n, device=device) |
|
return k_diffusion_sampling.append_zero(model.t_to_sigma(t)) |
|
|
|
|
|
def get_sigmas_simple(model, n, s_min, s_max, device): |
|
min_idx = torch.argmin(torch.abs(model.sigmas - s_min)) |
|
max_idx = torch.argmin(torch.abs(model.sigmas - s_max)) |
|
sigmas_slice = model.sigmas[min_idx:max_idx] |
|
ss = len(sigmas_slice) / n |
|
sigs = [float(s_max)] |
|
for x in range(1, n - 1): |
|
sigs += [float(sigmas_slice[-(1 + int(x * ss))])] |
|
sigs += [float(s_min), 0.0] |
|
return torch.tensor(sigs, device=device) |
|
|
|
|
|
def get_sigmas_ddim_uniform(model, n, s_min, s_max, device): |
|
t_min, t_max = model.sigma_to_t(torch.tensor([s_min, s_max], device=device)) |
|
ddim_timesteps = torch.linspace(t_max, t_min, n, dtype=torch.int16, device=device) |
|
sigs = [] |
|
for ts in ddim_timesteps: |
|
if ts > 999: |
|
ts = 999 |
|
sigs.append(model.t_to_sigma(ts)) |
|
sigs += [0.0] |
|
return torch.tensor(sigs, device=device) |
|
|
|
|
|
def get_sigmas_simple_test(model, n, s_min, s_max, device): |
|
min_idx = torch.argmin(torch.abs(model.sigmas - s_min)) |
|
max_idx = torch.argmin(torch.abs(model.sigmas - s_max)) |
|
sigmas_slice = model.sigmas[min_idx:max_idx] |
|
ss = len(sigmas_slice) / n |
|
sigs = [] |
|
for x in range(n): |
|
sigs += [float(sigmas_slice[-(1 + int(x * ss))])] |
|
sigs += [0.0] |
|
return torch.tensor(sigs, device=device) |
|
|
|
|
|
SCHEDULER_MAPPING = { |
|
"normal": get_sigmas_normal, |
|
"karras": get_sigmas_karras, |
|
"exponential": get_sigmas_exponential, |
|
"simple": get_sigmas_simple, |
|
"ddim_uniform": get_sigmas_ddim_uniform, |
|
"simple_test": get_sigmas_simple_test, |
|
} |
|
|