File size: 5,622 Bytes
3c8ff2e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import torch
import torch.nn as nn

from src import losses, model_utils
from fvcore.nn import FlopCountAnalysis
from fvcore.nn import flop_count_table

S2_BANDS = 13

class BaseModel(nn.Module):
    def __init__(
        self,
        config
    ):
        super(BaseModel, self).__init__()
        self.config     = config    # store config
        self.frozen     = False     # no parameters are frozen
        self.len_epoch  = 0         # steps of one epoch

        # temporarily rescale model inputs & outputs by constant factor, e.g. from [0,1] to [0,100],
        #       to deal with numerical imprecision issues closeby 0 magnitude (and their inverses)
        #       --- convert inputs, mean & variance predictions to original scale again after NLL loss is computed
        # note: this may also require adjusting the range of output nonlinearities in the generator network,
        #       i.e. out_mean, out_var and diag_var

        #   -------------- set input via set_input and call forward ---------------
        # inputs self.real_A & self.real_B  set in set_input            by * self.scale_by
        #   ------------------------------ then scale -----------------------------
        # output self.fake_B will automatically get scaled              by ''
        #   ------------------- then compute loss via get_loss_G ------------------
        # output self.netG.variance  will automatically get scaled      by * self.scale_by**2
        #   ----------------------------- then rescale ----------------------------
        # inputs self.real_A & self.real_B  set in set_input            by * 1/self.scale_by
        # output self.fake_B                set in self.forward         by * 1/self.scale_by
        # output self.netG.variance         set in get_loss_G           by * 1/self.scale_by**2
        self.scale_by  = config.scale_by                    # temporarily rescale model inputs by constant factor, e.g. from [0,1] to [0,100]

        # fetch generator
        self.netG = model_utils.get_generator(self.config)

        # 1 criterion
        self.criterion = losses.get_loss(self.config)
        self.log_vars  = None

        # 2 optimizer: for G
        paramsG = [{'params': self.netG.parameters()}]

        self.optimizer_G = torch.optim.Adam(paramsG, lr=config.lr)

        # 2 scheduler: for G, note: stepping takes place at the end of epoch
        self.scheduler_G = torch.optim.lr_scheduler.ExponentialLR(self.optimizer_G, gamma=self.config.gamma)

        self.real_A = None
        self.fake_B = None
        self.real_B = None
        self.dates  = None
        self.masks  = None
        self.netG.variance = None

    def forward(self):
        # forward through generator, note: for val/test splits, 
        # 'with torch.no_grad():' is declared in train script
        self.fake_B = self.netG(self.real_A, batch_positions=self.dates)
        if self.config.profile: 
            flopstats  = FlopCountAnalysis(self.netG, (self.real_A, self.dates))
            # print(flop_count_table(flopstats))
            # TFLOPS: flopstats.total() *1e-12
            # MFLOPS: flopstats.total() *1e-6
            # compute MFLOPS per input sample
            self.flops = (flopstats.total()*1e-6)/self.config.batch_size
            print(f"MFLOP count: {self.flops}")
        self.netG.variance = None # purge earlier variance prediction, re-compute via get_loss_G()

    def backward_G(self):
        # calculate generator loss
        self.get_loss_G()
        self.loss_G.backward()


    def get_loss_G(self):

        if hasattr(self.netG, 'vars_idx'):
            self.loss_G, self.netG.variance = losses.calc_loss(self.criterion, self.config, self.fake_B[:, :, :self.netG.mean_idx, ...], self.real_B, var=self.fake_B[:, :, self.netG.mean_idx:self.netG.vars_idx, ...])
        else: # used with all other models
            self.loss_G, self.netG.variance = losses.calc_loss(self.criterion, self.config, self.fake_B[:, :, :S2_BANDS, ...], self.real_B, var=self.fake_B[:, :, S2_BANDS:, ...])

    def set_input(self, input):
        self.real_A = self.scale_by * input['A'].to(self.config.device)
        self.real_B = self.scale_by * input['B'].to(self.config.device)
        self.dates  = None if input['dates'] is None else input['dates'].to(self.config.device)
        self.masks  = input['masks'].to(self.config.device)


    def reset_input(self):
        self.real_A = None
        self.real_B = None
        self.dates  = None 
        self.masks  = None
        del self.real_A
        del self.real_B 
        del self.dates
        del self.masks


    def rescale(self):
        # rescale target and mean predictions
        if hasattr(self, 'real_A'): self.real_A = 1/self.scale_by * self.real_A
        self.real_B = 1/self.scale_by * self.real_B 
        self.fake_B = 1/self.scale_by * self.fake_B[:,:,:S2_BANDS,...]
        
        # rescale (co)variances
        if hasattr(self.netG, 'variance') and self.netG.variance is not None:
            self.netG.variance = 1/self.scale_by**2 * self.netG.variance

    def optimize_parameters(self):
        self.forward()
        del self.real_A

        # update G
        self.optimizer_G.zero_grad() 
        self.backward_G()
        self.optimizer_G.step()

        # re-scale inputs, predicted means, predicted variances, etc
        self.rescale()
        # resetting inputs after optimization saves memory
        self.reset_input()

        if self.netG.training: 
            self.fake_B = self.fake_B.cpu()
            if self.netG.variance is not None: self.netG.variance = self.netG.variance.cpu()