Spaces:
Build error
Build error
File size: 5,083 Bytes
d323598 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 |
"""Partially ported from https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/sampling.py."""
from __future__ import annotations
from typing import Union
import rerun as rr
import torch
from omegaconf import ListConfig, OmegaConf
from tqdm import tqdm
from ...util import append_dims, default, instantiate_from_config
from .sampling_utils import to_d
class BaseDiffusionSampler:
def __init__(
self,
discretization_config: Union[dict, ListConfig, OmegaConf],
num_steps: Union[int, None] = None,
guider_config: Union[dict, ListConfig, OmegaConf, None] = None,
verbose: bool = False,
device: str = "cuda",
):
self.num_steps = num_steps
self.discretization = instantiate_from_config(discretization_config)
self.guider = instantiate_from_config(guider_config)
self.verbose = verbose
self.device = device
def prepare_sampling_loop(self, x, cond, uc=None, num_steps=None):
sigmas = self.discretization(
self.num_steps if num_steps is None else num_steps, device=self.device
)
uc = default(uc, cond)
x *= torch.sqrt(1.0 + sigmas[0] ** 2)
num_sigmas = len(sigmas)
s_in = x.new_ones([x.shape[0]])
return x, s_in, sigmas, num_sigmas, cond, uc
def denoise(self, x, denoiser, sigma, cond, cond_mask, uc):
denoised = denoiser(*self.guider.prepare_inputs(x, sigma, cond, cond_mask, uc))
denoised = self.guider(denoised, sigma)
return denoised
def get_sigma_gen(self, num_sigmas):
sigma_generator = range(num_sigmas - 1)
if self.verbose:
print("#" * 30, " Sampling Setting ", "#" * 30)
print(f"Sampler: {self.__class__.__name__}")
print(f"Discretization: {self.discretization.__class__.__name__}")
print(f"Guider: {self.guider.__class__.__name__}")
sigma_generator = tqdm(
sigma_generator,
total=num_sigmas,
desc=f"Sampling with {self.__class__.__name__} for {num_sigmas} steps",
)
return sigma_generator
class SingleStepDiffusionSampler(BaseDiffusionSampler):
def sampler_step(self, sigma, next_sigma, denoiser, x, cond, uc, *args, **kwargs):
raise NotImplementedError
def euler_step(self, x, d, dt):
return x + dt * d
class EulerEDMSampler(SingleStepDiffusionSampler):
def __init__(
self, s_churn=0.0, s_tmin=0.0, s_tmax=float("inf"), s_noise=1.0, *args, **kwargs
):
super().__init__(*args, **kwargs)
self.s_churn = s_churn
self.s_tmin = s_tmin
self.s_tmax = s_tmax
self.s_noise = s_noise
def sampler_step(
self, sigma, next_sigma, denoiser, x, cond, cond_mask=None, uc=None, gamma=0.0
):
sigma_hat = sigma * (gamma + 1.0)
if gamma > 0:
eps = torch.randn_like(x) * self.s_noise
x = x + eps * append_dims(sigma_hat**2 - sigma**2, x.ndim) ** 0.5
denoised = self.denoise(x, denoiser, sigma_hat, cond, cond_mask, uc)
d = to_d(x, sigma_hat, denoised)
dt = append_dims(next_sigma - sigma_hat, x.ndim)
euler_step = self.euler_step(x, d, dt)
return euler_step
def __call__(
self,
denoiser,
x, # x is randn
cond,
uc=None,
cond_frame=None,
cond_mask=None,
num_steps=None,
num_sequence=0,
log_queue=None,
):
x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(
x, cond, uc, num_steps
)
replace_cond_frames = cond_mask is not None and cond_mask.any()
for i in tqdm(self.get_sigma_gen(num_sigmas), "Diffusion steps"):
if replace_cond_frames:
x = x * append_dims(1 - cond_mask, x.ndim) + cond_frame * append_dims(
cond_mask, cond_frame.ndim
)
gamma = (
min(self.s_churn / (num_sigmas - 1), 2**0.5 - 1)
if self.s_tmin <= sigmas[i] <= self.s_tmax
else 0.0
)
x = self.sampler_step(
s_in * sigmas[i],
s_in * sigmas[i + 1],
denoiser,
x,
cond,
cond_mask,
uc,
gamma,
)
log_queue.put(
(
f"diffusion_{num_sequence}",
rr.Tensor(x.numpy(force=True)),
[
("frame_id", 0),
("diffusion", i),
(
"combined",
2 * num_sequence + (i * 1.0 / num_sigmas),
),
],
)
)
if replace_cond_frames:
x = x * append_dims(1 - cond_mask, x.ndim) + cond_frame * append_dims(
cond_mask, cond_frame.ndim
)
return x
|