kawaimasa's picture
Upload 5 files
c8c2540 verified
import torch
import torch.nn as nn
import numpy as np
from diffusers import DDPMScheduler
#changes_start
import transformers
def normalize(x: torch.Tensor, dim=None, eps=1e-4) -> torch.Tensor:
if dim is None:
dim = list(range(1, x.ndim))
norm = torch.linalg.vector_norm(x, dim=dim, keepdim=True, dtype=torch.float32) # type: torch.Tensor
norm = torch.add(eps, norm, alpha=np.sqrt(norm.numel() / x.numel()))
return x / norm.to(x.dtype)
class FourierFeatureExtractor(torch.nn.Module):
def __init__(self, num_channels, bandwidth=1):
super().__init__()
self.register_buffer('freqs', 2 * np.pi * torch.randn(num_channels) * bandwidth)
self.register_buffer('phases', 2 * np.pi * torch.rand(num_channels))
def forward(self, x: torch.Tensor) -> torch.Tensor:
y = x.to(torch.float32)
y = y.ger(self.freqs.to(torch.float32))
y = y + self.phases.to(torch.float32) # type: torch.Tensor
y = y.cos() * np.sqrt(2)
return y.to(x.dtype)
class NormalizedLinearLayer(torch.nn.Module):
def __init__(self, in_channels, out_channels, kernel):
super().__init__()
self.out_channels = out_channels
self.weight = torch.nn.Parameter(torch.randn(out_channels, in_channels, *kernel))
def forward(self, x: torch.Tensor, gain=1) -> torch.Tensor:
w = self.weight.to(torch.float32)
if self.training:
with torch.no_grad():
self.weight.copy_(normalize(w)) # forced weight normalization
w = normalize(w) # traditional weight normalization
w = w * (gain / np.sqrt(w[0].numel())) # type: torch.Tensor # magnitude-preserving scaling
w = w.to(x.dtype)
if w.ndim == 2:
return x @ w.t()
assert w.ndim == 4
return torch.nn.functional.conv2d(x, w, padding=(w.shape[-1]//2,))
class AdaptiveLossWeightMLP(nn.Module):
def __init__(
self,
noise_scheduler: DDPMScheduler,
logvar_channels=128,
lambda_weights: torch.Tensor = None,
):
super().__init__()
self.alphas_cumprod = noise_scheduler.alphas_cumprod.cuda()
#self.a_bar_mean = noise_scheduler.alphas_cumprod.mean()
#self.a_bar_std = noise_scheduler.alphas_cumprod.std()
self.a_bar_mean = self.alphas_cumprod.mean()
self.a_bar_std = self.alphas_cumprod.std()
self.logvar_fourier = FourierFeatureExtractor(logvar_channels)
self.logvar_linear = NormalizedLinearLayer(logvar_channels, 1, kernel=[]) # kernel = []? (not in code given, added matching edm2)
self.lambda_weights = lambda_weights.cuda() if lambda_weights is not None else torch.ones(1000, device='cuda')
self.noise_scheduler = noise_scheduler
def _forward(self, timesteps: torch.Tensor):
#a_bar = self.noise_scheduler.alphas_cumprod[timesteps]
a_bar = self.alphas_cumprod[timesteps]
c_noise = a_bar.sub(self.a_bar_mean).div_(self.a_bar_std)
return self.logvar_linear(self.logvar_fourier(c_noise)).squeeze()
def forward(self, loss: torch.Tensor, timesteps):
adaptive_loss_weights = self._forward(timesteps)
loss_scaled = loss * (self.lambda_weights[timesteps] / torch.exp(adaptive_loss_weights)) # type: torch.Tensor
loss = loss_scaled + adaptive_loss_weights # type: torch.Tensor
return loss, loss_scaled
def create_weight_MLP(noise_scheduler, logvar_channels=128, lambda_weights=None):
print("creating weight MLP")
lossweightMLP = AdaptiveLossWeightMLP(noise_scheduler, logvar_channels, lambda_weights)
# MLP_optim = torch.optim.AdamW(lossweightMLP.parameters(), lr=1e-2, weight_decay=0)
MLP_optim =transformers.optimization.Adafactor(lossweightMLP.parameters(), lr=1e-2, scale_parameter=False, relative_step=False, warmup_init=False)
return lossweightMLP, MLP_optim