|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import numpy as np |
|
import scipy.signal |
|
import torch |
|
from torch_utils import persistence |
|
from torch_utils import misc |
|
from torch_utils.ops import upfirdn2d |
|
from torch_utils.ops import grid_sample_gradfix |
|
from torch_utils.ops import conv2d_gradfix |
|
|
|
|
|
|
|
|
|
|
|
def get_beta_schedule(beta_schedule, beta_start, beta_end, num_diffusion_timesteps): |
|
def sigmoid(x): |
|
return 1 / (np.exp(-x) + 1) |
|
|
|
def continuous_t_beta(t, T): |
|
b_max = 5. |
|
b_min = 0.1 |
|
alpha = np.exp(-b_min / T - 0.5 * (b_max - b_min) * (2 * t - 1) / T ** 2) |
|
return 1 - alpha |
|
|
|
if beta_schedule == "continuous_t": |
|
betas = continuous_t_beta(np.arange(1, num_diffusion_timesteps+1), num_diffusion_timesteps) |
|
elif beta_schedule == "quad": |
|
betas = ( |
|
np.linspace( |
|
beta_start ** 0.5, |
|
beta_end ** 0.5, |
|
num_diffusion_timesteps, |
|
dtype=np.float64, |
|
) |
|
** 2 |
|
) |
|
elif beta_schedule == "linear": |
|
betas = np.linspace( |
|
beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64 |
|
) |
|
elif beta_schedule == "const": |
|
betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64) |
|
elif beta_schedule == "jsd": |
|
betas = 1.0 / np.linspace( |
|
num_diffusion_timesteps, 1, num_diffusion_timesteps, dtype=np.float64 |
|
) |
|
elif beta_schedule == "sigmoid": |
|
betas = np.linspace(-6, 6, num_diffusion_timesteps) |
|
betas = sigmoid(betas) * (beta_end - beta_start) + beta_start |
|
else: |
|
raise NotImplementedError(beta_schedule) |
|
assert betas.shape == (num_diffusion_timesteps,) |
|
return betas |
|
|
|
|
|
def q_sample(x_0, alphas_bar_sqrt, one_minus_alphas_bar_sqrt, t, noise_type='gauss', noise_std=1.0): |
|
batch_size, num_channels, _, _ = x_0.shape |
|
if noise_type == 'gauss': |
|
noise = torch.randn_like(x_0, device=x_0.device) * noise_std |
|
elif noise_type == 'bernoulli': |
|
noise = (torch.bernoulli(torch.ones_like(x_0) * 0.5) * 2 - 1.) * noise_std |
|
else: |
|
raise NotImplementedError(noise_type) |
|
alphas_t_sqrt = alphas_bar_sqrt[t].view(batch_size, num_channels, 1, 1) |
|
one_minus_alphas_bar_t_sqrt = one_minus_alphas_bar_sqrt[t].view(batch_size, num_channels, 1, 1) |
|
x_t = alphas_t_sqrt * x_0 + one_minus_alphas_bar_t_sqrt * noise |
|
return x_t |
|
|
|
|
|
@persistence.persistent_class |
|
class Diffusion(torch.nn.Module): |
|
def __init__(self, |
|
beta_schedule='linear', beta_start=1e-4, beta_end=1e-2, |
|
t_min=5, t_max=500, noise_std=0.5, |
|
): |
|
super().__init__() |
|
self.p = 0.0 |
|
self.noise_type = self.base_noise_type = 'gauss' |
|
self.base_schedule = beta_schedule |
|
self.beta_start = beta_start |
|
self.beta_end = beta_end |
|
self.t_min = t_min |
|
self.t_max = t_max |
|
self.t_add = t_max - t_min |
|
self.update_T() |
|
|
|
|
|
self.noise_std = float(noise_std) |
|
|
|
def set_diffusion_process(self, t, beta_schedule): |
|
|
|
betas = get_beta_schedule( |
|
beta_schedule=beta_schedule, |
|
beta_start=self.beta_start, |
|
beta_end=self.beta_end, |
|
num_diffusion_timesteps=t, |
|
) |
|
|
|
betas = self.betas = torch.from_numpy(betas).float() |
|
self.num_timesteps = betas.shape[0] |
|
|
|
alphas = self.alphas = 1.0 - betas |
|
alphas_cumprod = torch.cat([torch.tensor([1.]), alphas.cumprod(dim=0)]) |
|
self.alphas_bar_sqrt = torch.sqrt(alphas_cumprod) |
|
self.one_minus_alphas_bar_sqrt = torch.sqrt(1 - alphas_cumprod) |
|
|
|
def update_T(self): |
|
t_adjust = round(self.p * self.t_add) |
|
t = np.clip(int(self.t_min + t_adjust), a_min=self.t_min, a_max=self.t_max) |
|
self.set_diffusion_process(t, "linear") |
|
|
|
|
|
self.t_epl = np.zeros(64, dtype=np.int) |
|
diffusion_ind = min(round(self.p * 64), 48) |
|
prob_t = np.arange(t) / np.arange(t).sum() |
|
t_diffusion = np.random.choice(np.arange(1, t+1), size=diffusion_ind, p=prob_t) |
|
self.t_epl[:diffusion_ind] = t_diffusion |
|
|
|
def forward(self, x_0, noise_std=1.0): |
|
assert isinstance(x_0, torch.Tensor) and x_0.ndim == 4 |
|
batch_size, num_channels, height, width = x_0.shape |
|
device = x_0.device |
|
|
|
alphas_bar_sqrt = self.alphas_bar_sqrt.to(device) |
|
one_minus_alphas_bar_sqrt = self.one_minus_alphas_bar_sqrt.to(device) |
|
|
|
t = torch.from_numpy(np.random.choice(self.t_epl, size=batch_size * num_channels, replace=True)).to(device) |
|
|
|
x_t = q_sample(x_0, alphas_bar_sqrt, one_minus_alphas_bar_sqrt, t, |
|
noise_type=self.noise_type, |
|
noise_std=noise_std) |
|
return x_t |
|
|
|
|