Spaces:
Runtime error
Runtime error
File size: 6,944 Bytes
f85e212 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 |
import torch
import torch.nn.functional as F
from medical_diffusion.models.noise_schedulers import BasicNoiseScheduler
class GaussianNoiseScheduler(BasicNoiseScheduler):
def __init__(
self,
timesteps=1000,
T = None,
schedule_strategy='cosine',
beta_start = 0.0001, # default 1e-4, stable-diffusion ~ 1e-3
beta_end = 0.02,
betas = None,
):
super().__init__(timesteps, T)
self.schedule_strategy = schedule_strategy
if betas is not None:
betas = torch.as_tensor(betas, dtype = torch.float64)
elif schedule_strategy == "linear":
betas = torch.linspace(beta_start, beta_end, timesteps, dtype = torch.float64)
elif schedule_strategy == "scaled_linear": # proposed as "quadratic" in https://arxiv.org/abs/2006.11239, used in stable-diffusion
betas = torch.linspace(beta_start**0.5, beta_end**0.5, timesteps, dtype = torch.float64)**2
elif schedule_strategy == "cosine":
s = 0.008
x = torch.linspace(0, timesteps, timesteps + 1, dtype = torch.float64) # [0, T]
alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2
alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
betas = torch.clip(betas, 0, 0.999)
else:
raise NotImplementedError(f"{schedule_strategy} does is not implemented for {self.__class__}")
alphas = 1-betas
alphas_cumprod = torch.cumprod(alphas, dim=0)
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value = 1.)
register_buffer = lambda name, val: self.register_buffer(name, val.to(torch.float32))
register_buffer('betas', betas) # (0 , 1)
register_buffer('alphas', alphas) # (1 , 0)
register_buffer('alphas_cumprod', alphas_cumprod)
register_buffer('alphas_cumprod_prev', alphas_cumprod_prev)
register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod))
register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1. - alphas_cumprod))
register_buffer('sqrt_recip_alphas_cumprod', torch.sqrt(1. / alphas_cumprod))
register_buffer('sqrt_recipm1_alphas_cumprod', torch.sqrt(1. / alphas_cumprod - 1))
register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))
register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod))
register_buffer('posterior_variance', betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod))
def estimate_x_t(self, x_0, t, x_T=None):
# NOTE: t == 0 means diffused for 1 step (https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils.py#L108)
# NOTE: t == 0 means not diffused for cold-diffusion (in contradiction to the above comment) https://github.com/arpitbansal297/Cold-Diffusion-Models/blob/c828140b7047ca22f995b99fbcda360bc30fc25d/denoising-diffusion-pytorch/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py#L361
x_T = self.x_final(x_0) if x_T is None else x_T
# ndim = x_0.ndim
# x_t = (self.extract(self.sqrt_alphas_cumprod, t, ndim)*x_0 +
# self.extract(self.sqrt_one_minus_alphas_cumprod, t, ndim)*x_T)
def clipper(b):
tb = t[b]
if tb<0:
return x_0[b]
elif tb>=self.T:
return x_T[b]
else:
return self.sqrt_alphas_cumprod[tb]*x_0[b]+self.sqrt_one_minus_alphas_cumprod[tb]*x_T[b]
x_t = torch.stack([clipper(b) for b in range(t.shape[0])])
return x_t
def estimate_x_t_prior_from_x_T(self, x_t, t, x_T, use_log=True, clip_x0=True, var_scale=0, cold_diffusion=False):
x_0 = self.estimate_x_0(x_t, x_T, t, clip_x0)
return self.estimate_x_t_prior_from_x_0(x_t, t, x_0, use_log, clip_x0, var_scale, cold_diffusion)
def estimate_x_t_prior_from_x_0(self, x_t, t, x_0, use_log=True, clip_x0=True, var_scale=0, cold_diffusion=False):
x_0 = self._clip_x_0(x_0) if clip_x0 else x_0
if cold_diffusion: # see https://arxiv.org/abs/2208.09392
x_T_est = self.estimate_x_T(x_t, x_0, t) # or use x_T estimated by UNet if available?
x_t_est = self.estimate_x_t(x_0, t, x_T=x_T_est)
x_t_prior = self.estimate_x_t(x_0, t-1, x_T=x_T_est)
noise_t = x_t_est-x_t_prior
x_t_prior = x_t-noise_t
else:
mean = self.estimate_mean_t(x_t, x_0, t)
variance = self.estimate_variance_t(t, x_t.ndim, use_log, var_scale)
std = torch.exp(0.5*variance) if use_log else torch.sqrt(variance)
std[t==0] = 0.0
x_T = self.x_final(x_t)
x_t_prior = mean+std*x_T
return x_t_prior, x_0
def estimate_mean_t(self, x_t, x_0, t):
ndim = x_t.ndim
return (self.extract(self.posterior_mean_coef1, t, ndim)*x_0+
self.extract(self.posterior_mean_coef2, t, ndim)*x_t)
def estimate_variance_t(self, t, ndim, log=True, var_scale=0, eps=1e-20):
min_variance = self.extract(self.posterior_variance, t, ndim)
max_variance = self.extract(self.betas, t, ndim)
if log:
min_variance = torch.log(min_variance.clamp(min=eps))
max_variance = torch.log(max_variance.clamp(min=eps))
return var_scale * max_variance + (1 - var_scale) * min_variance
def estimate_x_0(self, x_t, x_T, t, clip_x0=True):
ndim = x_t.ndim
x_0 = (self.extract(self.sqrt_recip_alphas_cumprod, t, ndim)*x_t -
self.extract(self.sqrt_recipm1_alphas_cumprod, t, ndim)*x_T)
x_0 = self._clip_x_0(x_0) if clip_x0 else x_0
return x_0
def estimate_x_T(self, x_t, x_0, t, clip_x0=True):
ndim = x_t.ndim
x_0 = self._clip_x_0(x_0) if clip_x0 else x_0
return ((self.extract(self.sqrt_recip_alphas_cumprod, t, ndim)*x_t - x_0)/
self.extract(self.sqrt_recipm1_alphas_cumprod, t, ndim))
@classmethod
def x_final(cls, x):
return torch.randn_like(x)
@classmethod
def _clip_x_0(cls, x_0):
# See "static/dynamic thresholding" in Imagen https://arxiv.org/abs/2205.11487
# "static thresholding"
m = 1 # Set this to about 4*sigma = 4 if latent diffusion is used
x_0 = x_0.clamp(-m, m)
# "dynamic thresholding"
# r = torch.stack([torch.quantile(torch.abs(x_0_b), 0.997) for x_0_b in x_0])
# r = torch.maximum(r, torch.full_like(r,m))
# x_0 = torch.stack([x_0_b.clamp(-r_b, r_b)/r_b*m for x_0_b, r_b in zip(x_0, r) ] )
return x_0
|