|
import torch |
|
import torch.nn as nn |
|
|
|
from src import losses, model_utils |
|
from fvcore.nn import FlopCountAnalysis |
|
from fvcore.nn import flop_count_table |
|
|
|
S2_BANDS = 13 |
|
|
|
class BaseModel(nn.Module): |
|
def __init__( |
|
self, |
|
config |
|
): |
|
super(BaseModel, self).__init__() |
|
self.config = config |
|
self.frozen = False |
|
self.len_epoch = 0 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.scale_by = config.scale_by |
|
|
|
|
|
self.netG = model_utils.get_generator(self.config) |
|
|
|
|
|
self.criterion = losses.get_loss(self.config) |
|
self.log_vars = None |
|
|
|
|
|
paramsG = [{'params': self.netG.parameters()}] |
|
|
|
self.optimizer_G = torch.optim.Adam(paramsG, lr=config.lr) |
|
|
|
|
|
self.scheduler_G = torch.optim.lr_scheduler.ExponentialLR(self.optimizer_G, gamma=self.config.gamma) |
|
|
|
self.real_A = None |
|
self.fake_B = None |
|
self.real_B = None |
|
self.dates = None |
|
self.masks = None |
|
self.netG.variance = None |
|
|
|
def forward(self): |
|
|
|
|
|
self.fake_B = self.netG(self.real_A, batch_positions=self.dates) |
|
if self.config.profile: |
|
flopstats = FlopCountAnalysis(self.netG, (self.real_A, self.dates)) |
|
|
|
|
|
|
|
|
|
self.flops = (flopstats.total()*1e-6)/self.config.batch_size |
|
print(f"MFLOP count: {self.flops}") |
|
self.netG.variance = None |
|
|
|
def backward_G(self): |
|
|
|
self.get_loss_G() |
|
self.loss_G.backward() |
|
|
|
|
|
def get_loss_G(self): |
|
|
|
if hasattr(self.netG, 'vars_idx'): |
|
self.loss_G, self.netG.variance = losses.calc_loss(self.criterion, self.config, self.fake_B[:, :, :self.netG.mean_idx, ...], self.real_B, var=self.fake_B[:, :, self.netG.mean_idx:self.netG.vars_idx, ...]) |
|
else: |
|
self.loss_G, self.netG.variance = losses.calc_loss(self.criterion, self.config, self.fake_B[:, :, :S2_BANDS, ...], self.real_B, var=self.fake_B[:, :, S2_BANDS:, ...]) |
|
|
|
def set_input(self, input): |
|
self.real_A = self.scale_by * input['A'].to(self.config.device) |
|
self.real_B = self.scale_by * input['B'].to(self.config.device) |
|
self.dates = None if input['dates'] is None else input['dates'].to(self.config.device) |
|
self.masks = input['masks'].to(self.config.device) |
|
|
|
|
|
def reset_input(self): |
|
self.real_A = None |
|
self.real_B = None |
|
self.dates = None |
|
self.masks = None |
|
del self.real_A |
|
del self.real_B |
|
del self.dates |
|
del self.masks |
|
|
|
|
|
def rescale(self): |
|
|
|
if hasattr(self, 'real_A'): self.real_A = 1/self.scale_by * self.real_A |
|
self.real_B = 1/self.scale_by * self.real_B |
|
self.fake_B = 1/self.scale_by * self.fake_B[:,:,:S2_BANDS,...] |
|
|
|
|
|
if hasattr(self.netG, 'variance') and self.netG.variance is not None: |
|
self.netG.variance = 1/self.scale_by**2 * self.netG.variance |
|
|
|
def optimize_parameters(self): |
|
self.forward() |
|
del self.real_A |
|
|
|
|
|
self.optimizer_G.zero_grad() |
|
self.backward_G() |
|
self.optimizer_G.step() |
|
|
|
|
|
self.rescale() |
|
|
|
self.reset_input() |
|
|
|
if self.netG.training: |
|
self.fake_B = self.fake_B.cpu() |
|
if self.netG.variance is not None: self.netG.variance = self.netG.variance.cpu() |