|
import torch
|
|
import torch.nn as nn
|
|
import numpy as np
|
|
from diffusers import DDPMScheduler
|
|
|
|
import transformers
|
|
import math
|
|
|
|
|
|
def normalize(x: torch.Tensor, dim=None, eps=1e-4) -> torch.Tensor:
|
|
if dim is None:
|
|
dim = list(range(1, x.ndim))
|
|
norm = torch.linalg.vector_norm(x, dim=dim, keepdim=True, dtype=torch.float32)
|
|
norm = torch.add(eps, norm, alpha=np.sqrt(norm.numel() / x.numel()))
|
|
return x / norm.to(x.dtype)
|
|
|
|
class FourierFeatureExtractor(torch.nn.Module):
|
|
def __init__(self, num_channels, bandwidth=1):
|
|
super().__init__()
|
|
self.register_buffer('freqs', 2 * np.pi * torch.randn(num_channels) * bandwidth)
|
|
self.register_buffer('phases', 2 * np.pi * torch.rand(num_channels))
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
y = x.to(torch.float32)
|
|
y = y.ger(self.freqs.to(torch.float32))
|
|
y = y + self.phases.to(torch.float32)
|
|
y = y.cos() * np.sqrt(2)
|
|
return y.to(x.dtype)
|
|
|
|
class NormalizedLinearLayer(torch.nn.Module):
|
|
def __init__(self, in_channels, out_channels, kernel):
|
|
super().__init__()
|
|
self.out_channels = out_channels
|
|
self.weight = torch.nn.Parameter(torch.randn(out_channels, in_channels, *kernel))
|
|
|
|
def forward(self, x: torch.Tensor, gain=1) -> torch.Tensor:
|
|
w = self.weight.to(torch.float32)
|
|
if self.training:
|
|
with torch.no_grad():
|
|
self.weight.copy_(normalize(w))
|
|
w = normalize(w)
|
|
w = w * (gain / np.sqrt(w[0].numel()))
|
|
w = w.to(x.dtype)
|
|
if w.ndim == 2:
|
|
return x @ w.t()
|
|
assert w.ndim == 4
|
|
return torch.nn.functional.conv2d(x, w, padding=(w.shape[-1]//2,))
|
|
|
|
class AdaptiveLossWeightMLP(nn.Module):
|
|
def __init__(
|
|
self,
|
|
noise_scheduler: DDPMScheduler,
|
|
logvar_channels=128,
|
|
lambda_weights: torch.Tensor = None,
|
|
):
|
|
super().__init__()
|
|
self.alphas_cumprod = noise_scheduler.alphas_cumprod.cuda()
|
|
|
|
|
|
self.a_bar_mean = self.alphas_cumprod.mean()
|
|
self.a_bar_std = self.alphas_cumprod.std()
|
|
self.logvar_fourier = FourierFeatureExtractor(logvar_channels)
|
|
self.logvar_linear = NormalizedLinearLayer(logvar_channels, 1, kernel=[])
|
|
self.lambda_weights = lambda_weights.cuda() if lambda_weights is not None else torch.ones(1000, device='cuda')
|
|
self.noise_scheduler = noise_scheduler
|
|
|
|
def _forward(self, timesteps: torch.Tensor):
|
|
|
|
a_bar = self.alphas_cumprod[timesteps]
|
|
c_noise = a_bar.sub(self.a_bar_mean).div_(self.a_bar_std)
|
|
return self.logvar_linear(self.logvar_fourier(c_noise)).squeeze()
|
|
|
|
def forward(self, loss: torch.Tensor, timesteps):
|
|
adaptive_loss_weights = self._forward(timesteps)
|
|
loss_scaled = loss * (self.lambda_weights[timesteps] / torch.exp(adaptive_loss_weights))
|
|
loss = loss_scaled + adaptive_loss_weights
|
|
|
|
return loss, loss_scaled
|
|
|
|
def create_weight_MLP(noise_scheduler, logvar_channels=128, lambda_weights=None):
|
|
print("creating weight MLP")
|
|
lossweightMLP = AdaptiveLossWeightMLP(noise_scheduler, logvar_channels, lambda_weights)
|
|
MLP_optim = transformers.optimization.Adafactor(
|
|
lossweightMLP.parameters(),
|
|
lr=1.5e-2,
|
|
scale_parameter=False,
|
|
relative_step=False,
|
|
warmup_init=False
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
disable_scheduler = False
|
|
|
|
if disable_scheduler:
|
|
def lr_lambda(current_step: int):
|
|
return 1
|
|
else:
|
|
def lr_lambda(current_step: int):
|
|
warmup_steps = 100
|
|
constant_steps = 300
|
|
if current_step <= warmup_steps:
|
|
return current_step / max(1, warmup_steps)
|
|
else:
|
|
return 1 / math.sqrt(max(current_step / (warmup_steps + constant_steps), 1))
|
|
|
|
MLP_scheduler = torch.optim.lr_scheduler.LambdaLR(
|
|
optimizer=MLP_optim,
|
|
lr_lambda=lr_lambda
|
|
)
|
|
|
|
return lossweightMLP, MLP_optim, MLP_scheduler |