File size: 5,622 Bytes
3c8ff2e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 |
import torch
import torch.nn as nn
from src import losses, model_utils
from fvcore.nn import FlopCountAnalysis
from fvcore.nn import flop_count_table
S2_BANDS = 13
class BaseModel(nn.Module):
def __init__(
self,
config
):
super(BaseModel, self).__init__()
self.config = config # store config
self.frozen = False # no parameters are frozen
self.len_epoch = 0 # steps of one epoch
# temporarily rescale model inputs & outputs by constant factor, e.g. from [0,1] to [0,100],
# to deal with numerical imprecision issues closeby 0 magnitude (and their inverses)
# --- convert inputs, mean & variance predictions to original scale again after NLL loss is computed
# note: this may also require adjusting the range of output nonlinearities in the generator network,
# i.e. out_mean, out_var and diag_var
# -------------- set input via set_input and call forward ---------------
# inputs self.real_A & self.real_B set in set_input by * self.scale_by
# ------------------------------ then scale -----------------------------
# output self.fake_B will automatically get scaled by ''
# ------------------- then compute loss via get_loss_G ------------------
# output self.netG.variance will automatically get scaled by * self.scale_by**2
# ----------------------------- then rescale ----------------------------
# inputs self.real_A & self.real_B set in set_input by * 1/self.scale_by
# output self.fake_B set in self.forward by * 1/self.scale_by
# output self.netG.variance set in get_loss_G by * 1/self.scale_by**2
self.scale_by = config.scale_by # temporarily rescale model inputs by constant factor, e.g. from [0,1] to [0,100]
# fetch generator
self.netG = model_utils.get_generator(self.config)
# 1 criterion
self.criterion = losses.get_loss(self.config)
self.log_vars = None
# 2 optimizer: for G
paramsG = [{'params': self.netG.parameters()}]
self.optimizer_G = torch.optim.Adam(paramsG, lr=config.lr)
# 2 scheduler: for G, note: stepping takes place at the end of epoch
self.scheduler_G = torch.optim.lr_scheduler.ExponentialLR(self.optimizer_G, gamma=self.config.gamma)
self.real_A = None
self.fake_B = None
self.real_B = None
self.dates = None
self.masks = None
self.netG.variance = None
def forward(self):
# forward through generator, note: for val/test splits,
# 'with torch.no_grad():' is declared in train script
self.fake_B = self.netG(self.real_A, batch_positions=self.dates)
if self.config.profile:
flopstats = FlopCountAnalysis(self.netG, (self.real_A, self.dates))
# print(flop_count_table(flopstats))
# TFLOPS: flopstats.total() *1e-12
# MFLOPS: flopstats.total() *1e-6
# compute MFLOPS per input sample
self.flops = (flopstats.total()*1e-6)/self.config.batch_size
print(f"MFLOP count: {self.flops}")
self.netG.variance = None # purge earlier variance prediction, re-compute via get_loss_G()
def backward_G(self):
# calculate generator loss
self.get_loss_G()
self.loss_G.backward()
def get_loss_G(self):
if hasattr(self.netG, 'vars_idx'):
self.loss_G, self.netG.variance = losses.calc_loss(self.criterion, self.config, self.fake_B[:, :, :self.netG.mean_idx, ...], self.real_B, var=self.fake_B[:, :, self.netG.mean_idx:self.netG.vars_idx, ...])
else: # used with all other models
self.loss_G, self.netG.variance = losses.calc_loss(self.criterion, self.config, self.fake_B[:, :, :S2_BANDS, ...], self.real_B, var=self.fake_B[:, :, S2_BANDS:, ...])
def set_input(self, input):
self.real_A = self.scale_by * input['A'].to(self.config.device)
self.real_B = self.scale_by * input['B'].to(self.config.device)
self.dates = None if input['dates'] is None else input['dates'].to(self.config.device)
self.masks = input['masks'].to(self.config.device)
def reset_input(self):
self.real_A = None
self.real_B = None
self.dates = None
self.masks = None
del self.real_A
del self.real_B
del self.dates
del self.masks
def rescale(self):
# rescale target and mean predictions
if hasattr(self, 'real_A'): self.real_A = 1/self.scale_by * self.real_A
self.real_B = 1/self.scale_by * self.real_B
self.fake_B = 1/self.scale_by * self.fake_B[:,:,:S2_BANDS,...]
# rescale (co)variances
if hasattr(self.netG, 'variance') and self.netG.variance is not None:
self.netG.variance = 1/self.scale_by**2 * self.netG.variance
def optimize_parameters(self):
self.forward()
del self.real_A
# update G
self.optimizer_G.zero_grad()
self.backward_G()
self.optimizer_G.step()
# re-scale inputs, predicted means, predicted variances, etc
self.rescale()
# resetting inputs after optimization saves memory
self.reset_input()
if self.netG.training:
self.fake_B = self.fake_B.cpu()
if self.netG.variance is not None: self.netG.variance = self.netG.variance.cpu() |