mueller-franzes's picture
history blame
6.94 kB
import torch
import torch.nn.functional as F
from medical_diffusion.models.noise_schedulers import BasicNoiseScheduler
class GaussianNoiseScheduler(BasicNoiseScheduler):
def __init__(
T = None,
beta_start = 0.0001, # default 1e-4, stable-diffusion ~ 1e-3
beta_end = 0.02,
betas = None,
super().__init__(timesteps, T)
self.schedule_strategy = schedule_strategy
if betas is not None:
betas = torch.as_tensor(betas, dtype = torch.float64)
elif schedule_strategy == "linear":
betas = torch.linspace(beta_start, beta_end, timesteps, dtype = torch.float64)
elif schedule_strategy == "scaled_linear": # proposed as "quadratic" in, used in stable-diffusion
betas = torch.linspace(beta_start**0.5, beta_end**0.5, timesteps, dtype = torch.float64)**2
elif schedule_strategy == "cosine":
s = 0.008
x = torch.linspace(0, timesteps, timesteps + 1, dtype = torch.float64) # [0, T]
alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2
alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
betas = torch.clip(betas, 0, 0.999)
raise NotImplementedError(f"{schedule_strategy} does is not implemented for {self.__class__}")
alphas = 1-betas
alphas_cumprod = torch.cumprod(alphas, dim=0)
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value = 1.)
register_buffer = lambda name, val: self.register_buffer(name,
register_buffer('betas', betas) # (0 , 1)
register_buffer('alphas', alphas) # (1 , 0)
register_buffer('alphas_cumprod', alphas_cumprod)
register_buffer('alphas_cumprod_prev', alphas_cumprod_prev)
register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod))
register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1. - alphas_cumprod))
register_buffer('sqrt_recip_alphas_cumprod', torch.sqrt(1. / alphas_cumprod))
register_buffer('sqrt_recipm1_alphas_cumprod', torch.sqrt(1. / alphas_cumprod - 1))
register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))
register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod))
register_buffer('posterior_variance', betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod))
def estimate_x_t(self, x_0, t, x_T=None):
# NOTE: t == 0 means diffused for 1 step (
# NOTE: t == 0 means not diffused for cold-diffusion (in contradiction to the above comment)
x_T = self.x_final(x_0) if x_T is None else x_T
# ndim = x_0.ndim
# x_t = (self.extract(self.sqrt_alphas_cumprod, t, ndim)*x_0 +
# self.extract(self.sqrt_one_minus_alphas_cumprod, t, ndim)*x_T)
def clipper(b):
tb = t[b]
if tb<0:
return x_0[b]
elif tb>=self.T:
return x_T[b]
return self.sqrt_alphas_cumprod[tb]*x_0[b]+self.sqrt_one_minus_alphas_cumprod[tb]*x_T[b]
x_t = torch.stack([clipper(b) for b in range(t.shape[0])])
return x_t
def estimate_x_t_prior_from_x_T(self, x_t, t, x_T, use_log=True, clip_x0=True, var_scale=0, cold_diffusion=False):
x_0 = self.estimate_x_0(x_t, x_T, t, clip_x0)
return self.estimate_x_t_prior_from_x_0(x_t, t, x_0, use_log, clip_x0, var_scale, cold_diffusion)
def estimate_x_t_prior_from_x_0(self, x_t, t, x_0, use_log=True, clip_x0=True, var_scale=0, cold_diffusion=False):
x_0 = self._clip_x_0(x_0) if clip_x0 else x_0
if cold_diffusion: # see
x_T_est = self.estimate_x_T(x_t, x_0, t) # or use x_T estimated by UNet if available?
x_t_est = self.estimate_x_t(x_0, t, x_T=x_T_est)
x_t_prior = self.estimate_x_t(x_0, t-1, x_T=x_T_est)
noise_t = x_t_est-x_t_prior
x_t_prior = x_t-noise_t
mean = self.estimate_mean_t(x_t, x_0, t)
variance = self.estimate_variance_t(t, x_t.ndim, use_log, var_scale)
std = torch.exp(0.5*variance) if use_log else torch.sqrt(variance)
std[t==0] = 0.0
x_T = self.x_final(x_t)
x_t_prior = mean+std*x_T
return x_t_prior, x_0
def estimate_mean_t(self, x_t, x_0, t):
ndim = x_t.ndim
return (self.extract(self.posterior_mean_coef1, t, ndim)*x_0+
self.extract(self.posterior_mean_coef2, t, ndim)*x_t)
def estimate_variance_t(self, t, ndim, log=True, var_scale=0, eps=1e-20):
min_variance = self.extract(self.posterior_variance, t, ndim)
max_variance = self.extract(self.betas, t, ndim)
if log:
min_variance = torch.log(min_variance.clamp(min=eps))
max_variance = torch.log(max_variance.clamp(min=eps))
return var_scale * max_variance + (1 - var_scale) * min_variance
def estimate_x_0(self, x_t, x_T, t, clip_x0=True):
ndim = x_t.ndim
x_0 = (self.extract(self.sqrt_recip_alphas_cumprod, t, ndim)*x_t -
self.extract(self.sqrt_recipm1_alphas_cumprod, t, ndim)*x_T)
x_0 = self._clip_x_0(x_0) if clip_x0 else x_0
return x_0
def estimate_x_T(self, x_t, x_0, t, clip_x0=True):
ndim = x_t.ndim
x_0 = self._clip_x_0(x_0) if clip_x0 else x_0
return ((self.extract(self.sqrt_recip_alphas_cumprod, t, ndim)*x_t - x_0)/
self.extract(self.sqrt_recipm1_alphas_cumprod, t, ndim))
def x_final(cls, x):
return torch.randn_like(x)
def _clip_x_0(cls, x_0):
# See "static/dynamic thresholding" in Imagen
# "static thresholding"
m = 1 # Set this to about 4*sigma = 4 if latent diffusion is used
x_0 = x_0.clamp(-m, m)
# "dynamic thresholding"
# r = torch.stack([torch.quantile(torch.abs(x_0_b), 0.997) for x_0_b in x_0])
# r = torch.maximum(r, torch.full_like(r,m))
# x_0 = torch.stack([x_0_b.clamp(-r_b, r_b)/r_b*m for x_0_b, r_b in zip(x_0, r) ] )
return x_0