Spaces:
Running
Running
File size: 7,402 Bytes
3aa4060 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 |
import torch
class NoiseScheduleVP:
def __init__(self, beta_min=0.05, beta_max=20):
self.beta_min = beta_min
self.beta_max = beta_max
self.T = 1.
def get_noise(self, t, beta_init, beta_term, cumulative=False):
if cumulative:
noise = beta_init*t + 0.5*(beta_term - beta_init)*(t**2)
else:
noise = beta_init + (beta_term - beta_init)*t
return noise
def marginal_log_mean_coeff(self, t):
return -0.25 * t**2 * (self.beta_max -
self.beta_min) - 0.5 * t * self.beta_min
def marginal_std(self, t):
return torch.sqrt(1. - torch.exp(2. * self.marginal_log_mean_coeff(t)))
def marginal_lambda(self, t):
log_mean_coeff = self.marginal_log_mean_coeff(t)
log_std = 0.5 * torch.log(1. - torch.exp(2. * log_mean_coeff))
return log_mean_coeff - log_std
def inverse_lambda(self, lamb):
tmp = 2. * (self.beta_max - self.beta_min) * torch.logaddexp(
-2. * lamb,
torch.zeros((1, )).to(lamb))
Delta = self.beta_min**2 + tmp
return tmp / (torch.sqrt(Delta) + self.beta_min) / (self.beta_max -
self.beta_min)
def get_time_steps(self, t_T, t_0, N):
lambda_T = self.marginal_lambda(torch.tensor(t_T))
lambda_0 = self.marginal_lambda(torch.tensor(t_0))
logSNR_steps = torch.linspace(lambda_T, lambda_0, N + 1)
return self.inverse_lambda(logSNR_steps)
@torch.no_grad()
def reverse_diffusion(self, estimator, spk, z, mask, mu, n_timesteps, stoc):
print("use dpm-solver reverse")
xt = z * mask
yt = xt - mu
T = 1
eps = 1e-3
time = self.get_time_steps(T, eps, n_timesteps)
for i in range(n_timesteps):
s = torch.ones((xt.shape[0], )).to(xt.device) * time[i]
t = torch.ones((xt.shape[0], )).to(xt.device) * time[i + 1]
lambda_s = self.marginal_lambda(s)
lambda_t = self.marginal_lambda(t)
h = lambda_t - lambda_s
log_alpha_s = self.marginal_log_mean_coeff(s)
log_alpha_t = self.marginal_log_mean_coeff(t)
sigma_t = self.marginal_std(t)
phi_1 = torch.expm1(h)
noise_s = estimator(spk, yt + mu, mask, mu, s)
lt = 1 - torch.exp(-self.get_noise(s, self.beta_min, self.beta_max, cumulative=True))
a = torch.exp(log_alpha_t - log_alpha_s)
b = sigma_t * phi_1 * torch.sqrt(lt)
yt = a * yt + (b * noise_s)
xt = yt + mu
return xt
class MaxLikelihood:
def __init__(self, beta_min=0.05, beta_max=20):
self.beta_min = beta_min
self.beta_max = beta_max
def get_noise(self, t, beta_init, beta_term, cumulative=False):
if cumulative:
noise = beta_init*t + 0.5*(beta_term - beta_init)*(t**2)
else:
noise = beta_init + (beta_term - beta_init)*t
return noise
def get_gamma(self, s, t, beta_init, beta_term):
gamma = beta_init*(t-s) + 0.5*(beta_term-beta_init)*(t**2-s**2)
gamma = torch.exp(-0.5*gamma)
return gamma
def get_mu(self, s, t):
gamma_0_s = self.get_gamma(0, s, self.beta_min, self.beta_max)
gamma_0_t = self.get_gamma(0, t, self.beta_min, self.beta_max)
gamma_s_t = self.get_gamma(s, t, self.beta_min, self.beta_max)
mu = gamma_s_t * ((1-gamma_0_s**2) / (1-gamma_0_t**2))
return mu
def get_nu(self, s, t):
gamma_0_s = self.get_gamma(0, s, self.beta_min, self.beta_max)
gamma_0_t = self.get_gamma(0, t, self.beta_min, self.beta_max)
gamma_s_t = self.get_gamma(s, t, self.beta_min, self.beta_max)
nu = gamma_0_s * ((1-gamma_s_t**2) / (1-gamma_0_t**2))
return nu
def get_sigma(self, s, t):
gamma_0_s = self.get_gamma(0, s, self.beta_min, self.beta_max)
gamma_0_t = self.get_gamma(0, t, self.beta_min, self.beta_max)
gamma_s_t = self.get_gamma(s, t, self.beta_min, self.beta_max)
sigma = torch.sqrt(((1 - gamma_0_s**2) * (1 - gamma_s_t**2)) / (1 - gamma_0_t**2))
return sigma
def get_kappa(self, t, h, noise):
nu = self.get_nu(t-h, t)
gamma_0_t = self.get_gamma(0, t, self.beta_min, self.beta_max)
kappa = (nu*(1-gamma_0_t**2)/(gamma_0_t*noise*h) - 1)
return kappa
def get_omega(self, t, h, noise):
mu = self.get_mu(t-h, t)
kappa = self.get_kappa(t, h, noise)
gamma_0_t = self.get_gamma(0, t, self.beta_min, self.beta_max)
omega = (mu-1)/(noise*h) + (1+kappa)/(1-gamma_0_t**2) - 0.5
return omega
@torch.no_grad()
def reverse_diffusion(self, estimator, spk, z, mask, mu, n_timesteps, stoc=False):
print("use MaxLikelihood reverse")
h = 1.0 / n_timesteps
xt = z * mask
for i in range(n_timesteps):
t = (1.0 - i*h) * torch.ones(z.shape[0], dtype=z.dtype,
device=z.device)
time = t.unsqueeze(-1).unsqueeze(-1)
noise_t = self.get_noise(time, self.beta_min, self.beta_max,
cumulative=False)
kappa_t_h = self.get_kappa(t, h, noise_t)
omega_t_h = self.get_omega(t, h, noise_t)
sigma_t_h = self.get_sigma(t-h, t)
es = estimator(spk, xt, mask, mu, t)
dxt = ((0.5+omega_t_h)*(xt - mu) + (1+kappa_t_h) * es)
dxt_stoc = torch.randn(z.shape, dtype=z.dtype, device=z.device,
requires_grad=False)
dxt_stoc = dxt_stoc * sigma_t_h
dxt = dxt * noise_t * h + dxt_stoc
xt = (xt + dxt) * mask
return xt
class GradRaw:
def __init__(self, beta_min=0.05, beta_max=20):
self.beta_min = beta_min
self.beta_max = beta_max
def get_noise(self, t, beta_init, beta_term, cumulative=False):
if cumulative:
noise = beta_init*t + 0.5*(beta_term - beta_init)*(t**2)
else:
noise = beta_init + (beta_term - beta_init)*t
return noise
@torch.no_grad()
def reverse_diffusion(self, estimator, spk, z, mask, mu, n_timesteps, stoc=False):
print("use grad-raw reverse")
h = 1.0 / n_timesteps
xt = z * mask
for i in range(n_timesteps):
t = (1.0 - (i + 0.5)*h) * \
torch.ones(z.shape[0], dtype=z.dtype, device=z.device)
time = t.unsqueeze(-1).unsqueeze(-1)
noise_t = self.get_noise(time, self.beta_min, self.beta_max,
cumulative=False)
if stoc: # adds stochastic term
dxt_det = 0.5 * (mu - xt) - estimator(spk, xt, mask, mu, t)
dxt_det = dxt_det * noise_t * h
dxt_stoc = torch.randn(z.shape, dtype=z.dtype, device=z.device,
requires_grad=False)
dxt_stoc = dxt_stoc * torch.sqrt(noise_t * h)
dxt = dxt_det + dxt_stoc
else:
dxt = 0.5 * (mu - xt - estimator(spk, xt, mask, mu, t))
dxt = dxt * noise_t * h
xt = (xt - dxt) * mask
return xt
|