""" Main script for image reconstruction experiments Author: Patrick Ebel (github/PatrickTUM), based on the scripts of Vivien Sainte Fare Garnot (github/VSainteuf) License: MIT """ import os import sys import time import json import random import pprint import argparse import numpy as np from tqdm import tqdm from matplotlib import pyplot as plt dirname = os.path.dirname(os.path.abspath(__file__)) sys.path.append(os.path.dirname(dirname)) from parse_args import create_parser from data.dataLoader import SEN12MSCR from src.model_utils import ( get_model, save_model, freeze_layers, load_model, load_checkpoint, ) from src.learning.metrics import img_metrics, avg_img_metrics import torch import torchnet as tnt from torch.utils.tensorboard import SummaryWriter from src import utils, losses from src.learning.weight_init import weight_init S2_BANDS = 13 parser = create_parser(mode="train") config = utils.str2list( parser.parse_args(), list_args=["encoder_widths", "decoder_widths", "out_conv"] ) if config.model in ["unet", "utae"]: assert len(config.encoder_widths) == len(config.decoder_widths) config.loss = "l2" if config.model == "unet": # train U-Net from scratch config.pretrain = True config.trained_checkp = "" if config.pretrain: # pre-training is on a single time point config.input_t = config.n_head = 1 config.sample_type = "pretrain" if config.model == "unet": config.batch_size = 32 config.positional_encoding = False if config.loss in ["GNLL", "MGNLL"]: # for univariate losses, default to univariate mode (batched across channels) if config.loss in ["GNLL"]: config.covmode = "uni" if config.covmode == "iso": config.out_conv[-1] += 1 elif config.covmode in ["uni", "diag"]: config.out_conv[-1] += S2_BANDS config.var_nonLinearity = "softplus" # grab the PID so we can look it up in the logged config for server-side process management config.pid = os.getpid() # import & re-load a previous configuration, e.g. to resume training if config.resume_from: load_conf = os.path.join(config.res_dir, config.experiment_name, "conf.json") if config.experiment_name != config.trained_checkp.split("/")[-2]: raise ValueError("Mismatch of loaded config file and checkpoints") with open(load_conf, "rt") as f: t_args = argparse.Namespace() # do not overwrite the following flags by their respective values in the config file no_overwrite = [ "pid", "num_workers", "root1", "root2", "root3", "resume_from", "trained_checkp", "epochs", "encoder_widths", "decoder_widths", "lr", ] conf_dict = { key: val for key, val in json.load(f).items() if key not in no_overwrite } for key, val in vars(config).items(): if key in no_overwrite: conf_dict[key] = val t_args.__dict__.update(conf_dict) config = parser.parse_args(namespace=t_args) config = utils.str2list( config, list_args=["encoder_widths", "decoder_widths", "out_conv"] ) # resume at a specified epoch and update optimizer accordingly if config.resume_at >= 0: config.lr = config.lr * config.gamma**config.resume_at # fix all RNG seeds, # throw the whole bunch at 'em def seed_packages(seed): random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) os.environ["PYTHONHASHSEED"] = str(seed) # torch.use_deterministic_algorithms(True, warn_only=True) torch.backends.cudnn.benchmark = False def seed_worker(worker_id): worker_seed = torch.initial_seed() % 2**32 np.random.seed(worker_seed) random.seed(worker_seed) # seed everything seed_packages(config.rdm_seed) # seed generators for train & val/test dataloaders f, g = torch.Generator(), torch.Generator() f.manual_seed(config.rdm_seed + 0) # note: this may get re-seeded each epoch g.manual_seed(config.rdm_seed) # keep this one fixed if __name__ == "__main__": pprint.pprint(config) # instantiate tensorboard logger writer = SummaryWriter( os.path.join(os.path.dirname(config.res_dir), "logs", config.experiment_name) ) def plot_img(imgs, mod, plot_dir, file_id=None): if not os.path.exists(plot_dir): os.makedirs(plot_dir) try: imgs = imgs.cpu().numpy() for tdx, img in enumerate(imgs): # iterate over temporal dimension time = "" if imgs.shape[0] == 1 else f"_t-{tdx}" if mod in ["pred", "in", "target", "s2"]: rgb = [3, 2, 1] if img.shape[0] == S2_BANDS else [5, 4, 3] img, val_min, val_max = img[rgb, ...], 0, 1 elif mod == "s1": img, val_min, val_max = img[[0], ...], 0, 1 elif mod == "mask": img, val_min, val_max = img[[0], ...], 0, 1 elif mod == "err": img, val_min, val_max = img[[0], ...], 0, 0.01 elif mod == "var": img, val_min, val_max = img[[0], ...], 0, 0.000025 else: raise NotImplementedError if file_id is not None: # export into file name img = img.clip( val_min, val_max ) # note: this only removes outliers, vmin/vmax below do the global rescaling (else doing instance-wise min/max scaling) plt.imsave( os.path.join(plot_dir, f"img-{file_id}_{mod}{time}.png"), np.moveaxis(img, 0, -1).squeeze(), dpi=100, cmap="gray", vmin=val_min, vmax=val_max, ) except: if isinstance(imgs, plt.Figure): # the passed argument is a pre-rendered figure plt.savefig(os.path.join(plot_dir, f"img-{file_id}_{mod}.png"), dpi=100) else: raise NotImplementedError def export(arrs, mod, export_dir, file_id=None): if not os.path.exists(export_dir): os.makedirs(export_dir) for tdx, arr in enumerate(arrs): # iterate over temporal dimension num = "" if arrs.shape[0] == 1 else f"_t-{tdx}" np.save(os.path.join(export_dir, f"img-{file_id}_{mod}{num}.npy"), arr.cpu()) def prepare_data(batch, device, config): if config.pretrain: return prepare_data_mono(batch, device, config) else: return prepare_data_multi(batch, device, config) def prepare_data_mono(batch, device, config): x = batch["input"]["S2"].to(device).unsqueeze(1) if config.use_sar: x = torch.cat((batch["input"]["S1"].to(device).unsqueeze(1), x), dim=2) m = batch["input"]["masks"].to(device).unsqueeze(1) y = batch["target"]["S2"].to(device).unsqueeze(1) return x, y, m def prepare_data_multi(batch, device, config): in_S2 = recursive_todevice(batch["input"]["S2"], device) in_S2_td = recursive_todevice(batch["input"]["S2 TD"], device) if config.batch_size > 1: in_S2_td = torch.stack((in_S2_td)).T in_m = torch.stack(recursive_todevice(batch["input"]["masks"], device)).swapaxes( 0, 1 ) target_S2 = recursive_todevice(batch["target"]["S2"], device) y = torch.cat(target_S2, dim=0).unsqueeze(1) if config.use_sar: in_S1 = recursive_todevice(batch["input"]["S1"], device) in_S1_td = recursive_todevice(batch["input"]["S1 TD"], device) if config.batch_size > 1: in_S1_td = torch.stack((in_S1_td)).T x = torch.cat((torch.stack(in_S1, dim=1), torch.stack(in_S2, dim=1)), dim=2) dates = ( torch.stack((torch.tensor(in_S1_td), torch.tensor(in_S2_td))) .float() .mean(dim=0) .to(device) ) else: x = torch.stack(in_S2, dim=1) dates = torch.tensor(in_S2_td).float().to(device) return x, y, in_m, dates def log_aleatoric(writer, config, mode, step, var, name, img_meter=None): # if var is of shape [B x 1 x C x C x H x W] then it's a covariance tensor if len(var.shape) > 5: covar = var # get [B x 1 x C x H x W] variance tensor var = var.diagonal(dim1=2, dim2=3).moveaxis(-1, 2) # compute spatial-average to visualize patch-wise covariance matrices patch_covmat = covar.mean(dim=-1).mean(dim=-1).squeeze(dim=1) for bdx, img in enumerate(patch_covmat): # iterate over [B x C x C] covmats img = img.detach().numpy() max_abs = max(abs(img.min()), abs(img.max())) scale_rel_left, scale_rel_right = -max_abs, +max_abs fig = continuous_matshow(img, min=scale_rel_left, max=scale_rel_right) writer.add_figure(f"Img/{mode}/patch covmat relative {bdx}", fig, step) scale_center0_absolute = ( 1 / 4 * 1**2 ) # assuming covmat has been rescaled already, this is an upper bound fig = continuous_matshow( img, min=-scale_center0_absolute, max=scale_center0_absolute ) writer.add_figure(f"Img/{mode}/patch covmat absolute {bdx}", fig, step) # aleatoric uncertainty: comput during train, val and test # note: the quantile statistics are computed solely over the variances (and would be much different if involving covariances, e.g. in the isotopic case) avg_var = torch.mean( var, dim=2, keepdim=True ) # avg over bands, note: this only considers variances (else diag COV's avg would be tiny) q50 = ( avg_var[:, 0, ...].view(avg_var.shape[0], -1).median(dim=-1)[0].detach().clone() ) q75 = ( avg_var[:, 0, ...] .view(avg_var.shape[0], -1) .quantile(0.75, dim=-1) .detach() .clone() ) q50, q75 = q50[0], q75[0] # take batch's first item as a summary binning = 256 # see: https://pytorch.org/docs/stable/tensorboard.html#torch.utils.tensorboard.writer.SummaryWriter.add_histogram if config.loss in ["GNLL", "MGNLL"]: writer.add_image( f"Img/{mode}/{name}aleatoric [0,1]", avg_var[0, 0, ...].clip(0, 1), step, dataformats="CHW", ) # map image to [0, 1] writer.add_image( f"Img/{mode}/{name}aleatoric [0,q75]", avg_var[0, 0, ...].clip(0.0, q75) / q75, step, dataformats="CHW", ) # map image to [0, q75] writer.add_histogram( f"Hist/{mode}/{name}aleatoric", avg_var[0, 0, ...].flatten().clip(0, 1), step, bins=binning, max_bins=binning, ) else: raise NotImplementedError writer.add_scalar(f"{mode}/{name}aleatoric median all", q50, step) writer.add_scalar(f"{mode}/{name}aleatoric q75 all", q75, step) if img_meter is not None: writer.add_scalar(f"{mode}/{name}UCE SE", img_meter.value()["UCE SE"], step) writer.add_scalar(f"{mode}/{name}AUCE SE", img_meter.value()["AUCE SE"], step) def log_train(writer, config, model, step, x, out, y, in_m, name="", var=None): # logged loss is before rescaling by learning rate _, loss = model.criterion, model.loss_G.cpu() if name != "": name = f"model_{name}/" writer.add_scalar(f"train/{name}{config.loss}", loss, step) writer.add_scalar(f"train/{name}total", loss, step) # use add_images for batch-wise adding across temporal dimension if config.use_sar: writer.add_image( f"Img/train/{name}in_s1", x[0, :, [0], ...], step, dataformats="NCHW" ) writer.add_image( f"Img/train/{name}in_s2", x[0, :, [5, 4, 3], ...], step, dataformats="NCHW" ) else: writer.add_image( f"Img/train/{name}in_s2", x[0, :, [3, 2, 1], ...], step, dataformats="NCHW" ) writer.add_image( f"Img/train/{name}out", out[0, 0, [3, 2, 1], ...], step, dataformats="CHW" ) writer.add_image( f"Img/train/{name}y", y[0, 0, [3, 2, 1], ...], step, dataformats="CHW" ) writer.add_image( f"Img/train/{name}m", in_m[0, :, None, ...], step, dataformats="NCHW" ) # analyse cloud coverage # covered at ALL time points (AND) or covered at ANY time points (OR) # and_m, or_m = torch.prod(in_m[0,:, ...], dim=0, keepdim=True), torch.sum(in_m[0,:, ...], dim=0, keepdim=True).clip(0,1) and_m, or_m = torch.prod(in_m, dim=1, keepdim=True), torch.sum( in_m, dim=1, keepdim=True ).clip(0, 1) writer.add_scalar(f"train/{name}OR m %", or_m.float().mean(), step) writer.add_scalar(f"train/{name}AND m %", and_m.float().mean(), step) writer.add_image(f"Img/train/{name}AND m", and_m, step, dataformats="NCHW") writer.add_image(f"Img/train/{name}OR m", or_m, step, dataformats="NCHW") and_m_gray = in_m.float().mean(axis=1).cpu() for bdx, img in enumerate(and_m_gray): fig = discrete_matshow(img, n_colors=config.input_t) writer.add_figure(f"Img/train/temp overlay m {bdx}", fig, step) if var is not None: # log aleatoric uncertainty statistics, excluding computation of ECE log_aleatoric(writer, config, "train", step, var, name, img_meter=None) def discrete_matshow(data, n_colors=5, min=0, max=1): fig, ax = plt.subplots() # get discrete colormap cmap = plt.get_cmap("gray", n_colors + 1) ax.matshow(data, cmap=cmap, vmin=min, vmax=max) ax.axis("off") fig.tight_layout() return fig def continuous_matshow(data, min=0, max=1): fig, ax = plt.subplots() # get discrete colormap cmap = plt.get_cmap("seismic") ax.matshow(data, cmap=cmap, vmin=min, vmax=max) ax.axis("off") # optionally: provide a colorbar and tick at integers # cax = plt.colorbar(mat, ticks=np.arange(min, max + 1)) return fig def iterate(model, data_loader, config, writer, mode="train", epoch=None, device=None): if len(data_loader) == 0: raise ValueError("Received data loader with zero samples!") # loss meter, needs 1 meter per scalar (see https://tnt.readthedocs.io/en/latest/_modules/torchnet/meter/averagevaluemeter.html); loss_meter = tnt.meter.AverageValueMeter() img_meter = avg_img_metrics() # collect sample-averaged uncertainties and errors errs, errs_se, errs_ae, vars_aleatoric = [], [], [], [] t_start = time.time() for i, batch in enumerate(tqdm(data_loader)): step = (epoch - 1) * len(data_loader) + i if config.sample_type == "cloudy_cloudfree": x, y, in_m, dates = prepare_data(batch, device, config) elif config.sample_type == "pretrain": x, y, in_m = prepare_data(batch, device, config) dates = None else: raise NotImplementedError inputs = {"A": x, "B": y, "dates": dates, "masks": in_m} if mode != "train": # val or test with torch.no_grad(): # compute single-model mean and variance predictions model.set_input(inputs) model.forward() model.get_loss_G() model.rescale() out = model.fake_B if hasattr(model.netG, "variance") and model.netG.variance is not None: var = model.netG.variance model.netG.variance = None else: var = out[:, :, S2_BANDS:, ...] out = out[:, :, :S2_BANDS, ...] batch_size = y.size()[0] for bdx in range(batch_size): # only compute statistics on variance estimates if using e.g. NLL loss or combinations thereof if config.loss in ["GNLL", "MGNLL"]: # if the variance variable is of shape [B x 1 x C x C x H x W] then it's a covariance tensor if len(var.shape) > 5: covar = var # get [B x 1 x C x H x W] variance tensor var = var.diagonal(dim1=2, dim2=3).moveaxis(-1, 2) extended_metrics = img_metrics(y[bdx], out[bdx], var=var[bdx]) vars_aleatoric.append(extended_metrics["mean var"]) errs.append(extended_metrics["error"]) errs_se.append(extended_metrics["mean se"]) errs_ae.append(extended_metrics["mean ae"]) else: extended_metrics = img_metrics(y[bdx], out[bdx]) img_meter.add(extended_metrics) idx = i * batch_size + bdx # plot and export every k-th item if config.plot_every > 0 and idx % config.plot_every == 0: plot_dir = os.path.join( config.res_dir, config.experiment_name, "plots", f"epoch_{epoch}", f"{mode}", ) plot_img(x[bdx], "in", plot_dir, file_id=idx) plot_img(out[bdx], "pred", plot_dir, file_id=idx) plot_img(y[bdx], "target", plot_dir, file_id=idx) plot_img( ((out[bdx] - y[bdx]) ** 2).mean(1, keepdims=True), "err", plot_dir, file_id=idx, ) plot_img( discrete_matshow( in_m.float().mean(axis=1).cpu()[bdx], n_colors=config.input_t, ), "mask", plot_dir, file_id=idx, ) if var is not None: plot_img( var.mean(2, keepdims=True)[bdx], "var", plot_dir, file_id=idx, ) if config.export_every > 0 and idx % config.export_every == 0: export_dir = os.path.join( config.res_dir, config.experiment_name, "export", f"epoch_{epoch}", f"{mode}", ) export(out[bdx], "pred", export_dir, file_id=idx) export(y[bdx], "target", export_dir, file_id=idx) if var is not None: try: export(covar[bdx], "covar", export_dir, file_id=idx) except: export(var[bdx], "var", export_dir, file_id=idx) else: # training # compute single-model mean and variance predictions model.set_input(inputs) model.optimize_parameters() # not using model.forward() directly out = model.fake_B.detach().cpu() # read variance predictions stored on generator if hasattr(model.netG, "variance") and model.netG.variance is not None: var = model.netG.variance.cpu() else: var = out[:, :, S2_BANDS:, ...] out = out[:, :, :S2_BANDS, ...] if config.plot_every > 0: plot_out = out.detach().clone() batch_size = y.size()[0] for bdx in range(batch_size): idx = i * batch_size + bdx # plot and export every k-th item if idx % config.plot_every == 0: plot_dir = os.path.join( config.res_dir, config.experiment_name, "plots", f"epoch_{epoch}", f"{mode}", ) plot_img(x[bdx], "in", plot_dir, file_id=i) plot_img(plot_out[bdx], "pred", plot_dir, file_id=i) plot_img(y[bdx], "target", plot_dir, file_id=i) if mode == "train": # periodically log stats if step % config.display_step == 0: out, x, y, in_m = out.cpu(), x.cpu(), y.cpu(), in_m.cpu() if config.loss in ["GNLL", "MGNLL"]: var = var.cpu() log_train(writer, config, model, step, x, out, y, in_m, var=var) else: log_train(writer, config, model, step, x, out, y, in_m) # log the loss, computed via model.backward_G() at train time & via model.get_loss_G() at val/test time loss_meter.add(model.loss_G.item()) # after each batch, close any leftover figures plt.close("all") # --- end of epoch --- # after each epoch, log the loss metrics t_end = time.time() total_time = t_end - t_start print("Epoch time : {:.1f}s".format(total_time)) metrics = {f"{mode}_epoch_time": total_time} # log the loss, only computed within model.backward_G() at train time metrics[f"{mode}_loss"] = loss_meter.value()[0] if mode == "train": # after each epoch, update lr acc. to scheduler current_lr = model.optimizer_G.state_dict()["param_groups"][0]["lr"] writer.add_scalar("Etc/train/lr", current_lr, step) model.scheduler_G.step() if mode == "test" or mode == "val": # log the metrics # log image metrics for key, val in img_meter.value().items(): writer.add_scalar(f"{mode}/{key}", val, step) # any loss is currently only computed within model.backward_G() at train time writer.add_scalar(f"{mode}/loss", metrics[f"{mode}_loss"], step) # use add_images for batch-wise adding across temporal dimension if config.use_sar: writer.add_image( f"Img/{mode}/in_s1", x[0, :, [0], ...], step, dataformats="NCHW" ) writer.add_image( f"Img/{mode}/in_s2", x[0, :, [5, 4, 3], ...], step, dataformats="NCHW" ) else: writer.add_image( f"Img/{mode}/in_s2", x[0, :, [3, 2, 1], ...], step, dataformats="NCHW" ) writer.add_image( f"Img/{mode}/out", out[0, 0, [3, 2, 1], ...], step, dataformats="CHW" ) writer.add_image( f"Img/{mode}/y", y[0, 0, [3, 2, 1], ...], step, dataformats="CHW" ) writer.add_image( f"Img/{mode}/m", in_m[0, :, None, ...], step, dataformats="NCHW" ) # compute Expected Calibration Error (ECE) if config.loss in ["GNLL", "MGNLL"]: sorted_errors_se = compute_ece( vars_aleatoric, errs_se, len(data_loader.dataset), percent=5 ) sorted_errors = {"se_sortAleatoric": sorted_errors_se} plot_discard( sorted_errors["se_sortAleatoric"], config, mode, step, is_se=True ) # compute ECE uce_l2, auce_l2 = compute_uce_auce( vars_aleatoric, errs, len(data_loader.dataset), percent=5, l2=True, mode=mode, step=step, ) # no need for a running mean here img_meter.value()["UCE SE"] = uce_l2.cpu().numpy().item() img_meter.value()["AUCE SE"] = auce_l2.cpu().numpy().item() if config.loss in ["GNLL", "MGNLL"]: log_aleatoric(writer, config, mode, step, var, f"model/", img_meter) return metrics, img_meter.value() else: return metrics def plot_discard(sorted_errors, config, mode, step, is_se=True): metric = "SE" if is_se else "AE" fig, ax = plt.subplots() x_axis = np.arange(0.0, 1.0, 0.05) ax.scatter( x_axis, sorted_errors, c="b", alpha=1.0, marker=r".", label=f"{metric}, sorted by uncertainty", ) # fit a linear regressor with slope b and intercept a sorted_errors[np.isnan(sorted_errors)] = np.nanmean(sorted_errors) b, a = np.polyfit(x_axis, sorted_errors, deg=1) x_seq = np.linspace(0, 1.0, num=1000) ax.plot( x_seq, a + b * x_seq, c="k", lw=1.5, alpha=0.75, label=f"linear fit, {round(a, 3)} + {round(b, 3)} * x", ) plt.xlabel("Fraction of samples, sorted ascendingly by uncertainty") plt.ylabel("Error") plt.legend(loc="upper left") plt.grid() fig.tight_layout() writer.add_figure(f"Img/{mode}/discard_uncertain", fig, step) if mode == "test": # export the final test split plots for print path_to = os.path.join(config.res_dir, config.experiment_name) print(f"Logging discard plots to path {path_to}") fig.savefig( os.path.join(path_to, f"plot_{mode}_{metric}_discard.png"), bbox_inches="tight", dpi=int(1e3), ) fig.savefig( os.path.join(path_to, f"plot_{mode}_{metric}_discard.pdf"), bbox_inches="tight", dpi=int(1e3), ) def compute_ece(vars, errors, n_samples, percent=5): # rank sample-averaged uncertainties ascendingly, and errors accordingly _, vars_indices = torch.sort(torch.Tensor(vars)) errors = torch.Tensor(errors) errs_sort = errors[vars_indices] # incrementally remove 5% of errors, ranked by highest uncertainty bins = torch.linspace(0, n_samples, 100 // percent + 1, dtype=int)[1:] # get uncertainty-sorted cumulative errors, i.e. at x-tick 65% we report the average error for the 65% most certain predictions sorted_errors = np.array( [torch.nanmean(errs_sort[:rdx]).cpu().numpy() for rdx in bins] ) return sorted_errors binarize = lambda arg, n_bins, floor=0, ceil=1: np.digitize( arg, bins=np.linspace(floor, ceil, num=n_bins)[1:] ) def compute_uce_auce(var, errors, n_samples, percent=5, l2=True, mode="val", step=0): n_bins = 100 // percent var, errors = torch.Tensor(var), torch.Tensor(errors) # metric: IN: standard deviation & error # OUT: either root mean variance & root mean squared error or mean standard deviation & mean absolute error metric = ( lambda arg: torch.sqrt(torch.mean(arg**2)) if l2 else torch.mean(torch.abs(arg)) ) m_str = "L2" if l2 else "L1" # group uncertainty values into n_bins var_idx = torch.Tensor(binarize(var, n_bins, floor=var.min(), ceil=var.max())) # compute bin-wise statistics, defaults to nan if no data contained in bin bk_var, bk_err = torch.empty(n_bins), torch.empty(n_bins) for bin_idx in range(n_bins): # for each of the n_bins ... bk_var[bin_idx] = metric( var[var_idx == bin_idx].sqrt() ) # note: taking the sqrt to wrap into metric function, bk_err[bin_idx] = metric( errors[var_idx == bin_idx] ) # apply same metric function on error calib_err = torch.abs( bk_err - bk_var ) # calibration error: discrepancy of error vs uncertainty bk_weight = ( torch.histogram(var_idx, n_bins)[0] / n_samples ) # fraction of total data per bin, for bin-weighting uce = torch.nansum(bk_weight * calib_err) # calc. weighted UCE, auce = torch.nanmean(calib_err) # calc. unweighted AUCE # plot bin-wise error versus bin-wise uncertainty fig, ax = plt.subplots() x_min, x_max = bk_var[~bk_var.isnan()].min(), bk_var[~bk_var.isnan()].max() y_min, y_max = 0, bk_err[~bk_err.isnan()].max() x_axis = np.linspace(x_min, x_max, num=n_bins) ax.plot(x_axis, x_axis) # diagonal reference line ax.bar( x_axis, bk_err, width=x_axis[1] - x_axis[0], alpha=0.75, edgecolor="k", color="gray", ) plt.xlim(x_min, x_max) plt.ylim(y_min, y_max) plt.xlabel("Uncertainty") plt.ylabel(f"{m_str} Error") plt.legend(loc="upper left") plt.grid() fig.tight_layout() writer.add_figure(f"Img/{mode}/err_vs_var_{m_str}", fig, step) return uce, auce def recursive_todevice(x, device): if isinstance(x, torch.Tensor): return x.to(device) elif isinstance(x, dict): return {k: recursive_todevice(v, device) for k, v in x.items()} else: return [recursive_todevice(c, device) for c in x] def prepare_output(config): os.makedirs(os.path.join(config.res_dir, config.experiment_name), exist_ok=True) def checkpoint(log, config): with open( os.path.join(config.res_dir, config.experiment_name, "trainlog.json"), "w" ) as outfile: json.dump(log, outfile, indent=4) def save_results(metrics, path, split="test"): with open(os.path.join(path, f"{split}_metrics.json"), "w") as outfile: json.dump(metrics, outfile, indent=4) # check for file of pre-computed statistics, e.g. indices or cloud coverage def import_from_path(split, config): if os.path.exists( os.path.join(os.path.dirname(os.getcwd()), "util", "precomputed") ): import_path = os.path.join( os.path.dirname(os.getcwd()), "util", "precomputed", f"generic_{config.input_t}_{split}_{config.region}_s2cloudless_mask.npy", ) else: import_path = os.path.join( config.precomputed, f"generic_{config.input_t}_{split}_{config.region}_s2cloudless_mask.npy", ) import_data_path = import_path if os.path.isfile(import_path) else None return import_data_path def main(config): prepare_output(config) device = torch.device(config.device) # define data sets if config.pretrain: # pretrain / training on mono-temporal data dt_train = SEN12MSCR( os.path.expanduser(config.root3), split="train", region=config.region, sample_type=config.sample_type, ) dt_val = SEN12MSCR( os.path.expanduser(config.root3), split="test", region=config.region, sample_type=config.sample_type, ) dt_test = SEN12MSCR( os.path.expanduser(config.root3), split="val", region=config.region, sample_type=config.sample_type, ) else: pass # wrap to allow for subsampling, e.g. for test runs etc dt_train = torch.utils.data.Subset( dt_train, range( 0, min( config.max_samples_count, len(dt_train), int(len(dt_train) * config.max_samples_frac), ), ), ) dt_val = torch.utils.data.Subset( dt_val, range( 0, min( config.max_samples_count, len(dt_val), int(len(dt_train) * config.max_samples_frac), ), ), ) dt_test = torch.utils.data.Subset( dt_test, range( 0, min( config.max_samples_count, len(dt_test), int(len(dt_train) * config.max_samples_frac), ), ), ) # instantiate dataloaders, note: worker_init_fn is needed to get reproducible random samples across runs if vary_samples=True train_loader = torch.utils.data.DataLoader( dt_train, batch_size=config.batch_size, shuffle=True, worker_init_fn=seed_worker, generator=f, num_workers=config.num_workers, pin_memory=True, persistent_workers=True, ) val_loader = torch.utils.data.DataLoader( dt_val, batch_size=config.batch_size, shuffle=False, worker_init_fn=seed_worker, generator=g, num_workers=config.num_workers, pin_memory=True, persistent_workers=True, ) test_loader = torch.utils.data.DataLoader( dt_test, batch_size=config.batch_size, shuffle=False, worker_init_fn=seed_worker, generator=g, num_workers=config.num_workers, pin_memory=True, persistent_workers=True, ) print("Train {}, Val {}, Test {}".format(len(dt_train), len(dt_val), len(dt_test))) # model definition # (compiled model hangs up in validation step on some systems, retry in the future for pytorch > 2.0) model = get_model(config) # torch.compile(get_model(config)) # set model properties model.len_epoch = len(train_loader) config.N_params = utils.get_ntrainparams(model) print("\n\nTrainable layers:") for name, p in model.named_parameters(): if p.requires_grad: print(f"\t{name}") model = model.to(device) # do random weight initialization print("\nInitializing weights randomly.") model.netG.apply(weight_init) if config.trained_checkp and len(config.trained_checkp) > 0: # load weights from the indicated checkpoint print(f"Loading weights from (pre-)trained checkpoint {config.trained_checkp}") load_model( config, model, train_out_layer=False, load_out_partly=config.model in ["uncrtaints"], ) with open( os.path.join(config.res_dir, config.experiment_name, "conf.json"), "w" ) as file: file.write(json.dumps(vars(config), indent=4)) print(f"TOTAL TRAINABLE PARAMETERS: {config.N_params}\n") print(model) # Optimizer and Loss model.criterion = losses.get_loss(config) # track best loss, checkpoint at best validation performance is_better, best_loss = lambda new, prev: new <= prev, float("inf") # Training loop trainlog = {} # resume training at scheduler's latest epoch, != 0 if --resume_from begin_at = ( config.resume_at if config.resume_at >= 0 else model.scheduler_G.state_dict()["last_epoch"] ) for epoch in range(begin_at + 1, config.epochs + 1): print("\nEPOCH {}/{}".format(epoch, config.epochs)) # put all networks in training mode again model.train() model.netG.train() # unfreeze all layers after specified epoch if epoch > config.unfreeze_after and hasattr(model, "frozen") and model.frozen: print("Unfreezing all network layers") model.frozen = False freeze_layers(model.netG, grad=True) # re-seed train generator for each epoch anew, depending on seed choice plus current epoch number # ~ else, dataloader provides same samples no matter what epoch training starts/resumes from # ~ note: only re-seed train split dataloader (if config.vary_samples), but keep all others consistent # ~ if desiring different runs, then the seeds must at least be config.epochs numbers apart if config.vary_samples: # condition dataloader samples on current epoch count f.manual_seed(config.rdm_seed + epoch) train_loader = torch.utils.data.DataLoader( dt_train, batch_size=config.batch_size, shuffle=True, worker_init_fn=seed_worker, generator=f, num_workers=config.num_workers, ) train_metrics = iterate( model, data_loader=train_loader, config=config, writer=writer, mode="train", epoch=epoch, device=device, ) # do regular validation steps at the end of each training epoch if epoch % config.val_every == 0 and epoch > config.val_after: print("Validation . . . ") model.eval() model.netG.eval() val_metrics, val_img_metrics = iterate( model, data_loader=val_loader, config=config, writer=writer, mode="val", epoch=epoch, device=device, ) # use the training loss for validation print("Using training loss as validation loss") if "val_loss" in val_metrics: val_loss = val_metrics["val_loss"] else: val_loss = val_metrics["val_loss_ensembleAverage"] print(f"Validation Loss {val_loss}") print(f"validation image metrics: {val_img_metrics}") save_results( val_img_metrics, os.path.join(config.res_dir, config.experiment_name), split=f"test_epoch_{epoch}", ) print( f"\nLogged validation epoch {epoch} metrics to path {os.path.join(config.res_dir, config.experiment_name)}" ) # checkpoint best model trainlog[epoch] = {**train_metrics, **val_metrics} checkpoint(trainlog, config) if is_better(val_loss, best_loss): best_loss = val_loss save_model(config, epoch, model, "model") else: trainlog[epoch] = {**train_metrics} checkpoint(trainlog, config) # always checkpoint the current epoch's model save_model(config, epoch, model, f"model_epoch_{epoch}") print(f"Completed current epoch of experiment {config.experiment_name}.") # following training, test on hold-out data print("Testing best epoch . . .") load_checkpoint(config, config.res_dir, model, "model") model.eval() model.netG.eval() test_metrics, test_img_metrics = iterate( model, data_loader=test_loader, config=config, writer=writer, mode="test", epoch=epoch, device=device, ) if "test_loss" in test_metrics: test_loss = test_metrics["test_loss"] else: test_loss = test_metrics["test_loss_ensembleAverage"] print(f"Test Loss {test_loss}") print(f"\nTest image metrics: {test_img_metrics}") save_results( test_img_metrics, os.path.join(config.res_dir, config.experiment_name), split="test", ) print( f"\nLogged test metrics to path {os.path.join(config.res_dir, config.experiment_name)}" ) # close tensorboard logging writer.close() print(f"Finished training experiment {config.experiment_name}.") if __name__ == "__main__": main(config) exit()