Wismut's picture
initial commit
0af9841
from math import atan, cos, pi, sin, sqrt
from typing import Any, Callable, List, Optional, Tuple, Type
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, reduce
from torch import Tensor
from .utils import *
"""
Diffusion Training
"""
""" Distributions """
class Distribution:
def __call__(self, num_samples: int, device: torch.device):
raise NotImplementedError()
class LogNormalDistribution(Distribution):
def __init__(self, mean: float, std: float):
self.mean = mean
self.std = std
def __call__(
self, num_samples: int, device: torch.device = torch.device("cpu")
) -> Tensor:
normal = self.mean + self.std * torch.randn((num_samples,), device=device)
return normal.exp()
class UniformDistribution(Distribution):
def __call__(self, num_samples: int, device: torch.device = torch.device("cpu")):
return torch.rand(num_samples, device=device)
class VKDistribution(Distribution):
def __init__(
self,
min_value: float = 0.0,
max_value: float = float("inf"),
sigma_data: float = 1.0,
):
self.min_value = min_value
self.max_value = max_value
self.sigma_data = sigma_data
def __call__(
self, num_samples: int, device: torch.device = torch.device("cpu")
) -> Tensor:
sigma_data = self.sigma_data
min_cdf = atan(self.min_value / sigma_data) * 2 / pi
max_cdf = atan(self.max_value / sigma_data) * 2 / pi
u = (max_cdf - min_cdf) * torch.randn((num_samples,), device=device) + min_cdf
return torch.tan(u * pi / 2) * sigma_data
""" Diffusion Classes """
def pad_dims(x: Tensor, ndim: int) -> Tensor:
# Pads additional ndims to the right of the tensor
return x.view(*x.shape, *((1,) * ndim))
def clip(x: Tensor, dynamic_threshold: float = 0.0):
if dynamic_threshold == 0.0:
return x.clamp(-1.0, 1.0)
else:
# Dynamic thresholding
# Find dynamic threshold quantile for each batch
x_flat = rearrange(x, "b ... -> b (...)")
scale = torch.quantile(x_flat.abs(), dynamic_threshold, dim=-1)
# Clamp to a min of 1.0
scale.clamp_(min=1.0)
# Clamp all values and scale
scale = pad_dims(scale, ndim=x.ndim - scale.ndim)
x = x.clamp(-scale, scale) / scale
return x
def to_batch(
batch_size: int,
device: torch.device,
x: Optional[float] = None,
xs: Optional[Tensor] = None,
) -> Tensor:
assert exists(x) ^ exists(xs), "Either x or xs must be provided"
# If x provided use the same for all batch items
if exists(x):
xs = torch.full(size=(batch_size,), fill_value=x).to(device)
assert exists(xs)
return xs
class Diffusion(nn.Module):
alias: str = ""
"""Base diffusion class"""
def denoise_fn(
self,
x_noisy: Tensor,
sigmas: Optional[Tensor] = None,
sigma: Optional[float] = None,
**kwargs,
) -> Tensor:
raise NotImplementedError("Diffusion class missing denoise_fn")
def forward(self, x: Tensor, noise: Tensor = None, **kwargs) -> Tensor:
raise NotImplementedError("Diffusion class missing forward function")
class VDiffusion(Diffusion):
alias = "v"
def __init__(self, net: nn.Module, *, sigma_distribution: Distribution):
super().__init__()
self.net = net
self.sigma_distribution = sigma_distribution
def get_alpha_beta(self, sigmas: Tensor) -> Tuple[Tensor, Tensor]:
angle = sigmas * pi / 2
alpha = torch.cos(angle)
beta = torch.sin(angle)
return alpha, beta
def denoise_fn(
self,
x_noisy: Tensor,
sigmas: Optional[Tensor] = None,
sigma: Optional[float] = None,
**kwargs,
) -> Tensor:
batch_size, device = x_noisy.shape[0], x_noisy.device
sigmas = to_batch(x=sigma, xs=sigmas, batch_size=batch_size, device=device)
return self.net(x_noisy, sigmas, **kwargs)
def forward(self, x: Tensor, noise: Tensor = None, **kwargs) -> Tensor:
batch_size, device = x.shape[0], x.device
# Sample amount of noise to add for each batch element
sigmas = self.sigma_distribution(num_samples=batch_size, device=device)
sigmas_padded = rearrange(sigmas, "b -> b 1 1")
# Get noise
noise = default(noise, lambda: torch.randn_like(x))
# Combine input and noise weighted by half-circle
alpha, beta = self.get_alpha_beta(sigmas_padded)
x_noisy = x * alpha + noise * beta
x_target = noise * alpha - x * beta
# Denoise and return loss
x_denoised = self.denoise_fn(x_noisy, sigmas, **kwargs)
return F.mse_loss(x_denoised, x_target)
class KDiffusion(Diffusion):
"""Elucidated Diffusion (Karras et al. 2022): https://arxiv.org/abs/2206.00364"""
alias = "k"
def __init__(
self,
net: nn.Module,
*,
sigma_distribution: Distribution,
sigma_data: float, # data distribution standard deviation
dynamic_threshold: float = 0.0,
):
super().__init__()
self.net = net
self.sigma_data = sigma_data
self.sigma_distribution = sigma_distribution
self.dynamic_threshold = dynamic_threshold
def get_scale_weights(self, sigmas: Tensor) -> Tuple[Tensor, ...]:
sigma_data = self.sigma_data
c_noise = torch.log(sigmas) * 0.25
sigmas = rearrange(sigmas, "b -> b 1 1")
c_skip = (sigma_data ** 2) / (sigmas ** 2 + sigma_data ** 2)
c_out = sigmas * sigma_data * (sigma_data ** 2 + sigmas ** 2) ** -0.5
c_in = (sigmas ** 2 + sigma_data ** 2) ** -0.5
return c_skip, c_out, c_in, c_noise
def denoise_fn(
self,
x_noisy: Tensor,
sigmas: Optional[Tensor] = None,
sigma: Optional[float] = None,
**kwargs,
) -> Tensor:
batch_size, device = x_noisy.shape[0], x_noisy.device
sigmas = to_batch(x=sigma, xs=sigmas, batch_size=batch_size, device=device)
# Predict network output and add skip connection
c_skip, c_out, c_in, c_noise = self.get_scale_weights(sigmas)
x_pred = self.net(c_in * x_noisy, c_noise, **kwargs)
x_denoised = c_skip * x_noisy + c_out * x_pred
return x_denoised
def loss_weight(self, sigmas: Tensor) -> Tensor:
# Computes weight depending on data distribution
return (sigmas ** 2 + self.sigma_data ** 2) * (sigmas * self.sigma_data) ** -2
def forward(self, x: Tensor, noise: Tensor = None, **kwargs) -> Tensor:
batch_size, device = x.shape[0], x.device
from einops import rearrange, reduce
# Sample amount of noise to add for each batch element
sigmas = self.sigma_distribution(num_samples=batch_size, device=device)
sigmas_padded = rearrange(sigmas, "b -> b 1 1")
# Add noise to input
noise = default(noise, lambda: torch.randn_like(x))
x_noisy = x + sigmas_padded * noise
# Compute denoised values
x_denoised = self.denoise_fn(x_noisy, sigmas=sigmas, **kwargs)
# Compute weighted loss
losses = F.mse_loss(x_denoised, x, reduction="none")
losses = reduce(losses, "b ... -> b", "mean")
losses = losses * self.loss_weight(sigmas)
loss = losses.mean()
return loss
class VKDiffusion(Diffusion):
alias = "vk"
def __init__(self, net: nn.Module, *, sigma_distribution: Distribution):
super().__init__()
self.net = net
self.sigma_distribution = sigma_distribution
def get_scale_weights(self, sigmas: Tensor) -> Tuple[Tensor, ...]:
sigma_data = 1.0
sigmas = rearrange(sigmas, "b -> b 1 1")
c_skip = (sigma_data ** 2) / (sigmas ** 2 + sigma_data ** 2)
c_out = -sigmas * sigma_data * (sigma_data ** 2 + sigmas ** 2) ** -0.5
c_in = (sigmas ** 2 + sigma_data ** 2) ** -0.5
return c_skip, c_out, c_in
def sigma_to_t(self, sigmas: Tensor) -> Tensor:
return sigmas.atan() / pi * 2
def t_to_sigma(self, t: Tensor) -> Tensor:
return (t * pi / 2).tan()
def denoise_fn(
self,
x_noisy: Tensor,
sigmas: Optional[Tensor] = None,
sigma: Optional[float] = None,
**kwargs,
) -> Tensor:
batch_size, device = x_noisy.shape[0], x_noisy.device
sigmas = to_batch(x=sigma, xs=sigmas, batch_size=batch_size, device=device)
# Predict network output and add skip connection
c_skip, c_out, c_in = self.get_scale_weights(sigmas)
x_pred = self.net(c_in * x_noisy, self.sigma_to_t(sigmas), **kwargs)
x_denoised = c_skip * x_noisy + c_out * x_pred
return x_denoised
def forward(self, x: Tensor, noise: Tensor = None, **kwargs) -> Tensor:
batch_size, device = x.shape[0], x.device
# Sample amount of noise to add for each batch element
sigmas = self.sigma_distribution(num_samples=batch_size, device=device)
sigmas_padded = rearrange(sigmas, "b -> b 1 1")
# Add noise to input
noise = default(noise, lambda: torch.randn_like(x))
x_noisy = x + sigmas_padded * noise
# Compute model output
c_skip, c_out, c_in = self.get_scale_weights(sigmas)
x_pred = self.net(c_in * x_noisy, self.sigma_to_t(sigmas), **kwargs)
# Compute v-objective target
v_target = (x - c_skip * x_noisy) / (c_out + 1e-7)
# Compute loss
loss = F.mse_loss(x_pred, v_target)
return loss
"""
Diffusion Sampling
"""
""" Schedules """
class Schedule(nn.Module):
"""Interface used by different sampling schedules"""
def forward(self, num_steps: int, device: torch.device) -> Tensor:
raise NotImplementedError()
class LinearSchedule(Schedule):
def forward(self, num_steps: int, device: Any) -> Tensor:
sigmas = torch.linspace(1, 0, num_steps + 1)[:-1]
return sigmas
class KarrasSchedule(Schedule):
"""https://arxiv.org/abs/2206.00364 equation 5"""
def __init__(self, sigma_min: float, sigma_max: float, rho: float = 7.0):
super().__init__()
self.sigma_min = sigma_min
self.sigma_max = sigma_max
self.rho = rho
def forward(self, num_steps: int, device: Any) -> Tensor:
rho_inv = 1.0 / self.rho
steps = torch.arange(num_steps, device=device, dtype=torch.float32)
sigmas = (
self.sigma_max ** rho_inv
+ (steps / (num_steps - 1))
* (self.sigma_min ** rho_inv - self.sigma_max ** rho_inv)
) ** self.rho
sigmas = F.pad(sigmas, pad=(0, 1), value=0.0)
return sigmas
""" Samplers """
class Sampler(nn.Module):
diffusion_types: List[Type[Diffusion]] = []
def forward(
self, noise: Tensor, fn: Callable, sigmas: Tensor, num_steps: int
) -> Tensor:
raise NotImplementedError()
def inpaint(
self,
source: Tensor,
mask: Tensor,
fn: Callable,
sigmas: Tensor,
num_steps: int,
num_resamples: int,
) -> Tensor:
raise NotImplementedError("Inpainting not available with current sampler")
class VSampler(Sampler):
diffusion_types = [VDiffusion]
def get_alpha_beta(self, sigma: float) -> Tuple[float, float]:
angle = sigma * pi / 2
alpha = cos(angle)
beta = sin(angle)
return alpha, beta
def forward(
self, noise: Tensor, fn: Callable, sigmas: Tensor, num_steps: int
) -> Tensor:
x = sigmas[0] * noise
alpha, beta = self.get_alpha_beta(sigmas[0].item())
for i in range(num_steps - 1):
is_last = i == num_steps - 1
x_denoised = fn(x, sigma=sigmas[i])
x_pred = x * alpha - x_denoised * beta
x_eps = x * beta + x_denoised * alpha
if not is_last:
alpha, beta = self.get_alpha_beta(sigmas[i + 1].item())
x = x_pred * alpha + x_eps * beta
return x_pred
class KarrasSampler(Sampler):
"""https://arxiv.org/abs/2206.00364 algorithm 1"""
diffusion_types = [KDiffusion, VKDiffusion]
def __init__(
self,
s_tmin: float = 0,
s_tmax: float = float("inf"),
s_churn: float = 0.0,
s_noise: float = 1.0,
):
super().__init__()
self.s_tmin = s_tmin
self.s_tmax = s_tmax
self.s_noise = s_noise
self.s_churn = s_churn
def step(
self, x: Tensor, fn: Callable, sigma: float, sigma_next: float, gamma: float
) -> Tensor:
"""Algorithm 2 (step)"""
# Select temporarily increased noise level
sigma_hat = sigma + gamma * sigma
# Add noise to move from sigma to sigma_hat
epsilon = self.s_noise * torch.randn_like(x)
x_hat = x + sqrt(sigma_hat ** 2 - sigma ** 2) * epsilon
# Evaluate ∂x/∂sigma at sigma_hat
d = (x_hat - fn(x_hat, sigma=sigma_hat)) / sigma_hat
# Take euler step from sigma_hat to sigma_next
x_next = x_hat + (sigma_next - sigma_hat) * d
# Second order correction
if sigma_next != 0:
model_out_next = fn(x_next, sigma=sigma_next)
d_prime = (x_next - model_out_next) / sigma_next
x_next = x_hat + 0.5 * (sigma - sigma_hat) * (d + d_prime)
return x_next
def forward(
self, noise: Tensor, fn: Callable, sigmas: Tensor, num_steps: int
) -> Tensor:
x = sigmas[0] * noise
# Compute gammas
gammas = torch.where(
(sigmas >= self.s_tmin) & (sigmas <= self.s_tmax),
min(self.s_churn / num_steps, sqrt(2) - 1),
0.0,
)
# Denoise to sample
for i in range(num_steps - 1):
x = self.step(
x, fn=fn, sigma=sigmas[i], sigma_next=sigmas[i + 1], gamma=gammas[i] # type: ignore # noqa
)
return x
class AEulerSampler(Sampler):
diffusion_types = [KDiffusion, VKDiffusion]
def get_sigmas(self, sigma: float, sigma_next: float) -> Tuple[float, float]:
sigma_up = sqrt(sigma_next ** 2 * (sigma ** 2 - sigma_next ** 2) / sigma ** 2)
sigma_down = sqrt(sigma_next ** 2 - sigma_up ** 2)
return sigma_up, sigma_down
def step(self, x: Tensor, fn: Callable, sigma: float, sigma_next: float) -> Tensor:
# Sigma steps
sigma_up, sigma_down = self.get_sigmas(sigma, sigma_next)
# Derivative at sigma (∂x/∂sigma)
d = (x - fn(x, sigma=sigma)) / sigma
# Euler method
x_next = x + d * (sigma_down - sigma)
# Add randomness
x_next = x_next + torch.randn_like(x) * sigma_up
return x_next
def forward(
self, noise: Tensor, fn: Callable, sigmas: Tensor, num_steps: int
) -> Tensor:
x = sigmas[0] * noise
# Denoise to sample
for i in range(num_steps - 1):
x = self.step(x, fn=fn, sigma=sigmas[i], sigma_next=sigmas[i + 1]) # type: ignore # noqa
return x
class ADPM2Sampler(Sampler):
"""https://www.desmos.com/calculator/jbxjlqd9mb"""
diffusion_types = [KDiffusion, VKDiffusion]
def __init__(self, rho: float = 1.0):
super().__init__()
self.rho = rho
def get_sigmas(self, sigma: float, sigma_next: float) -> Tuple[float, float, float]:
r = self.rho
sigma_up = sqrt(sigma_next ** 2 * (sigma ** 2 - sigma_next ** 2) / sigma ** 2)
sigma_down = sqrt(sigma_next ** 2 - sigma_up ** 2)
sigma_mid = ((sigma ** (1 / r) + sigma_down ** (1 / r)) / 2) ** r
return sigma_up, sigma_down, sigma_mid
def step(self, x: Tensor, fn: Callable, sigma: float, sigma_next: float) -> Tensor:
# Sigma steps
sigma_up, sigma_down, sigma_mid = self.get_sigmas(sigma, sigma_next)
# Derivative at sigma (∂x/∂sigma)
d = (x - fn(x, sigma=sigma)) / sigma
# Denoise to midpoint
x_mid = x + d * (sigma_mid - sigma)
# Derivative at sigma_mid (∂x_mid/∂sigma_mid)
d_mid = (x_mid - fn(x_mid, sigma=sigma_mid)) / sigma_mid
# Denoise to next
x = x + d_mid * (sigma_down - sigma)
# Add randomness
x_next = x + torch.randn_like(x) * sigma_up
return x_next
def forward(
self, noise: Tensor, fn: Callable, sigmas: Tensor, num_steps: int
) -> Tensor:
x = sigmas[0] * noise
# Denoise to sample
for i in range(num_steps - 1):
x = self.step(x, fn=fn, sigma=sigmas[i], sigma_next=sigmas[i + 1]) # type: ignore # noqa
return x
def inpaint(
self,
source: Tensor,
mask: Tensor,
fn: Callable,
sigmas: Tensor,
num_steps: int,
num_resamples: int,
) -> Tensor:
x = sigmas[0] * torch.randn_like(source)
for i in range(num_steps - 1):
# Noise source to current noise level
source_noisy = source + sigmas[i] * torch.randn_like(source)
for r in range(num_resamples):
# Merge noisy source and current then denoise
x = source_noisy * mask + x * ~mask
x = self.step(x, fn=fn, sigma=sigmas[i], sigma_next=sigmas[i + 1]) # type: ignore # noqa
# Renoise if not last resample step
if r < num_resamples - 1:
sigma = sqrt(sigmas[i] ** 2 - sigmas[i + 1] ** 2)
x = x + sigma * torch.randn_like(x)
return source * mask + x * ~mask
""" Main Classes """
class DiffusionSampler(nn.Module):
def __init__(
self,
diffusion: Diffusion,
*,
sampler: Sampler,
sigma_schedule: Schedule,
num_steps: Optional[int] = None,
clamp: bool = True,
):
super().__init__()
self.denoise_fn = diffusion.denoise_fn
self.sampler = sampler
self.sigma_schedule = sigma_schedule
self.num_steps = num_steps
self.clamp = clamp
# Check sampler is compatible with diffusion type
sampler_class = sampler.__class__.__name__
diffusion_class = diffusion.__class__.__name__
message = f"{sampler_class} incompatible with {diffusion_class}"
assert diffusion.alias in [t.alias for t in sampler.diffusion_types], message
def forward(
self, noise: Tensor, num_steps: Optional[int] = None, **kwargs
) -> Tensor:
device = noise.device
num_steps = default(num_steps, self.num_steps) # type: ignore
assert exists(num_steps), "Parameter `num_steps` must be provided"
# Compute sigmas using schedule
sigmas = self.sigma_schedule(num_steps, device)
# Append additional kwargs to denoise function (used e.g. for conditional unet)
fn = lambda *a, **ka: self.denoise_fn(*a, **{**ka, **kwargs}) # noqa
# Sample using sampler
x = self.sampler(noise, fn=fn, sigmas=sigmas, num_steps=num_steps)
x = x.clamp(-1.0, 1.0) if self.clamp else x
return x
class DiffusionInpainter(nn.Module):
def __init__(
self,
diffusion: Diffusion,
*,
num_steps: int,
num_resamples: int,
sampler: Sampler,
sigma_schedule: Schedule,
):
super().__init__()
self.denoise_fn = diffusion.denoise_fn
self.num_steps = num_steps
self.num_resamples = num_resamples
self.inpaint_fn = sampler.inpaint
self.sigma_schedule = sigma_schedule
@torch.no_grad()
def forward(self, inpaint: Tensor, inpaint_mask: Tensor) -> Tensor:
x = self.inpaint_fn(
source=inpaint,
mask=inpaint_mask,
fn=self.denoise_fn,
sigmas=self.sigma_schedule(self.num_steps, inpaint.device),
num_steps=self.num_steps,
num_resamples=self.num_resamples,
)
return x
def sequential_mask(like: Tensor, start: int) -> Tensor:
length, device = like.shape[2], like.device
mask = torch.ones_like(like, dtype=torch.bool)
mask[:, :, start:] = torch.zeros((length - start,), device=device)
return mask
class SpanBySpanComposer(nn.Module):
def __init__(
self,
inpainter: DiffusionInpainter,
*,
num_spans: int,
):
super().__init__()
self.inpainter = inpainter
self.num_spans = num_spans
def forward(self, start: Tensor, keep_start: bool = False) -> Tensor:
half_length = start.shape[2] // 2
spans = list(start.chunk(chunks=2, dim=-1)) if keep_start else []
# Inpaint second half from first half
inpaint = torch.zeros_like(start)
inpaint[:, :, :half_length] = start[:, :, half_length:]
inpaint_mask = sequential_mask(like=start, start=half_length)
for i in range(self.num_spans):
# Inpaint second half
span = self.inpainter(inpaint=inpaint, inpaint_mask=inpaint_mask)
# Replace first half with generated second half
second_half = span[:, :, half_length:]
inpaint[:, :, :half_length] = second_half
# Save generated span
spans.append(second_half)
return torch.cat(spans, dim=2)
class XDiffusion(nn.Module):
def __init__(self, type: str, net: nn.Module, **kwargs):
super().__init__()
diffusion_classes = [VDiffusion, KDiffusion, VKDiffusion]
aliases = [t.alias for t in diffusion_classes] # type: ignore
message = f"type='{type}' must be one of {*aliases,}"
assert type in aliases, message
self.net = net
for XDiffusion in diffusion_classes:
if XDiffusion.alias == type: # type: ignore
self.diffusion = XDiffusion(net=net, **kwargs)
def forward(self, *args, **kwargs) -> Tensor:
return self.diffusion(*args, **kwargs)
def sample(
self,
noise: Tensor,
num_steps: int,
sigma_schedule: Schedule,
sampler: Sampler,
clamp: bool,
**kwargs,
) -> Tensor:
diffusion_sampler = DiffusionSampler(
diffusion=self.diffusion,
sampler=sampler,
sigma_schedule=sigma_schedule,
num_steps=num_steps,
clamp=clamp,
)
return diffusion_sampler(noise, **kwargs)