"""Partially ported from https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/sampling.py.""" from __future__ import annotations from typing import Union import rerun as rr import torch from omegaconf import ListConfig, OmegaConf from tqdm import tqdm from ...util import append_dims, default, instantiate_from_config from .sampling_utils import to_d class BaseDiffusionSampler: def __init__( self, discretization_config: Union[dict, ListConfig, OmegaConf], num_steps: Union[int, None] = None, guider_config: Union[dict, ListConfig, OmegaConf, None] = None, verbose: bool = False, device: str = "cuda", ): self.num_steps = num_steps self.discretization = instantiate_from_config(discretization_config) self.guider = instantiate_from_config(guider_config) self.verbose = verbose self.device = device def prepare_sampling_loop(self, x, cond, uc=None, num_steps=None): sigmas = self.discretization( self.num_steps if num_steps is None else num_steps, device=self.device ) uc = default(uc, cond) x *= torch.sqrt(1.0 + sigmas[0] ** 2) num_sigmas = len(sigmas) s_in = x.new_ones([x.shape[0]]) return x, s_in, sigmas, num_sigmas, cond, uc def denoise(self, x, denoiser, sigma, cond, cond_mask, uc): denoised = denoiser(*self.guider.prepare_inputs(x, sigma, cond, cond_mask, uc)) denoised = self.guider(denoised, sigma) return denoised def get_sigma_gen(self, num_sigmas): sigma_generator = range(num_sigmas - 1) if self.verbose: print("#" * 30, " Sampling Setting ", "#" * 30) print(f"Sampler: {self.__class__.__name__}") print(f"Discretization: {self.discretization.__class__.__name__}") print(f"Guider: {self.guider.__class__.__name__}") sigma_generator = tqdm( sigma_generator, total=num_sigmas, desc=f"Sampling with {self.__class__.__name__} for {num_sigmas} steps", ) return sigma_generator class SingleStepDiffusionSampler(BaseDiffusionSampler): def sampler_step(self, sigma, next_sigma, denoiser, x, cond, uc, *args, **kwargs): raise NotImplementedError def euler_step(self, x, d, dt): return x + dt * d class EulerEDMSampler(SingleStepDiffusionSampler): def __init__( self, s_churn=0.0, s_tmin=0.0, s_tmax=float("inf"), s_noise=1.0, *args, **kwargs ): super().__init__(*args, **kwargs) self.s_churn = s_churn self.s_tmin = s_tmin self.s_tmax = s_tmax self.s_noise = s_noise def sampler_step( self, sigma, next_sigma, denoiser, x, cond, cond_mask=None, uc=None, gamma=0.0 ): sigma_hat = sigma * (gamma + 1.0) if gamma > 0: eps = torch.randn_like(x) * self.s_noise x = x + eps * append_dims(sigma_hat**2 - sigma**2, x.ndim) ** 0.5 denoised = self.denoise(x, denoiser, sigma_hat, cond, cond_mask, uc) d = to_d(x, sigma_hat, denoised) dt = append_dims(next_sigma - sigma_hat, x.ndim) euler_step = self.euler_step(x, d, dt) return euler_step def __call__( self, denoiser, x, # x is randn cond, uc=None, cond_frame=None, cond_mask=None, num_steps=None, num_sequence=0, log_queue=None, ): x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop( x, cond, uc, num_steps ) replace_cond_frames = cond_mask is not None and cond_mask.any() for i in tqdm(self.get_sigma_gen(num_sigmas), "Diffusion steps"): if replace_cond_frames: x = x * append_dims(1 - cond_mask, x.ndim) + cond_frame * append_dims( cond_mask, cond_frame.ndim ) gamma = ( min(self.s_churn / (num_sigmas - 1), 2**0.5 - 1) if self.s_tmin <= sigmas[i] <= self.s_tmax else 0.0 ) x = self.sampler_step( s_in * sigmas[i], s_in * sigmas[i + 1], denoiser, x, cond, cond_mask, uc, gamma, ) log_queue.put( ( f"diffusion_{num_sequence}", rr.Tensor(x.numpy(force=True)), [ ("frame_id", 0), ("diffusion", i), ( "combined", 2 * num_sequence + (i * 1.0 / num_sigmas), ), ], ) ) if replace_cond_frames: x = x * append_dims(1 - cond_mask, x.ndim) + cond_frame * append_dims( cond_mask, cond_frame.ndim ) return x