|
""" |
|
Based on: https://github.com/crowsonkb/k-diffusion |
|
|
|
Copyright (c) 2022 Katherine Crowson |
|
|
|
Permission is hereby granted, free of charge, to any person obtaining a copy |
|
of this software and associated documentation files (the "Software"), to deal |
|
in the Software without restriction, including without limitation the rights |
|
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell |
|
copies of the Software, and to permit persons to whom the Software is |
|
furnished to do so, subject to the following conditions: |
|
|
|
The above copyright notice and this permission notice shall be included in |
|
all copies or substantial portions of the Software. |
|
|
|
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR |
|
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, |
|
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE |
|
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER |
|
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, |
|
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN |
|
THE SOFTWARE. |
|
""" |
|
|
|
import numpy as np |
|
import torch as th |
|
|
|
from .gaussian_diffusion import GaussianDiffusion, mean_flat |
|
|
|
|
|
class KarrasDenoiser: |
|
def __init__(self, sigma_data: float = 0.5): |
|
self.sigma_data = sigma_data |
|
|
|
def get_snr(self, sigmas): |
|
return sigmas**-2 |
|
|
|
def get_sigmas(self, sigmas): |
|
return sigmas |
|
|
|
def get_scalings(self, sigma): |
|
c_skip = self.sigma_data**2 / (sigma**2 + self.sigma_data**2) |
|
c_out = sigma * self.sigma_data / (sigma**2 + self.sigma_data**2) ** 0.5 |
|
c_in = 1 / (sigma**2 + self.sigma_data**2) ** 0.5 |
|
return c_skip, c_out, c_in |
|
|
|
def training_losses(self, model, x_start, sigmas, model_kwargs=None, noise=None): |
|
if model_kwargs is None: |
|
model_kwargs = {} |
|
if noise is None: |
|
noise = th.randn_like(x_start) |
|
|
|
terms = {} |
|
|
|
dims = x_start.ndim |
|
x_t = x_start + noise * append_dims(sigmas, dims) |
|
c_skip, c_out, _ = [append_dims(x, dims) for x in self.get_scalings(sigmas)] |
|
model_output, denoised = self.denoise(model, x_t, sigmas, **model_kwargs) |
|
target = (x_start - c_skip * x_t) / c_out |
|
|
|
terms["mse"] = mean_flat((model_output - target) ** 2) |
|
terms["xs_mse"] = mean_flat((denoised - x_start) ** 2) |
|
|
|
if "vb" in terms: |
|
terms["loss"] = terms["mse"] + terms["vb"] |
|
else: |
|
terms["loss"] = terms["mse"] |
|
|
|
return terms |
|
|
|
def denoise(self, model, x_t, sigmas, **model_kwargs): |
|
c_skip, c_out, c_in = [append_dims(x, x_t.ndim) for x in self.get_scalings(sigmas)] |
|
rescaled_t = 1000 * 0.25 * th.log(sigmas + 1e-44) |
|
model_output = model(c_in * x_t, rescaled_t, **model_kwargs) |
|
denoised = c_out * model_output + c_skip * x_t |
|
return model_output, denoised |
|
|
|
|
|
class GaussianToKarrasDenoiser: |
|
def __init__(self, model, diffusion): |
|
from scipy import interpolate |
|
|
|
self.model = model |
|
self.diffusion = diffusion |
|
self.alpha_cumprod_to_t = interpolate.interp1d( |
|
diffusion.alphas_cumprod, np.arange(0, diffusion.num_timesteps) |
|
) |
|
|
|
def sigma_to_t(self, sigma): |
|
alpha_cumprod = 1.0 / (sigma**2 + 1) |
|
if alpha_cumprod > self.diffusion.alphas_cumprod[0]: |
|
return 0 |
|
elif alpha_cumprod <= self.diffusion.alphas_cumprod[-1]: |
|
return self.diffusion.num_timesteps - 1 |
|
else: |
|
return float(self.alpha_cumprod_to_t(alpha_cumprod)) |
|
|
|
def denoise(self, x_t, sigmas, clip_denoised=True, model_kwargs=None): |
|
t = th.tensor( |
|
[self.sigma_to_t(sigma) for sigma in sigmas.cpu().numpy()], |
|
dtype=th.long, |
|
device=sigmas.device, |
|
) |
|
c_in = append_dims(1.0 / (sigmas**2 + 1) ** 0.5, x_t.ndim) |
|
out = self.diffusion.p_mean_variance( |
|
self.model, x_t * c_in, t, clip_denoised=clip_denoised, model_kwargs=model_kwargs |
|
) |
|
return None, out["pred_xstart"] |
|
|
|
|
|
def karras_sample(*args, **kwargs): |
|
last = None |
|
for x in karras_sample_progressive(*args, **kwargs): |
|
last = x["x"] |
|
return last |
|
|
|
|
|
def karras_sample_progressive( |
|
diffusion, |
|
model, |
|
shape, |
|
steps, |
|
clip_denoised=True, |
|
progress=False, |
|
model_kwargs=None, |
|
device=None, |
|
sigma_min=0.002, |
|
sigma_max=80, |
|
rho=7.0, |
|
sampler="heun", |
|
s_churn=0.0, |
|
s_tmin=0.0, |
|
s_tmax=float("inf"), |
|
s_noise=1.0, |
|
guidance_scale=0.0, |
|
): |
|
sigmas = get_sigmas_karras(steps, sigma_min, sigma_max, rho, device=device) |
|
x_T = th.randn(*shape, device=device) * sigma_max |
|
sample_fn = {"heun": sample_heun, "dpm": sample_dpm, "ancestral": sample_euler_ancestral}[ |
|
sampler |
|
] |
|
|
|
if sampler != "ancestral": |
|
sampler_args = dict(s_churn=s_churn, s_tmin=s_tmin, s_tmax=s_tmax, s_noise=s_noise) |
|
else: |
|
sampler_args = {} |
|
|
|
if isinstance(diffusion, KarrasDenoiser): |
|
|
|
def denoiser(x_t, sigma): |
|
_, denoised = diffusion.denoise(model, x_t, sigma, **model_kwargs) |
|
if clip_denoised: |
|
denoised = denoised.clamp(-1, 1) |
|
return denoised |
|
|
|
elif isinstance(diffusion, GaussianDiffusion): |
|
model = GaussianToKarrasDenoiser(model, diffusion) |
|
|
|
def denoiser(x_t, sigma): |
|
_, denoised = model.denoise( |
|
x_t, sigma, clip_denoised=clip_denoised, model_kwargs=model_kwargs |
|
) |
|
return denoised |
|
|
|
else: |
|
raise NotImplementedError |
|
|
|
if guidance_scale != 0 and guidance_scale != 1: |
|
|
|
def guided_denoiser(x_t, sigma): |
|
x_t = th.cat([x_t, x_t], dim=0) |
|
sigma = th.cat([sigma, sigma], dim=0) |
|
x_0 = denoiser(x_t, sigma) |
|
cond_x_0, uncond_x_0 = th.split(x_0, len(x_0) // 2, dim=0) |
|
x_0 = uncond_x_0 + guidance_scale * (cond_x_0 - uncond_x_0) |
|
return x_0 |
|
|
|
else: |
|
guided_denoiser = denoiser |
|
|
|
for obj in sample_fn( |
|
guided_denoiser, |
|
x_T, |
|
sigmas, |
|
progress=progress, |
|
**sampler_args, |
|
): |
|
if isinstance(diffusion, GaussianDiffusion): |
|
yield diffusion.unscale_out_dict(obj) |
|
else: |
|
yield obj |
|
|
|
|
|
def get_sigmas_karras(n, sigma_min, sigma_max, rho=7.0, device="cpu"): |
|
"""Constructs the noise schedule of Karras et al. (2022).""" |
|
ramp = th.linspace(0, 1, n) |
|
min_inv_rho = sigma_min ** (1 / rho) |
|
max_inv_rho = sigma_max ** (1 / rho) |
|
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho |
|
return append_zero(sigmas).to(device) |
|
|
|
|
|
def to_d(x, sigma, denoised): |
|
"""Converts a denoiser output to a Karras ODE derivative.""" |
|
return (x - denoised) / append_dims(sigma, x.ndim) |
|
|
|
|
|
def get_ancestral_step(sigma_from, sigma_to): |
|
"""Calculates the noise level (sigma_down) to step down to and the amount |
|
of noise to add (sigma_up) when doing an ancestral sampling step.""" |
|
sigma_up = (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5 |
|
sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5 |
|
return sigma_down, sigma_up |
|
|
|
|
|
@th.no_grad() |
|
def sample_euler_ancestral(model, x, sigmas, progress=False): |
|
"""Ancestral sampling with Euler method steps.""" |
|
s_in = x.new_ones([x.shape[0]]) |
|
indices = range(len(sigmas) - 1) |
|
if progress: |
|
from tqdm.auto import tqdm |
|
|
|
indices = tqdm(indices) |
|
|
|
for i in indices: |
|
denoised = model(x, sigmas[i] * s_in) |
|
sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1]) |
|
yield {"x": x, "i": i, "sigma": sigmas[i], "sigma_hat": sigmas[i], "pred_xstart": denoised} |
|
d = to_d(x, sigmas[i], denoised) |
|
|
|
dt = sigma_down - sigmas[i] |
|
x = x + d * dt |
|
x = x + th.randn_like(x) * sigma_up |
|
yield {"x": x, "pred_xstart": x} |
|
|
|
|
|
@th.no_grad() |
|
def sample_heun( |
|
denoiser, |
|
x, |
|
sigmas, |
|
progress=False, |
|
s_churn=0.0, |
|
s_tmin=0.0, |
|
s_tmax=float("inf"), |
|
s_noise=1.0, |
|
): |
|
"""Implements Algorithm 2 (Heun steps) from Karras et al. (2022).""" |
|
s_in = x.new_ones([x.shape[0]]) |
|
indices = range(len(sigmas) - 1) |
|
if progress: |
|
from tqdm.auto import tqdm |
|
|
|
indices = tqdm(indices) |
|
|
|
for i in indices: |
|
gamma = ( |
|
min(s_churn / (len(sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.0 |
|
) |
|
eps = th.randn_like(x) * s_noise |
|
sigma_hat = sigmas[i] * (gamma + 1) |
|
if gamma > 0: |
|
x = x + eps * (sigma_hat**2 - sigmas[i] ** 2) ** 0.5 |
|
denoised = denoiser(x, sigma_hat * s_in) |
|
d = to_d(x, sigma_hat, denoised) |
|
yield {"x": x, "i": i, "sigma": sigmas[i], "sigma_hat": sigma_hat, "pred_xstart": denoised} |
|
dt = sigmas[i + 1] - sigma_hat |
|
if sigmas[i + 1] == 0: |
|
|
|
x = x + d * dt |
|
else: |
|
|
|
x_2 = x + d * dt |
|
denoised_2 = denoiser(x_2, sigmas[i + 1] * s_in) |
|
d_2 = to_d(x_2, sigmas[i + 1], denoised_2) |
|
d_prime = (d + d_2) / 2 |
|
x = x + d_prime * dt |
|
yield {"x": x, "pred_xstart": denoised} |
|
|
|
|
|
@th.no_grad() |
|
def sample_dpm( |
|
denoiser, |
|
x, |
|
sigmas, |
|
progress=False, |
|
s_churn=0.0, |
|
s_tmin=0.0, |
|
s_tmax=float("inf"), |
|
s_noise=1.0, |
|
): |
|
"""A sampler inspired by DPM-Solver-2 and Algorithm 2 from Karras et al. (2022).""" |
|
s_in = x.new_ones([x.shape[0]]) |
|
indices = range(len(sigmas) - 1) |
|
if progress: |
|
from tqdm.auto import tqdm |
|
|
|
indices = tqdm(indices) |
|
|
|
for i in indices: |
|
gamma = ( |
|
min(s_churn / (len(sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.0 |
|
) |
|
eps = th.randn_like(x) * s_noise |
|
sigma_hat = sigmas[i] * (gamma + 1) |
|
if gamma > 0: |
|
x = x + eps * (sigma_hat**2 - sigmas[i] ** 2) ** 0.5 |
|
denoised = denoiser(x, sigma_hat * s_in) |
|
d = to_d(x, sigma_hat, denoised) |
|
yield {"x": x, "i": i, "sigma": sigmas[i], "sigma_hat": sigma_hat, "denoised": denoised} |
|
|
|
sigma_mid = ((sigma_hat ** (1 / 3) + sigmas[i + 1] ** (1 / 3)) / 2) ** 3 |
|
dt_1 = sigma_mid - sigma_hat |
|
dt_2 = sigmas[i + 1] - sigma_hat |
|
x_2 = x + d * dt_1 |
|
denoised_2 = denoiser(x_2, sigma_mid * s_in) |
|
d_2 = to_d(x_2, sigma_mid, denoised_2) |
|
x = x + d_2 * dt_2 |
|
yield {"x": x, "pred_xstart": denoised} |
|
|
|
|
|
def append_dims(x, target_dims): |
|
"""Appends dimensions to the end of a tensor until it has target_dims dimensions.""" |
|
dims_to_append = target_dims - x.ndim |
|
if dims_to_append < 0: |
|
raise ValueError(f"input has {x.ndim} dims but target_dims is {target_dims}, which is less") |
|
return x[(...,) + (None,) * dims_to_append] |
|
|
|
|
|
def append_zero(x): |
|
return th.cat([x, x.new_zeros([1])]) |
|
|