File size: 2,167 Bytes
6fecfbe |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 |
from comfy.k_diffusion import sampling as k_diffusion_sampling
import torch
def get_sigmas_karras(model, n, s_min, s_max, device):
return k_diffusion_sampling.get_sigmas_karras(n, s_min, s_max, device=device)
def get_sigmas_exponential(model, n, s_min, s_max, device):
return k_diffusion_sampling.get_sigmas_exponential(n, s_min, s_max, device=device)
def get_sigmas_normal(model, n, s_min, s_max, device):
t_min, t_max = model.sigma_to_t(torch.tensor([s_min, s_max], device=device))
t = torch.linspace(t_max, t_min, n, device=device)
return k_diffusion_sampling.append_zero(model.t_to_sigma(t))
def get_sigmas_simple(model, n, s_min, s_max, device):
min_idx = torch.argmin(torch.abs(model.sigmas - s_min))
max_idx = torch.argmin(torch.abs(model.sigmas - s_max))
sigmas_slice = model.sigmas[min_idx:max_idx]
ss = len(sigmas_slice) / n
sigs = [float(s_max)]
for x in range(1, n - 1):
sigs += [float(sigmas_slice[-(1 + int(x * ss))])]
sigs += [float(s_min), 0.0]
return torch.tensor(sigs, device=device)
def get_sigmas_ddim_uniform(model, n, s_min, s_max, device):
t_min, t_max = model.sigma_to_t(torch.tensor([s_min, s_max], device=device))
ddim_timesteps = torch.linspace(t_max, t_min, n, dtype=torch.int16, device=device)
sigs = []
for ts in ddim_timesteps:
if ts > 999:
ts = 999
sigs.append(model.t_to_sigma(ts))
sigs += [0.0]
return torch.tensor(sigs, device=device)
def get_sigmas_simple_test(model, n, s_min, s_max, device):
min_idx = torch.argmin(torch.abs(model.sigmas - s_min))
max_idx = torch.argmin(torch.abs(model.sigmas - s_max))
sigmas_slice = model.sigmas[min_idx:max_idx]
ss = len(sigmas_slice) / n
sigs = []
for x in range(n):
sigs += [float(sigmas_slice[-(1 + int(x * ss))])]
sigs += [0.0]
return torch.tensor(sigs, device=device)
SCHEDULER_MAPPING = {
"normal": get_sigmas_normal,
"karras": get_sigmas_karras,
"exponential": get_sigmas_exponential,
"simple": get_sigmas_simple,
"ddim_uniform": get_sigmas_ddim_uniform,
"simple_test": get_sigmas_simple_test,
}
|