File size: 2,167 Bytes
6fecfbe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
from comfy.k_diffusion import sampling as k_diffusion_sampling
import torch

def get_sigmas_karras(model, n, s_min, s_max, device):
    return k_diffusion_sampling.get_sigmas_karras(n, s_min, s_max, device=device)


def get_sigmas_exponential(model, n, s_min, s_max, device):
    return k_diffusion_sampling.get_sigmas_exponential(n, s_min, s_max, device=device)


def get_sigmas_normal(model, n, s_min, s_max, device):
    t_min, t_max = model.sigma_to_t(torch.tensor([s_min, s_max], device=device))
    t = torch.linspace(t_max, t_min, n, device=device)
    return k_diffusion_sampling.append_zero(model.t_to_sigma(t))


def get_sigmas_simple(model, n, s_min, s_max, device):
    min_idx = torch.argmin(torch.abs(model.sigmas - s_min))
    max_idx = torch.argmin(torch.abs(model.sigmas - s_max))
    sigmas_slice = model.sigmas[min_idx:max_idx]
    ss = len(sigmas_slice) / n
    sigs = [float(s_max)]
    for x in range(1, n - 1):
        sigs += [float(sigmas_slice[-(1 + int(x * ss))])]
    sigs += [float(s_min), 0.0]
    return torch.tensor(sigs, device=device)


def get_sigmas_ddim_uniform(model, n, s_min, s_max, device):
    t_min, t_max = model.sigma_to_t(torch.tensor([s_min, s_max], device=device))
    ddim_timesteps = torch.linspace(t_max, t_min, n, dtype=torch.int16, device=device)
    sigs = []
    for ts in ddim_timesteps:
        if ts > 999:
            ts = 999
        sigs.append(model.t_to_sigma(ts))
    sigs += [0.0]
    return torch.tensor(sigs, device=device)


def get_sigmas_simple_test(model, n, s_min, s_max, device):
    min_idx = torch.argmin(torch.abs(model.sigmas - s_min))
    max_idx = torch.argmin(torch.abs(model.sigmas - s_max))
    sigmas_slice = model.sigmas[min_idx:max_idx]
    ss = len(sigmas_slice) / n
    sigs = []
    for x in range(n):
        sigs += [float(sigmas_slice[-(1 + int(x * ss))])]
    sigs += [0.0]
    return torch.tensor(sigs, device=device)


SCHEDULER_MAPPING = {
    "normal": get_sigmas_normal,
    "karras": get_sigmas_karras,
    "exponential": get_sigmas_exponential,
    "simple": get_sigmas_simple,
    "ddim_uniform": get_sigmas_ddim_uniform,
    "simple_test": get_sigmas_simple_test,
}