File size: 5,347 Bytes
39fb74e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import torch
import torch.nn as nn
import numpy as np
from diffusers import DDPMScheduler
#changes_start
import transformers
import math
#from adopt import ADOPT

def normalize(x: torch.Tensor, dim=None, eps=1e-4) -> torch.Tensor:
    if dim is None:
        dim = list(range(1, x.ndim))
    norm = torch.linalg.vector_norm(x, dim=dim, keepdim=True, dtype=torch.float32) # type: torch.Tensor
    norm = torch.add(eps, norm, alpha=np.sqrt(norm.numel() / x.numel()))
    return x / norm.to(x.dtype)

class FourierFeatureExtractor(torch.nn.Module):
    def __init__(self, num_channels, bandwidth=1):
        super().__init__()
        self.register_buffer('freqs', 2 * np.pi * torch.randn(num_channels) * bandwidth)
        self.register_buffer('phases', 2 * np.pi * torch.rand(num_channels))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        y = x.to(torch.float32)
        y = y.ger(self.freqs.to(torch.float32))
        y = y + self.phases.to(torch.float32) # type: torch.Tensor
        y = y.cos() * np.sqrt(2)
        return y.to(x.dtype)

class NormalizedLinearLayer(torch.nn.Module):
    def __init__(self, in_channels, out_channels, kernel):
        super().__init__()
        self.out_channels = out_channels
        self.weight = torch.nn.Parameter(torch.randn(out_channels, in_channels, *kernel))

    def forward(self, x: torch.Tensor, gain=1) -> torch.Tensor:
        w = self.weight.to(torch.float32)
        if self.training:
            with torch.no_grad():
                self.weight.copy_(normalize(w)) # forced weight normalization
        w = normalize(w) # traditional weight normalization
        w = w * (gain / np.sqrt(w[0].numel())) # type: torch.Tensor # magnitude-preserving scaling
        w = w.to(x.dtype)
        if w.ndim == 2:
            return x @ w.t()
        assert w.ndim == 4
        return torch.nn.functional.conv2d(x, w, padding=(w.shape[-1]//2,))
    
class AdaptiveLossWeightMLP(nn.Module):
    def __init__(

            self,

            noise_scheduler: DDPMScheduler,

            logvar_channels=128,

            lambda_weights: torch.Tensor = None,

        ):
        super().__init__()
        self.alphas_cumprod = noise_scheduler.alphas_cumprod.cuda()
        #self.a_bar_mean = noise_scheduler.alphas_cumprod.mean()
        #self.a_bar_std = noise_scheduler.alphas_cumprod.std()
        self.a_bar_mean = self.alphas_cumprod.mean()
        self.a_bar_std = self.alphas_cumprod.std()
        self.logvar_fourier = FourierFeatureExtractor(logvar_channels)
        self.logvar_linear = NormalizedLinearLayer(logvar_channels, 1, kernel=[]) # kernel = []? (not in code given, added matching edm2)
        self.lambda_weights = lambda_weights.cuda() if lambda_weights is not None else torch.ones(1000, device='cuda')
        self.noise_scheduler = noise_scheduler

    def _forward(self, timesteps: torch.Tensor):
        #a_bar = self.noise_scheduler.alphas_cumprod[timesteps]
        a_bar = self.alphas_cumprod[timesteps]
        c_noise = a_bar.sub(self.a_bar_mean).div_(self.a_bar_std)
        return self.logvar_linear(self.logvar_fourier(c_noise)).squeeze()

    def forward(self, loss: torch.Tensor, timesteps):
        adaptive_loss_weights = self._forward(timesteps)
        loss_scaled = loss * (self.lambda_weights[timesteps] / torch.exp(adaptive_loss_weights)) # type: torch.Tensor
        loss = loss_scaled + adaptive_loss_weights # type: torch.Tensor

        return loss, loss_scaled
    
def create_weight_MLP(noise_scheduler, logvar_channels=128, lambda_weights=None):
    print("creating weight MLP")
    lossweightMLP = AdaptiveLossWeightMLP(noise_scheduler, logvar_channels, lambda_weights)
    MLP_optim = transformers.optimization.Adafactor(
        lossweightMLP.parameters(), 
        lr=1.5e-2,  
        scale_parameter=False, 
        relative_step=False, 
        warmup_init=False
    )
#    MLP_optim = torch.optim.AdamW(lossweightMLP.parameters(), lr=1.5e-2, weight_decay=0)
#    MLP_optim = lion_pytorch.Lion(lossweightMLP.parameters(), lr=2e-2)
#    MLP_optim = ADOPT(lossweightMLP.parameters(), lr=1.5e-2)

    
#     def lr_lambda(current_step: int):
#         warmup_steps = 100
#         constant_steps = 300
#         if current_step <= warmup_steps:
#             return current_step / max(1, warmup_steps)
#         else:
#             return 1/math.sqrt(max(current_step/(warmup_steps+constant_steps), 1))
#             
#     MLP_scheduler = torch.optim.lr_scheduler.LambdaLR(
#         optimizer=MLP_optim,
#         lr_lambda=lr_lambda
#     )
    disable_scheduler = False  # スケジューラーを無効化

    if disable_scheduler:
        def lr_lambda(current_step: int):
            return 1  # const
    else:
        def lr_lambda(current_step: int):
            warmup_steps = 100
            constant_steps = 300
            if current_step <= warmup_steps:
                return current_step / max(1, warmup_steps)
            else:
                return 1 / math.sqrt(max(current_step / (warmup_steps + constant_steps), 1))

    MLP_scheduler = torch.optim.lr_scheduler.LambdaLR(
        optimizer=MLP_optim,
        lr_lambda=lr_lambda
    )
    
    return lossweightMLP, MLP_optim, MLP_scheduler