File size: 3,993 Bytes
c8c2540
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import torch
import torch.nn as nn
import numpy as np
from diffusers import DDPMScheduler
#changes_start
import transformers

def normalize(x: torch.Tensor, dim=None, eps=1e-4) -> torch.Tensor:
    if dim is None:
        dim = list(range(1, x.ndim))
    norm = torch.linalg.vector_norm(x, dim=dim, keepdim=True, dtype=torch.float32) # type: torch.Tensor
    norm = torch.add(eps, norm, alpha=np.sqrt(norm.numel() / x.numel()))
    return x / norm.to(x.dtype)

class FourierFeatureExtractor(torch.nn.Module):
    def __init__(self, num_channels, bandwidth=1):
        super().__init__()
        self.register_buffer('freqs', 2 * np.pi * torch.randn(num_channels) * bandwidth)
        self.register_buffer('phases', 2 * np.pi * torch.rand(num_channels))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        y = x.to(torch.float32)
        y = y.ger(self.freqs.to(torch.float32))
        y = y + self.phases.to(torch.float32) # type: torch.Tensor
        y = y.cos() * np.sqrt(2)
        return y.to(x.dtype)

class NormalizedLinearLayer(torch.nn.Module):
    def __init__(self, in_channels, out_channels, kernel):
        super().__init__()
        self.out_channels = out_channels
        self.weight = torch.nn.Parameter(torch.randn(out_channels, in_channels, *kernel))

    def forward(self, x: torch.Tensor, gain=1) -> torch.Tensor:
        w = self.weight.to(torch.float32)
        if self.training:
            with torch.no_grad():
                self.weight.copy_(normalize(w)) # forced weight normalization
        w = normalize(w) # traditional weight normalization
        w = w * (gain / np.sqrt(w[0].numel())) # type: torch.Tensor # magnitude-preserving scaling
        w = w.to(x.dtype)
        if w.ndim == 2:
            return x @ w.t()
        assert w.ndim == 4
        return torch.nn.functional.conv2d(x, w, padding=(w.shape[-1]//2,))
    
class AdaptiveLossWeightMLP(nn.Module):
    def __init__(

            self,

            noise_scheduler: DDPMScheduler,

            logvar_channels=128,

            lambda_weights: torch.Tensor = None,

        ):
        super().__init__()
        self.alphas_cumprod = noise_scheduler.alphas_cumprod.cuda()
        #self.a_bar_mean = noise_scheduler.alphas_cumprod.mean()
        #self.a_bar_std = noise_scheduler.alphas_cumprod.std()
        self.a_bar_mean = self.alphas_cumprod.mean()
        self.a_bar_std = self.alphas_cumprod.std()
        self.logvar_fourier = FourierFeatureExtractor(logvar_channels)
        self.logvar_linear = NormalizedLinearLayer(logvar_channels, 1, kernel=[]) # kernel = []? (not in code given, added matching edm2)
        self.lambda_weights = lambda_weights.cuda() if lambda_weights is not None else torch.ones(1000, device='cuda')
        self.noise_scheduler = noise_scheduler

    def _forward(self, timesteps: torch.Tensor):
        #a_bar = self.noise_scheduler.alphas_cumprod[timesteps]
        a_bar = self.alphas_cumprod[timesteps]
        c_noise = a_bar.sub(self.a_bar_mean).div_(self.a_bar_std)
        return self.logvar_linear(self.logvar_fourier(c_noise)).squeeze()

    def forward(self, loss: torch.Tensor, timesteps):
        adaptive_loss_weights = self._forward(timesteps)
        loss_scaled = loss * (self.lambda_weights[timesteps] / torch.exp(adaptive_loss_weights)) # type: torch.Tensor
        loss = loss_scaled + adaptive_loss_weights # type: torch.Tensor

        return loss, loss_scaled
    
def create_weight_MLP(noise_scheduler, logvar_channels=128, lambda_weights=None):
    print("creating weight MLP")
    lossweightMLP = AdaptiveLossWeightMLP(noise_scheduler, logvar_channels, lambda_weights)
#    MLP_optim = torch.optim.AdamW(lossweightMLP.parameters(), lr=1e-2, weight_decay=0)
    MLP_optim =transformers.optimization.Adafactor(lossweightMLP.parameters(), lr=1e-2, scale_parameter=False, relative_step=False, warmup_init=False)
    return lossweightMLP, MLP_optim