XavierJiezou's picture
Upload folder using huggingface_hub
3c8ff2e verified
import torch
import torch.nn as nn
from src import losses, model_utils
from fvcore.nn import FlopCountAnalysis
from fvcore.nn import flop_count_table
S2_BANDS = 13
class BaseModel(nn.Module):
def __init__(
self,
config
):
super(BaseModel, self).__init__()
self.config = config # store config
self.frozen = False # no parameters are frozen
self.len_epoch = 0 # steps of one epoch
# temporarily rescale model inputs & outputs by constant factor, e.g. from [0,1] to [0,100],
# to deal with numerical imprecision issues closeby 0 magnitude (and their inverses)
# --- convert inputs, mean & variance predictions to original scale again after NLL loss is computed
# note: this may also require adjusting the range of output nonlinearities in the generator network,
# i.e. out_mean, out_var and diag_var
# -------------- set input via set_input and call forward ---------------
# inputs self.real_A & self.real_B set in set_input by * self.scale_by
# ------------------------------ then scale -----------------------------
# output self.fake_B will automatically get scaled by ''
# ------------------- then compute loss via get_loss_G ------------------
# output self.netG.variance will automatically get scaled by * self.scale_by**2
# ----------------------------- then rescale ----------------------------
# inputs self.real_A & self.real_B set in set_input by * 1/self.scale_by
# output self.fake_B set in self.forward by * 1/self.scale_by
# output self.netG.variance set in get_loss_G by * 1/self.scale_by**2
self.scale_by = config.scale_by # temporarily rescale model inputs by constant factor, e.g. from [0,1] to [0,100]
# fetch generator
self.netG = model_utils.get_generator(self.config)
# 1 criterion
self.criterion = losses.get_loss(self.config)
self.log_vars = None
# 2 optimizer: for G
paramsG = [{'params': self.netG.parameters()}]
self.optimizer_G = torch.optim.Adam(paramsG, lr=config.lr)
# 2 scheduler: for G, note: stepping takes place at the end of epoch
self.scheduler_G = torch.optim.lr_scheduler.ExponentialLR(self.optimizer_G, gamma=self.config.gamma)
self.real_A = None
self.fake_B = None
self.real_B = None
self.dates = None
self.masks = None
self.netG.variance = None
def forward(self):
# forward through generator, note: for val/test splits,
# 'with torch.no_grad():' is declared in train script
self.fake_B = self.netG(self.real_A, batch_positions=self.dates)
if self.config.profile:
flopstats = FlopCountAnalysis(self.netG, (self.real_A, self.dates))
# print(flop_count_table(flopstats))
# TFLOPS: flopstats.total() *1e-12
# MFLOPS: flopstats.total() *1e-6
# compute MFLOPS per input sample
self.flops = (flopstats.total()*1e-6)/self.config.batch_size
print(f"MFLOP count: {self.flops}")
self.netG.variance = None # purge earlier variance prediction, re-compute via get_loss_G()
def backward_G(self):
# calculate generator loss
self.get_loss_G()
self.loss_G.backward()
def get_loss_G(self):
if hasattr(self.netG, 'vars_idx'):
self.loss_G, self.netG.variance = losses.calc_loss(self.criterion, self.config, self.fake_B[:, :, :self.netG.mean_idx, ...], self.real_B, var=self.fake_B[:, :, self.netG.mean_idx:self.netG.vars_idx, ...])
else: # used with all other models
self.loss_G, self.netG.variance = losses.calc_loss(self.criterion, self.config, self.fake_B[:, :, :S2_BANDS, ...], self.real_B, var=self.fake_B[:, :, S2_BANDS:, ...])
def set_input(self, input):
self.real_A = self.scale_by * input['A'].to(self.config.device)
self.real_B = self.scale_by * input['B'].to(self.config.device)
self.dates = None if input['dates'] is None else input['dates'].to(self.config.device)
self.masks = input['masks'].to(self.config.device)
def reset_input(self):
self.real_A = None
self.real_B = None
self.dates = None
self.masks = None
del self.real_A
del self.real_B
del self.dates
del self.masks
def rescale(self):
# rescale target and mean predictions
if hasattr(self, 'real_A'): self.real_A = 1/self.scale_by * self.real_A
self.real_B = 1/self.scale_by * self.real_B
self.fake_B = 1/self.scale_by * self.fake_B[:,:,:S2_BANDS,...]
# rescale (co)variances
if hasattr(self.netG, 'variance') and self.netG.variance is not None:
self.netG.variance = 1/self.scale_by**2 * self.netG.variance
def optimize_parameters(self):
self.forward()
del self.real_A
# update G
self.optimizer_G.zero_grad()
self.backward_G()
self.optimizer_G.step()
# re-scale inputs, predicted means, predicted variances, etc
self.rescale()
# resetting inputs after optimization saves memory
self.reset_input()
if self.netG.training:
self.fake_B = self.fake_B.cpu()
if self.netG.variance is not None: self.netG.variance = self.netG.variance.cpu()