|
import abc |
|
|
|
import torch |
|
import torch.nn as nn |
|
|
|
|
|
torch._C._jit_set_profiling_mode(False) |
|
torch._C._jit_set_profiling_executor(False) |
|
torch._C._jit_override_can_fuse_on_cpu(True) |
|
torch._C._jit_override_can_fuse_on_gpu(True) |
|
|
|
|
|
def get_noise(config, dtype=torch.float32): |
|
return LogLinearNoise() |
|
|
|
if config.noise.type == 'geometric': |
|
return GeometricNoise(config.noise.sigma_min, |
|
config.noise.sigma_max) |
|
elif config.noise.type == 'loglinear': |
|
return LogLinearNoise() |
|
elif config.noise.type == 'cosine': |
|
return CosineNoise() |
|
elif config.noise.type == 'cosinesqr': |
|
return CosineSqrNoise() |
|
elif config.noise.type == 'linear': |
|
return Linear(config.noise.sigma_min, |
|
config.noise.sigma_max, |
|
dtype) |
|
else: |
|
raise ValueError(f'{config.noise.type} is not a valid noise') |
|
|
|
|
|
def binary_discretization(z): |
|
z_hard = torch.sign(z) |
|
z_soft = z / torch.norm(z, dim=-1, keepdim=True) |
|
return z_soft + (z_hard - z_soft).detach() |
|
|
|
|
|
class Noise(abc.ABC, nn.Module): |
|
""" |
|
Baseline forward method to get the total + rate of noise at a timestep |
|
""" |
|
def forward(self, t): |
|
|
|
return self.total_noise(t), self.rate_noise(t) |
|
|
|
@abc.abstractmethod |
|
def rate_noise(self, t): |
|
""" |
|
Rate of change of noise ie g(t) |
|
""" |
|
pass |
|
|
|
@abc.abstractmethod |
|
def total_noise(self, t): |
|
""" |
|
Total noise ie \int_0^t g(t) dt + g(0) |
|
""" |
|
pass |
|
|
|
|
|
class CosineNoise(Noise): |
|
def __init__(self, eps=1e-3): |
|
super().__init__() |
|
self.eps = eps |
|
|
|
def rate_noise(self, t): |
|
cos = (1 - self.eps) * torch.cos(t * torch.pi / 2) |
|
sin = (1 - self.eps) * torch.sin(t * torch.pi / 2) |
|
scale = torch.pi / 2 |
|
return scale * sin / (cos + self.eps) |
|
|
|
def total_noise(self, t): |
|
cos = torch.cos(t * torch.pi / 2) |
|
return - torch.log(self.eps + (1 - self.eps) * cos) |
|
|
|
|
|
class CosineSqrNoise(Noise): |
|
def __init__(self, eps=1e-3): |
|
super().__init__() |
|
self.eps = eps |
|
|
|
def rate_noise(self, t): |
|
cos = (1 - self.eps) * ( |
|
torch.cos(t * torch.pi / 2) ** 2) |
|
sin = (1 - self.eps) * torch.sin(t * torch.pi) |
|
scale = torch.pi / 2 |
|
return scale * sin / (cos + self.eps) |
|
|
|
def total_noise(self, t): |
|
cos = torch.cos(t * torch.pi / 2) ** 2 |
|
return - torch.log(self.eps + (1 - self.eps) * cos) |
|
|
|
|
|
class Linear(Noise): |
|
def __init__(self, sigma_min=0, sigma_max=10, dtype=torch.float32): |
|
super().__init__() |
|
self.sigma_min = torch.tensor(sigma_min, dtype=dtype) |
|
self.sigma_max = torch.tensor(sigma_max, dtype=dtype) |
|
|
|
def rate_noise(self, t): |
|
return self.sigma_max - self.sigma_min |
|
|
|
def total_noise(self, t): |
|
return self.sigma_min + t * (self.sigma_max - self.sigma_min) |
|
|
|
def importance_sampling_transformation(self, t): |
|
f_T = torch.log1p(- torch.exp(- self.sigma_max)) |
|
f_0 = torch.log1p(- torch.exp(- self.sigma_min)) |
|
sigma_t = - torch.log1p(- torch.exp(t * f_T + (1 - t) * f_0)) |
|
return (sigma_t - self.sigma_min) / ( |
|
self.sigma_max - self.sigma_min) |
|
|
|
|
|
class GeometricNoise(Noise): |
|
def __init__(self, sigma_min=1e-3, sigma_max=1): |
|
super().__init__() |
|
self.sigmas = 1.0 * torch.tensor([sigma_min, sigma_max]) |
|
|
|
def rate_noise(self, t): |
|
return self.sigmas[0] ** (1 - t) * self.sigmas[1] ** t * ( |
|
self.sigmas[1].log() - self.sigmas[0].log()) |
|
|
|
def total_noise(self, t): |
|
return self.sigmas[0] ** (1 - t) * self.sigmas[1] ** t |
|
|
|
|
|
class LogLinearNoise(Noise): |
|
"""Log Linear noise schedule. |
|
|
|
Built such that 1 - 1/e^(n(t)) interpolates between 0 and |
|
~1 when t varies from 0 to 1. Total noise is |
|
-log(1 - (1 - eps) * t), so the sigma will be |
|
(1 - eps) * t. |
|
""" |
|
def __init__(self, eps=1e-3): |
|
super().__init__() |
|
self.eps = eps |
|
self.sigma_max = self.total_noise(torch.tensor(1.0)) |
|
self.sigma_min = self.eps + self.total_noise(torch.tensor(0.0)) |
|
|
|
def rate_noise(self, t): |
|
return (1 - self.eps) / (1 - (1 - self.eps) * t) |
|
|
|
def total_noise(self, t): |
|
return -torch.log1p(-(1 - self.eps) * t) |
|
|
|
def importance_sampling_transformation(self, t): |
|
f_T = torch.log1p(- torch.exp(- self.sigma_max)) |
|
f_0 = torch.log1p(- torch.exp(- self.sigma_min)) |
|
sigma_t = - torch.log1p(- torch.exp(t * f_T + (1 - t) * f_0)) |
|
t = - torch.expm1(- sigma_t) / (1 - self.eps) |
|
return t |
|
|