Spaces:
Paused
Paused
import math | |
import torch | |
import torch.nn.functional as F | |
def cosine_beta_schedule(timesteps, s=0.008): | |
""" | |
cosine schedule as proposed in https://arxiv.org/abs/2102.09672 | |
""" | |
steps = timesteps + 1 | |
x = torch.linspace(0, timesteps, steps) | |
alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * math.pi * 0.5) ** 2 | |
alphas_cumprod = alphas_cumprod / alphas_cumprod[0] | |
betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1]) | |
return torch.clip(betas, 0.0001, 0.9999) | |
def linear_beta_schedule(timesteps): | |
beta_start = 0.0001 | |
beta_end = 0.02 | |
return torch.linspace(beta_start, beta_end, timesteps) | |
def quadratic_beta_schedule(timesteps): | |
beta_start = 0.0001 | |
beta_end = 0.02 | |
return torch.linspace(beta_start**0.5, beta_end**0.5, timesteps) ** 2 | |
def sigmoid_beta_schedule(timesteps): | |
beta_start = 0.0001 | |
beta_end = 0.02 | |
betas = torch.linspace(-6, 6, timesteps) | |
return torch.sigmoid(betas) * (beta_end - beta_start) + beta_start | |
class NoiseSchedule: | |
def __init__(self, timesteps=200): | |
self.timesteps = timesteps | |
# define beta schedule | |
self.betas = linear_beta_schedule(timesteps=timesteps) | |
# self.betas = cosine_beta_schedule(timesteps=timesteps) | |
# define alphas | |
self.alphas = 1. - self.betas | |
# alphas_cumprod: alpha bar | |
self.alphas_cumprod = torch.cumprod(self.alphas, axis=0) | |
self.alphas_cumprod_prev = F.pad(self.alphas_cumprod[:-1], (1, 0), value=1.0) | |
self.sqrt_recip_alphas = torch.sqrt(1.0 / self.alphas) | |
# calculations for diffusion q(x_t | x_{t-1}) and others | |
self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod) | |
self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - self.alphas_cumprod) | |
# calculations for posterior q(x_{t-1} | x_t, x_0) | |
self.posterior_variance = self.betas * (1. - self.alphas_cumprod_prev) / (1. - self.alphas_cumprod) | |
def extract(a, t, x_shape): | |
batch_size = t.shape[0] | |
out = a.gather(-1, t.cpu()) | |
return out.reshape(batch_size, *((1,) * (len(x_shape) - 1))).to(t.device) | |
# forward diffusion (using the nice property) | |
def q_sample(x_start, t, noise_schedule, noise=None): | |
if noise is None: | |
noise = torch.randn_like(x_start) | |
sqrt_alphas_cumprod_t = extract(noise_schedule.sqrt_alphas_cumprod, t, x_start.shape) | |
# print("sqrt_alphas_cumprod_t", sqrt_alphas_cumprod_t) | |
sqrt_one_minus_alphas_cumprod_t = extract( | |
noise_schedule.sqrt_one_minus_alphas_cumprod, t, x_start.shape | |
) | |
# print("sqrt_one_minus_alphas_cumprod_t", sqrt_one_minus_alphas_cumprod_t) | |
# print("noise", noise) | |
return sqrt_alphas_cumprod_t * x_start + sqrt_one_minus_alphas_cumprod_t * noise |