|
import torch |
|
import torch.nn as nn |
|
import numpy as np |
|
from functools import partial |
|
from inspect import isfunction |
|
|
|
def exists(x): |
|
return x is not None |
|
|
|
def default(val, d): |
|
if exists(val): |
|
return val |
|
return d() if isfunction(d) else d |
|
|
|
def extract_into_tensor(a, t, x_shape): |
|
b, *_ = t.shape |
|
out = a.gather(-1, t) |
|
return out.reshape(b, *((1,) * (len(x_shape) - 1))) |
|
|
|
def make_beta_schedule(n_timestep, linear_start=1e-4, linear_end=2e-2): |
|
betas = ( |
|
torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2 |
|
) |
|
return betas.numpy() |
|
|
|
class AbstractLowScaleModel(nn.Module): |
|
|
|
def __init__(self, noise_schedule_config=None): |
|
super(AbstractLowScaleModel, self).__init__() |
|
if noise_schedule_config is not None: |
|
self.register_schedule(**noise_schedule_config) |
|
|
|
def register_schedule(self, timesteps=1000, linear_start=1e-4, linear_end=2e-2): |
|
betas = make_beta_schedule(timesteps, linear_start=linear_start, linear_end=linear_end) |
|
alphas = 1. - betas |
|
alphas_cumprod = np.cumprod(alphas, axis=0) |
|
alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1]) |
|
|
|
timesteps, = betas.shape |
|
self.num_timesteps = int(timesteps) |
|
self.linear_start = linear_start |
|
self.linear_end = linear_end |
|
assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep' |
|
|
|
to_torch = partial(torch.tensor, dtype=torch.float32) |
|
|
|
self.register_buffer('betas', to_torch(betas)) |
|
self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) |
|
self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev)) |
|
|
|
|
|
self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod))) |
|
self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod))) |
|
self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod))) |
|
self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod))) |
|
self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1))) |
|
|
|
def q_sample(self, x_start, t, noise=None): |
|
noise = default(noise, lambda: torch.randn_like(x_start)) |
|
return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + |
|
extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise) |
|
|
|
def forward(self, x): |
|
return x, None |
|
|
|
def decode(self, x): |
|
return x |
|
|
|
|
|
class SimpleImageConcat(AbstractLowScaleModel): |
|
|
|
def __init__(self): |
|
super(SimpleImageConcat, self).__init__(noise_schedule_config=None) |
|
self.max_noise_level = 0 |
|
|
|
def forward(self, x): |
|
|
|
return x, torch.zeros(x.shape[0], device=x.device).long() |
|
|
|
|
|
class ImageConcatWithNoiseAugmentation(AbstractLowScaleModel): |
|
def __init__(self, noise_schedule_config, max_noise_level=1000, to_cuda=False): |
|
super().__init__(noise_schedule_config=noise_schedule_config) |
|
self.max_noise_level = max_noise_level |
|
|
|
def forward(self, x, noise_level=None): |
|
if noise_level is None: |
|
noise_level = torch.randint(0, self.max_noise_level, (x.shape[0],), device=x.device).long() |
|
else: |
|
assert isinstance(noise_level, torch.Tensor) |
|
z = self.q_sample(x, noise_level) |
|
return z, noise_level |
|
|
|
|
|
|
|
|