Weiyu Liu
add demo
8c02843
raw
history blame
2.76 kB
import math
import torch
import torch.nn.functional as F
def cosine_beta_schedule(timesteps, s=0.008):
"""
cosine schedule as proposed in https://arxiv.org/abs/2102.09672
"""
steps = timesteps + 1
x = torch.linspace(0, timesteps, steps)
alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * math.pi * 0.5) ** 2
alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
return torch.clip(betas, 0.0001, 0.9999)
def linear_beta_schedule(timesteps):
beta_start = 0.0001
beta_end = 0.02
return torch.linspace(beta_start, beta_end, timesteps)
def quadratic_beta_schedule(timesteps):
beta_start = 0.0001
beta_end = 0.02
return torch.linspace(beta_start**0.5, beta_end**0.5, timesteps) ** 2
def sigmoid_beta_schedule(timesteps):
beta_start = 0.0001
beta_end = 0.02
betas = torch.linspace(-6, 6, timesteps)
return torch.sigmoid(betas) * (beta_end - beta_start) + beta_start
class NoiseSchedule:
def __init__(self, timesteps=200):
self.timesteps = timesteps
# define beta schedule
self.betas = linear_beta_schedule(timesteps=timesteps)
# self.betas = cosine_beta_schedule(timesteps=timesteps)
# define alphas
self.alphas = 1. - self.betas
# alphas_cumprod: alpha bar
self.alphas_cumprod = torch.cumprod(self.alphas, axis=0)
self.alphas_cumprod_prev = F.pad(self.alphas_cumprod[:-1], (1, 0), value=1.0)
self.sqrt_recip_alphas = torch.sqrt(1.0 / self.alphas)
# calculations for diffusion q(x_t | x_{t-1}) and others
self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)
self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - self.alphas_cumprod)
# calculations for posterior q(x_{t-1} | x_t, x_0)
self.posterior_variance = self.betas * (1. - self.alphas_cumprod_prev) / (1. - self.alphas_cumprod)
def extract(a, t, x_shape):
batch_size = t.shape[0]
out = a.gather(-1, t.cpu())
return out.reshape(batch_size, *((1,) * (len(x_shape) - 1))).to(t.device)
# forward diffusion (using the nice property)
def q_sample(x_start, t, noise_schedule, noise=None):
if noise is None:
noise = torch.randn_like(x_start)
sqrt_alphas_cumprod_t = extract(noise_schedule.sqrt_alphas_cumprod, t, x_start.shape)
# print("sqrt_alphas_cumprod_t", sqrt_alphas_cumprod_t)
sqrt_one_minus_alphas_cumprod_t = extract(
noise_schedule.sqrt_one_minus_alphas_cumprod, t, x_start.shape
)
# print("sqrt_one_minus_alphas_cumprod_t", sqrt_one_minus_alphas_cumprod_t)
# print("noise", noise)
return sqrt_alphas_cumprod_t * x_start + sqrt_one_minus_alphas_cumprod_t * noise