kawaimasa's picture
Upload 2 files
39fb74e verified
import torch
import torch.nn as nn
import numpy as np
from diffusers import DDPMScheduler
#changes_start
import transformers
import math
#from adopt import ADOPT
def normalize(x: torch.Tensor, dim=None, eps=1e-4) -> torch.Tensor:
if dim is None:
dim = list(range(1, x.ndim))
norm = torch.linalg.vector_norm(x, dim=dim, keepdim=True, dtype=torch.float32) # type: torch.Tensor
norm = torch.add(eps, norm, alpha=np.sqrt(norm.numel() / x.numel()))
return x / norm.to(x.dtype)
class FourierFeatureExtractor(torch.nn.Module):
def __init__(self, num_channels, bandwidth=1):
super().__init__()
self.register_buffer('freqs', 2 * np.pi * torch.randn(num_channels) * bandwidth)
self.register_buffer('phases', 2 * np.pi * torch.rand(num_channels))
def forward(self, x: torch.Tensor) -> torch.Tensor:
y = x.to(torch.float32)
y = y.ger(self.freqs.to(torch.float32))
y = y + self.phases.to(torch.float32) # type: torch.Tensor
y = y.cos() * np.sqrt(2)
return y.to(x.dtype)
class NormalizedLinearLayer(torch.nn.Module):
def __init__(self, in_channels, out_channels, kernel):
super().__init__()
self.out_channels = out_channels
self.weight = torch.nn.Parameter(torch.randn(out_channels, in_channels, *kernel))
def forward(self, x: torch.Tensor, gain=1) -> torch.Tensor:
w = self.weight.to(torch.float32)
if self.training:
with torch.no_grad():
self.weight.copy_(normalize(w)) # forced weight normalization
w = normalize(w) # traditional weight normalization
w = w * (gain / np.sqrt(w[0].numel())) # type: torch.Tensor # magnitude-preserving scaling
w = w.to(x.dtype)
if w.ndim == 2:
return x @ w.t()
assert w.ndim == 4
return torch.nn.functional.conv2d(x, w, padding=(w.shape[-1]//2,))
class AdaptiveLossWeightMLP(nn.Module):
def __init__(
self,
noise_scheduler: DDPMScheduler,
logvar_channels=128,
lambda_weights: torch.Tensor = None,
):
super().__init__()
self.alphas_cumprod = noise_scheduler.alphas_cumprod.cuda()
#self.a_bar_mean = noise_scheduler.alphas_cumprod.mean()
#self.a_bar_std = noise_scheduler.alphas_cumprod.std()
self.a_bar_mean = self.alphas_cumprod.mean()
self.a_bar_std = self.alphas_cumprod.std()
self.logvar_fourier = FourierFeatureExtractor(logvar_channels)
self.logvar_linear = NormalizedLinearLayer(logvar_channels, 1, kernel=[]) # kernel = []? (not in code given, added matching edm2)
self.lambda_weights = lambda_weights.cuda() if lambda_weights is not None else torch.ones(1000, device='cuda')
self.noise_scheduler = noise_scheduler
def _forward(self, timesteps: torch.Tensor):
#a_bar = self.noise_scheduler.alphas_cumprod[timesteps]
a_bar = self.alphas_cumprod[timesteps]
c_noise = a_bar.sub(self.a_bar_mean).div_(self.a_bar_std)
return self.logvar_linear(self.logvar_fourier(c_noise)).squeeze()
def forward(self, loss: torch.Tensor, timesteps):
adaptive_loss_weights = self._forward(timesteps)
loss_scaled = loss * (self.lambda_weights[timesteps] / torch.exp(adaptive_loss_weights)) # type: torch.Tensor
loss = loss_scaled + adaptive_loss_weights # type: torch.Tensor
return loss, loss_scaled
def create_weight_MLP(noise_scheduler, logvar_channels=128, lambda_weights=None):
print("creating weight MLP")
lossweightMLP = AdaptiveLossWeightMLP(noise_scheduler, logvar_channels, lambda_weights)
MLP_optim = transformers.optimization.Adafactor(
lossweightMLP.parameters(),
lr=1.5e-2,
scale_parameter=False,
relative_step=False,
warmup_init=False
)
# MLP_optim = torch.optim.AdamW(lossweightMLP.parameters(), lr=1.5e-2, weight_decay=0)
# MLP_optim = lion_pytorch.Lion(lossweightMLP.parameters(), lr=2e-2)
# MLP_optim = ADOPT(lossweightMLP.parameters(), lr=1.5e-2)
# def lr_lambda(current_step: int):
# warmup_steps = 100
# constant_steps = 300
# if current_step <= warmup_steps:
# return current_step / max(1, warmup_steps)
# else:
# return 1/math.sqrt(max(current_step/(warmup_steps+constant_steps), 1))
#
# MLP_scheduler = torch.optim.lr_scheduler.LambdaLR(
# optimizer=MLP_optim,
# lr_lambda=lr_lambda
# )
disable_scheduler = False # スケジューラーを無効化
if disable_scheduler:
def lr_lambda(current_step: int):
return 1 # const
else:
def lr_lambda(current_step: int):
warmup_steps = 100
constant_steps = 300
if current_step <= warmup_steps:
return current_step / max(1, warmup_steps)
else:
return 1 / math.sqrt(max(current_step / (warmup_steps + constant_steps), 1))
MLP_scheduler = torch.optim.lr_scheduler.LambdaLR(
optimizer=MLP_optim,
lr_lambda=lr_lambda
)
return lossweightMLP, MLP_optim, MLP_scheduler