|
from functools import partial |
|
from typing import Tuple |
|
|
|
import torch |
|
from torch import nn |
|
import numpy as np |
|
|
|
|
|
def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): |
|
if schedule == "linear": |
|
betas = ( |
|
np.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=np.float64) ** 2 |
|
) |
|
|
|
elif schedule == "cosine": |
|
timesteps = ( |
|
np.arange(n_timestep + 1, dtype=np.float64) / n_timestep + cosine_s |
|
) |
|
alphas = timesteps / (1 + cosine_s) * np.pi / 2 |
|
alphas = np.cos(alphas).pow(2) |
|
alphas = alphas / alphas[0] |
|
betas = 1 - alphas[1:] / alphas[:-1] |
|
betas = np.clip(betas, a_min=0, a_max=0.999) |
|
|
|
elif schedule == "sqrt_linear": |
|
betas = np.linspace(linear_start, linear_end, n_timestep, dtype=np.float64) |
|
elif schedule == "sqrt": |
|
betas = np.linspace(linear_start, linear_end, n_timestep, dtype=np.float64) ** 0.5 |
|
else: |
|
raise ValueError(f"schedule '{schedule}' unknown.") |
|
return betas |
|
|
|
|
|
def extract_into_tensor(a: torch.Tensor, t: torch.Tensor, x_shape: Tuple[int]) -> torch.Tensor: |
|
b, *_ = t.shape |
|
out = a.gather(-1, t) |
|
return out.reshape(b, *((1,) * (len(x_shape) - 1))) |
|
|
|
|
|
class Diffusion(nn.Module): |
|
|
|
def __init__( |
|
self, |
|
timesteps=1000, |
|
beta_schedule="linear", |
|
loss_type="l2", |
|
linear_start=1e-4, |
|
linear_end=2e-2, |
|
cosine_s=8e-3, |
|
parameterization="eps" |
|
): |
|
super().__init__() |
|
self.num_timesteps = timesteps |
|
self.beta_schedule = beta_schedule |
|
self.linear_start = linear_start |
|
self.linear_end = linear_end |
|
self.cosine_s = cosine_s |
|
assert parameterization in ["eps", "x0", "v"], "currently only supporting 'eps' and 'x0' and 'v'" |
|
self.parameterization = parameterization |
|
self.loss_type = loss_type |
|
|
|
betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end, |
|
cosine_s=cosine_s) |
|
alphas = 1. - betas |
|
alphas_cumprod = np.cumprod(alphas, axis=0) |
|
sqrt_alphas_cumprod = np.sqrt(alphas_cumprod) |
|
sqrt_one_minus_alphas_cumprod = np.sqrt(1. - alphas_cumprod) |
|
|
|
self.betas = betas |
|
self.register("sqrt_alphas_cumprod", sqrt_alphas_cumprod) |
|
self.register("sqrt_one_minus_alphas_cumprod", sqrt_one_minus_alphas_cumprod) |
|
|
|
def register(self, name: str, value: np.ndarray) -> None: |
|
self.register_buffer(name, torch.tensor(value, dtype=torch.float32)) |
|
|
|
def q_sample(self, x_start, t, noise): |
|
return ( |
|
extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + |
|
extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise |
|
) |
|
|
|
def get_v(self, x, noise, t): |
|
return ( |
|
extract_into_tensor(self.sqrt_alphas_cumprod, t, x.shape) * noise - |
|
extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x.shape) * x |
|
) |
|
|
|
def get_loss(self, pred, target, mean=True): |
|
if self.loss_type == 'l1': |
|
loss = (target - pred).abs() |
|
if mean: |
|
loss = loss.mean() |
|
elif self.loss_type == 'l2': |
|
if mean: |
|
loss = torch.nn.functional.mse_loss(target, pred) |
|
else: |
|
loss = torch.nn.functional.mse_loss(target, pred, reduction='none') |
|
else: |
|
raise NotImplementedError("unknown loss type '{loss_type}'") |
|
|
|
return loss |
|
|
|
def p_losses(self, model, x_start, t, cond): |
|
noise = torch.randn_like(x_start) |
|
x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) |
|
model_output = model(x_noisy, t, cond) |
|
|
|
if self.parameterization == "x0": |
|
target = x_start |
|
elif self.parameterization == "eps": |
|
target = noise |
|
elif self.parameterization == "v": |
|
target = self.get_v(x_start, noise, t) |
|
else: |
|
raise NotImplementedError() |
|
|
|
loss_simple = self.get_loss(model_output, target, mean=False).mean() |
|
return loss_simple |
|
|