Spaces:
Build error
Build error
"""Partially ported from https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/sampling.py.""" | |
from __future__ import annotations | |
from typing import Union | |
import rerun as rr | |
import torch | |
from omegaconf import ListConfig, OmegaConf | |
from tqdm import tqdm | |
from ...util import append_dims, default, instantiate_from_config | |
from .sampling_utils import to_d | |
class BaseDiffusionSampler: | |
def __init__( | |
self, | |
discretization_config: Union[dict, ListConfig, OmegaConf], | |
num_steps: Union[int, None] = None, | |
guider_config: Union[dict, ListConfig, OmegaConf, None] = None, | |
verbose: bool = False, | |
device: str = "cuda", | |
): | |
self.num_steps = num_steps | |
self.discretization = instantiate_from_config(discretization_config) | |
self.guider = instantiate_from_config(guider_config) | |
self.verbose = verbose | |
self.device = device | |
def prepare_sampling_loop(self, x, cond, uc=None, num_steps=None): | |
sigmas = self.discretization( | |
self.num_steps if num_steps is None else num_steps, device=self.device | |
) | |
uc = default(uc, cond) | |
x *= torch.sqrt(1.0 + sigmas[0] ** 2) | |
num_sigmas = len(sigmas) | |
s_in = x.new_ones([x.shape[0]]) | |
return x, s_in, sigmas, num_sigmas, cond, uc | |
def denoise(self, x, denoiser, sigma, cond, cond_mask, uc): | |
denoised = denoiser(*self.guider.prepare_inputs(x, sigma, cond, cond_mask, uc)) | |
denoised = self.guider(denoised, sigma) | |
return denoised | |
def get_sigma_gen(self, num_sigmas): | |
sigma_generator = range(num_sigmas - 1) | |
if self.verbose: | |
print("#" * 30, " Sampling Setting ", "#" * 30) | |
print(f"Sampler: {self.__class__.__name__}") | |
print(f"Discretization: {self.discretization.__class__.__name__}") | |
print(f"Guider: {self.guider.__class__.__name__}") | |
sigma_generator = tqdm( | |
sigma_generator, | |
total=num_sigmas, | |
desc=f"Sampling with {self.__class__.__name__} for {num_sigmas} steps", | |
) | |
return sigma_generator | |
class SingleStepDiffusionSampler(BaseDiffusionSampler): | |
def sampler_step(self, sigma, next_sigma, denoiser, x, cond, uc, *args, **kwargs): | |
raise NotImplementedError | |
def euler_step(self, x, d, dt): | |
return x + dt * d | |
class EulerEDMSampler(SingleStepDiffusionSampler): | |
def __init__( | |
self, s_churn=0.0, s_tmin=0.0, s_tmax=float("inf"), s_noise=1.0, *args, **kwargs | |
): | |
super().__init__(*args, **kwargs) | |
self.s_churn = s_churn | |
self.s_tmin = s_tmin | |
self.s_tmax = s_tmax | |
self.s_noise = s_noise | |
def sampler_step( | |
self, sigma, next_sigma, denoiser, x, cond, cond_mask=None, uc=None, gamma=0.0 | |
): | |
sigma_hat = sigma * (gamma + 1.0) | |
if gamma > 0: | |
eps = torch.randn_like(x) * self.s_noise | |
x = x + eps * append_dims(sigma_hat**2 - sigma**2, x.ndim) ** 0.5 | |
denoised = self.denoise(x, denoiser, sigma_hat, cond, cond_mask, uc) | |
d = to_d(x, sigma_hat, denoised) | |
dt = append_dims(next_sigma - sigma_hat, x.ndim) | |
euler_step = self.euler_step(x, d, dt) | |
return euler_step | |
def __call__( | |
self, | |
denoiser, | |
x, # x is randn | |
cond, | |
uc=None, | |
cond_frame=None, | |
cond_mask=None, | |
num_steps=None, | |
num_sequence=0, | |
log_queue=None, | |
): | |
x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop( | |
x, cond, uc, num_steps | |
) | |
replace_cond_frames = cond_mask is not None and cond_mask.any() | |
for i in tqdm(self.get_sigma_gen(num_sigmas), "Diffusion steps"): | |
if replace_cond_frames: | |
x = x * append_dims(1 - cond_mask, x.ndim) + cond_frame * append_dims( | |
cond_mask, cond_frame.ndim | |
) | |
gamma = ( | |
min(self.s_churn / (num_sigmas - 1), 2**0.5 - 1) | |
if self.s_tmin <= sigmas[i] <= self.s_tmax | |
else 0.0 | |
) | |
x = self.sampler_step( | |
s_in * sigmas[i], | |
s_in * sigmas[i + 1], | |
denoiser, | |
x, | |
cond, | |
cond_mask, | |
uc, | |
gamma, | |
) | |
log_queue.put( | |
( | |
f"diffusion_{num_sequence}", | |
rr.Tensor(x.numpy(force=True)), | |
[ | |
("frame_id", 0), | |
("diffusion", i), | |
( | |
"combined", | |
2 * num_sequence + (i * 1.0 / num_sigmas), | |
), | |
], | |
) | |
) | |
if replace_cond_frames: | |
x = x * append_dims(1 - cond_mask, x.ndim) + cond_frame * append_dims( | |
cond_mask, cond_frame.ndim | |
) | |
return x | |