|
""" |
|
Main script for image reconstruction experiments |
|
Author: Patrick Ebel (github/PatrickTUM), based on the scripts of |
|
Vivien Sainte Fare Garnot (github/VSainteuf) |
|
License: MIT |
|
""" |
|
|
|
|
|
import os |
|
import sys |
|
import time |
|
import json |
|
import random |
|
import pprint |
|
import argparse |
|
import numpy as np |
|
from tqdm import tqdm |
|
from matplotlib import pyplot as plt |
|
|
|
dirname = os.path.dirname(os.path.abspath(__file__)) |
|
sys.path.append(os.path.dirname(dirname)) |
|
|
|
from parse_args import create_parser |
|
from data.dataLoader import SEN12MSCR |
|
from src.model_utils import ( |
|
get_model, |
|
save_model, |
|
freeze_layers, |
|
load_model, |
|
load_checkpoint, |
|
) |
|
from src.learning.metrics import img_metrics, avg_img_metrics |
|
|
|
import torch |
|
import torchnet as tnt |
|
from torch.utils.tensorboard import SummaryWriter |
|
|
|
from src import utils, losses |
|
from src.learning.weight_init import weight_init |
|
|
|
S2_BANDS = 13 |
|
parser = create_parser(mode="train") |
|
config = utils.str2list( |
|
parser.parse_args(), list_args=["encoder_widths", "decoder_widths", "out_conv"] |
|
) |
|
|
|
if config.model in ["unet", "utae"]: |
|
assert len(config.encoder_widths) == len(config.decoder_widths) |
|
config.loss = "l2" |
|
if config.model == "unet": |
|
|
|
config.pretrain = True |
|
config.trained_checkp = "" |
|
|
|
if config.pretrain: |
|
config.input_t = config.n_head = 1 |
|
config.sample_type = "pretrain" |
|
if config.model == "unet": |
|
config.batch_size = 32 |
|
config.positional_encoding = False |
|
|
|
if config.loss in ["GNLL", "MGNLL"]: |
|
|
|
if config.loss in ["GNLL"]: |
|
config.covmode = "uni" |
|
|
|
if config.covmode == "iso": |
|
config.out_conv[-1] += 1 |
|
elif config.covmode in ["uni", "diag"]: |
|
config.out_conv[-1] += S2_BANDS |
|
config.var_nonLinearity = "softplus" |
|
|
|
|
|
config.pid = os.getpid() |
|
|
|
|
|
if config.resume_from: |
|
load_conf = os.path.join(config.res_dir, config.experiment_name, "conf.json") |
|
if config.experiment_name != config.trained_checkp.split("/")[-2]: |
|
raise ValueError("Mismatch of loaded config file and checkpoints") |
|
with open(load_conf, "rt") as f: |
|
t_args = argparse.Namespace() |
|
|
|
no_overwrite = [ |
|
"pid", |
|
"num_workers", |
|
"root1", |
|
"root2", |
|
"root3", |
|
"resume_from", |
|
"trained_checkp", |
|
"epochs", |
|
"encoder_widths", |
|
"decoder_widths", |
|
"lr", |
|
] |
|
conf_dict = { |
|
key: val for key, val in json.load(f).items() if key not in no_overwrite |
|
} |
|
for key, val in vars(config).items(): |
|
if key in no_overwrite: |
|
conf_dict[key] = val |
|
t_args.__dict__.update(conf_dict) |
|
config = parser.parse_args(namespace=t_args) |
|
config = utils.str2list( |
|
config, list_args=["encoder_widths", "decoder_widths", "out_conv"] |
|
) |
|
|
|
|
|
if config.resume_at >= 0: |
|
config.lr = config.lr * config.gamma**config.resume_at |
|
|
|
|
|
|
|
|
|
def seed_packages(seed): |
|
random.seed(seed) |
|
np.random.seed(seed) |
|
torch.manual_seed(seed) |
|
torch.cuda.manual_seed(seed) |
|
torch.cuda.manual_seed_all(seed) |
|
os.environ["PYTHONHASHSEED"] = str(seed) |
|
|
|
torch.backends.cudnn.benchmark = False |
|
|
|
|
|
def seed_worker(worker_id): |
|
worker_seed = torch.initial_seed() % 2**32 |
|
np.random.seed(worker_seed) |
|
random.seed(worker_seed) |
|
|
|
|
|
|
|
seed_packages(config.rdm_seed) |
|
|
|
f, g = torch.Generator(), torch.Generator() |
|
f.manual_seed(config.rdm_seed + 0) |
|
g.manual_seed(config.rdm_seed) |
|
|
|
if __name__ == "__main__": |
|
pprint.pprint(config) |
|
|
|
|
|
writer = SummaryWriter( |
|
os.path.join(os.path.dirname(config.res_dir), "logs", config.experiment_name) |
|
) |
|
|
|
|
|
def plot_img(imgs, mod, plot_dir, file_id=None): |
|
if not os.path.exists(plot_dir): |
|
os.makedirs(plot_dir) |
|
try: |
|
imgs = imgs.cpu().numpy() |
|
for tdx, img in enumerate(imgs): |
|
time = "" if imgs.shape[0] == 1 else f"_t-{tdx}" |
|
if mod in ["pred", "in", "target", "s2"]: |
|
rgb = [3, 2, 1] if img.shape[0] == S2_BANDS else [5, 4, 3] |
|
img, val_min, val_max = img[rgb, ...], 0, 1 |
|
elif mod == "s1": |
|
img, val_min, val_max = img[[0], ...], 0, 1 |
|
elif mod == "mask": |
|
img, val_min, val_max = img[[0], ...], 0, 1 |
|
elif mod == "err": |
|
img, val_min, val_max = img[[0], ...], 0, 0.01 |
|
elif mod == "var": |
|
img, val_min, val_max = img[[0], ...], 0, 0.000025 |
|
else: |
|
raise NotImplementedError |
|
if file_id is not None: |
|
img = img.clip( |
|
val_min, val_max |
|
) |
|
plt.imsave( |
|
os.path.join(plot_dir, f"img-{file_id}_{mod}{time}.png"), |
|
np.moveaxis(img, 0, -1).squeeze(), |
|
dpi=100, |
|
cmap="gray", |
|
vmin=val_min, |
|
vmax=val_max, |
|
) |
|
except: |
|
if isinstance(imgs, plt.Figure): |
|
plt.savefig(os.path.join(plot_dir, f"img-{file_id}_{mod}.png"), dpi=100) |
|
else: |
|
raise NotImplementedError |
|
|
|
|
|
def export(arrs, mod, export_dir, file_id=None): |
|
if not os.path.exists(export_dir): |
|
os.makedirs(export_dir) |
|
for tdx, arr in enumerate(arrs): |
|
num = "" if arrs.shape[0] == 1 else f"_t-{tdx}" |
|
np.save(os.path.join(export_dir, f"img-{file_id}_{mod}{num}.npy"), arr.cpu()) |
|
|
|
|
|
def prepare_data(batch, device, config): |
|
if config.pretrain: |
|
return prepare_data_mono(batch, device, config) |
|
else: |
|
return prepare_data_multi(batch, device, config) |
|
|
|
|
|
def prepare_data_mono(batch, device, config): |
|
x = batch["input"]["S2"].to(device).unsqueeze(1) |
|
if config.use_sar: |
|
x = torch.cat((batch["input"]["S1"].to(device).unsqueeze(1), x), dim=2) |
|
m = batch["input"]["masks"].to(device).unsqueeze(1) |
|
y = batch["target"]["S2"].to(device).unsqueeze(1) |
|
return x, y, m |
|
|
|
|
|
def prepare_data_multi(batch, device, config): |
|
in_S2 = recursive_todevice(batch["input"]["S2"], device) |
|
in_S2_td = recursive_todevice(batch["input"]["S2 TD"], device) |
|
if config.batch_size > 1: |
|
in_S2_td = torch.stack((in_S2_td)).T |
|
in_m = torch.stack(recursive_todevice(batch["input"]["masks"], device)).swapaxes( |
|
0, 1 |
|
) |
|
target_S2 = recursive_todevice(batch["target"]["S2"], device) |
|
y = torch.cat(target_S2, dim=0).unsqueeze(1) |
|
|
|
if config.use_sar: |
|
in_S1 = recursive_todevice(batch["input"]["S1"], device) |
|
in_S1_td = recursive_todevice(batch["input"]["S1 TD"], device) |
|
if config.batch_size > 1: |
|
in_S1_td = torch.stack((in_S1_td)).T |
|
x = torch.cat((torch.stack(in_S1, dim=1), torch.stack(in_S2, dim=1)), dim=2) |
|
dates = ( |
|
torch.stack((torch.tensor(in_S1_td), torch.tensor(in_S2_td))) |
|
.float() |
|
.mean(dim=0) |
|
.to(device) |
|
) |
|
else: |
|
x = torch.stack(in_S2, dim=1) |
|
dates = torch.tensor(in_S2_td).float().to(device) |
|
|
|
return x, y, in_m, dates |
|
|
|
|
|
def log_aleatoric(writer, config, mode, step, var, name, img_meter=None): |
|
|
|
if len(var.shape) > 5: |
|
covar = var |
|
|
|
var = var.diagonal(dim1=2, dim2=3).moveaxis(-1, 2) |
|
|
|
|
|
patch_covmat = covar.mean(dim=-1).mean(dim=-1).squeeze(dim=1) |
|
for bdx, img in enumerate(patch_covmat): |
|
img = img.detach().numpy() |
|
|
|
max_abs = max(abs(img.min()), abs(img.max())) |
|
scale_rel_left, scale_rel_right = -max_abs, +max_abs |
|
fig = continuous_matshow(img, min=scale_rel_left, max=scale_rel_right) |
|
writer.add_figure(f"Img/{mode}/patch covmat relative {bdx}", fig, step) |
|
scale_center0_absolute = ( |
|
1 / 4 * 1**2 |
|
) |
|
fig = continuous_matshow( |
|
img, min=-scale_center0_absolute, max=scale_center0_absolute |
|
) |
|
writer.add_figure(f"Img/{mode}/patch covmat absolute {bdx}", fig, step) |
|
|
|
|
|
|
|
avg_var = torch.mean( |
|
var, dim=2, keepdim=True |
|
) |
|
q50 = ( |
|
avg_var[:, 0, ...].view(avg_var.shape[0], -1).median(dim=-1)[0].detach().clone() |
|
) |
|
q75 = ( |
|
avg_var[:, 0, ...] |
|
.view(avg_var.shape[0], -1) |
|
.quantile(0.75, dim=-1) |
|
.detach() |
|
.clone() |
|
) |
|
q50, q75 = q50[0], q75[0] |
|
binning = 256 |
|
|
|
if config.loss in ["GNLL", "MGNLL"]: |
|
writer.add_image( |
|
f"Img/{mode}/{name}aleatoric [0,1]", |
|
avg_var[0, 0, ...].clip(0, 1), |
|
step, |
|
dataformats="CHW", |
|
) |
|
writer.add_image( |
|
f"Img/{mode}/{name}aleatoric [0,q75]", |
|
avg_var[0, 0, ...].clip(0.0, q75) / q75, |
|
step, |
|
dataformats="CHW", |
|
) |
|
writer.add_histogram( |
|
f"Hist/{mode}/{name}aleatoric", |
|
avg_var[0, 0, ...].flatten().clip(0, 1), |
|
step, |
|
bins=binning, |
|
max_bins=binning, |
|
) |
|
else: |
|
raise NotImplementedError |
|
|
|
writer.add_scalar(f"{mode}/{name}aleatoric median all", q50, step) |
|
writer.add_scalar(f"{mode}/{name}aleatoric q75 all", q75, step) |
|
if img_meter is not None: |
|
writer.add_scalar(f"{mode}/{name}UCE SE", img_meter.value()["UCE SE"], step) |
|
writer.add_scalar(f"{mode}/{name}AUCE SE", img_meter.value()["AUCE SE"], step) |
|
|
|
|
|
def log_train(writer, config, model, step, x, out, y, in_m, name="", var=None): |
|
|
|
_, loss = model.criterion, model.loss_G.cpu() |
|
if name != "": |
|
name = f"model_{name}/" |
|
|
|
writer.add_scalar(f"train/{name}{config.loss}", loss, step) |
|
writer.add_scalar(f"train/{name}total", loss, step) |
|
|
|
if config.use_sar: |
|
writer.add_image( |
|
f"Img/train/{name}in_s1", x[0, :, [0], ...], step, dataformats="NCHW" |
|
) |
|
writer.add_image( |
|
f"Img/train/{name}in_s2", x[0, :, [5, 4, 3], ...], step, dataformats="NCHW" |
|
) |
|
else: |
|
writer.add_image( |
|
f"Img/train/{name}in_s2", x[0, :, [3, 2, 1], ...], step, dataformats="NCHW" |
|
) |
|
writer.add_image( |
|
f"Img/train/{name}out", out[0, 0, [3, 2, 1], ...], step, dataformats="CHW" |
|
) |
|
writer.add_image( |
|
f"Img/train/{name}y", y[0, 0, [3, 2, 1], ...], step, dataformats="CHW" |
|
) |
|
writer.add_image( |
|
f"Img/train/{name}m", in_m[0, :, None, ...], step, dataformats="NCHW" |
|
) |
|
|
|
|
|
|
|
|
|
|
|
and_m, or_m = torch.prod(in_m, dim=1, keepdim=True), torch.sum( |
|
in_m, dim=1, keepdim=True |
|
).clip(0, 1) |
|
writer.add_scalar(f"train/{name}OR m %", or_m.float().mean(), step) |
|
writer.add_scalar(f"train/{name}AND m %", and_m.float().mean(), step) |
|
writer.add_image(f"Img/train/{name}AND m", and_m, step, dataformats="NCHW") |
|
writer.add_image(f"Img/train/{name}OR m", or_m, step, dataformats="NCHW") |
|
|
|
and_m_gray = in_m.float().mean(axis=1).cpu() |
|
for bdx, img in enumerate(and_m_gray): |
|
fig = discrete_matshow(img, n_colors=config.input_t) |
|
writer.add_figure(f"Img/train/temp overlay m {bdx}", fig, step) |
|
|
|
if var is not None: |
|
|
|
log_aleatoric(writer, config, "train", step, var, name, img_meter=None) |
|
|
|
|
|
def discrete_matshow(data, n_colors=5, min=0, max=1): |
|
fig, ax = plt.subplots() |
|
|
|
cmap = plt.get_cmap("gray", n_colors + 1) |
|
ax.matshow(data, cmap=cmap, vmin=min, vmax=max) |
|
ax.axis("off") |
|
fig.tight_layout() |
|
return fig |
|
|
|
|
|
def continuous_matshow(data, min=0, max=1): |
|
fig, ax = plt.subplots() |
|
|
|
cmap = plt.get_cmap("seismic") |
|
ax.matshow(data, cmap=cmap, vmin=min, vmax=max) |
|
ax.axis("off") |
|
|
|
|
|
return fig |
|
|
|
|
|
def iterate(model, data_loader, config, writer, mode="train", epoch=None, device=None): |
|
if len(data_loader) == 0: |
|
raise ValueError("Received data loader with zero samples!") |
|
|
|
loss_meter = tnt.meter.AverageValueMeter() |
|
img_meter = avg_img_metrics() |
|
|
|
|
|
errs, errs_se, errs_ae, vars_aleatoric = [], [], [], [] |
|
|
|
t_start = time.time() |
|
for i, batch in enumerate(tqdm(data_loader)): |
|
step = (epoch - 1) * len(data_loader) + i |
|
|
|
if config.sample_type == "cloudy_cloudfree": |
|
x, y, in_m, dates = prepare_data(batch, device, config) |
|
elif config.sample_type == "pretrain": |
|
x, y, in_m = prepare_data(batch, device, config) |
|
dates = None |
|
else: |
|
raise NotImplementedError |
|
inputs = {"A": x, "B": y, "dates": dates, "masks": in_m} |
|
|
|
if mode != "train": |
|
with torch.no_grad(): |
|
|
|
model.set_input(inputs) |
|
model.forward() |
|
model.get_loss_G() |
|
model.rescale() |
|
out = model.fake_B |
|
if hasattr(model.netG, "variance") and model.netG.variance is not None: |
|
var = model.netG.variance |
|
model.netG.variance = None |
|
else: |
|
var = out[:, :, S2_BANDS:, ...] |
|
out = out[:, :, :S2_BANDS, ...] |
|
batch_size = y.size()[0] |
|
|
|
for bdx in range(batch_size): |
|
|
|
|
|
if config.loss in ["GNLL", "MGNLL"]: |
|
|
|
if len(var.shape) > 5: |
|
covar = var |
|
|
|
var = var.diagonal(dim1=2, dim2=3).moveaxis(-1, 2) |
|
|
|
extended_metrics = img_metrics(y[bdx], out[bdx], var=var[bdx]) |
|
vars_aleatoric.append(extended_metrics["mean var"]) |
|
errs.append(extended_metrics["error"]) |
|
errs_se.append(extended_metrics["mean se"]) |
|
errs_ae.append(extended_metrics["mean ae"]) |
|
else: |
|
extended_metrics = img_metrics(y[bdx], out[bdx]) |
|
|
|
img_meter.add(extended_metrics) |
|
idx = i * batch_size + bdx |
|
if config.plot_every > 0 and idx % config.plot_every == 0: |
|
plot_dir = os.path.join( |
|
config.res_dir, |
|
config.experiment_name, |
|
"plots", |
|
f"epoch_{epoch}", |
|
f"{mode}", |
|
) |
|
plot_img(x[bdx], "in", plot_dir, file_id=idx) |
|
plot_img(out[bdx], "pred", plot_dir, file_id=idx) |
|
plot_img(y[bdx], "target", plot_dir, file_id=idx) |
|
plot_img( |
|
((out[bdx] - y[bdx]) ** 2).mean(1, keepdims=True), |
|
"err", |
|
plot_dir, |
|
file_id=idx, |
|
) |
|
plot_img( |
|
discrete_matshow( |
|
in_m.float().mean(axis=1).cpu()[bdx], |
|
n_colors=config.input_t, |
|
), |
|
"mask", |
|
plot_dir, |
|
file_id=idx, |
|
) |
|
if var is not None: |
|
plot_img( |
|
var.mean(2, keepdims=True)[bdx], |
|
"var", |
|
plot_dir, |
|
file_id=idx, |
|
) |
|
if config.export_every > 0 and idx % config.export_every == 0: |
|
export_dir = os.path.join( |
|
config.res_dir, |
|
config.experiment_name, |
|
"export", |
|
f"epoch_{epoch}", |
|
f"{mode}", |
|
) |
|
export(out[bdx], "pred", export_dir, file_id=idx) |
|
export(y[bdx], "target", export_dir, file_id=idx) |
|
if var is not None: |
|
try: |
|
export(covar[bdx], "covar", export_dir, file_id=idx) |
|
except: |
|
export(var[bdx], "var", export_dir, file_id=idx) |
|
else: |
|
|
|
model.set_input(inputs) |
|
model.optimize_parameters() |
|
out = model.fake_B.detach().cpu() |
|
|
|
|
|
if hasattr(model.netG, "variance") and model.netG.variance is not None: |
|
var = model.netG.variance.cpu() |
|
else: |
|
var = out[:, :, S2_BANDS:, ...] |
|
out = out[:, :, :S2_BANDS, ...] |
|
|
|
if config.plot_every > 0: |
|
plot_out = out.detach().clone() |
|
batch_size = y.size()[0] |
|
for bdx in range(batch_size): |
|
idx = i * batch_size + bdx |
|
if idx % config.plot_every == 0: |
|
plot_dir = os.path.join( |
|
config.res_dir, |
|
config.experiment_name, |
|
"plots", |
|
f"epoch_{epoch}", |
|
f"{mode}", |
|
) |
|
plot_img(x[bdx], "in", plot_dir, file_id=i) |
|
plot_img(plot_out[bdx], "pred", plot_dir, file_id=i) |
|
plot_img(y[bdx], "target", plot_dir, file_id=i) |
|
|
|
if mode == "train": |
|
|
|
if step % config.display_step == 0: |
|
out, x, y, in_m = out.cpu(), x.cpu(), y.cpu(), in_m.cpu() |
|
if config.loss in ["GNLL", "MGNLL"]: |
|
var = var.cpu() |
|
log_train(writer, config, model, step, x, out, y, in_m, var=var) |
|
else: |
|
log_train(writer, config, model, step, x, out, y, in_m) |
|
|
|
|
|
loss_meter.add(model.loss_G.item()) |
|
|
|
|
|
plt.close("all") |
|
|
|
|
|
|
|
t_end = time.time() |
|
total_time = t_end - t_start |
|
print("Epoch time : {:.1f}s".format(total_time)) |
|
metrics = {f"{mode}_epoch_time": total_time} |
|
|
|
metrics[f"{mode}_loss"] = loss_meter.value()[0] |
|
|
|
if mode == "train": |
|
current_lr = model.optimizer_G.state_dict()["param_groups"][0]["lr"] |
|
writer.add_scalar("Etc/train/lr", current_lr, step) |
|
model.scheduler_G.step() |
|
|
|
if mode == "test" or mode == "val": |
|
|
|
|
|
|
|
for key, val in img_meter.value().items(): |
|
writer.add_scalar(f"{mode}/{key}", val, step) |
|
|
|
|
|
writer.add_scalar(f"{mode}/loss", metrics[f"{mode}_loss"], step) |
|
|
|
|
|
if config.use_sar: |
|
writer.add_image( |
|
f"Img/{mode}/in_s1", x[0, :, [0], ...], step, dataformats="NCHW" |
|
) |
|
writer.add_image( |
|
f"Img/{mode}/in_s2", x[0, :, [5, 4, 3], ...], step, dataformats="NCHW" |
|
) |
|
else: |
|
writer.add_image( |
|
f"Img/{mode}/in_s2", x[0, :, [3, 2, 1], ...], step, dataformats="NCHW" |
|
) |
|
writer.add_image( |
|
f"Img/{mode}/out", out[0, 0, [3, 2, 1], ...], step, dataformats="CHW" |
|
) |
|
writer.add_image( |
|
f"Img/{mode}/y", y[0, 0, [3, 2, 1], ...], step, dataformats="CHW" |
|
) |
|
writer.add_image( |
|
f"Img/{mode}/m", in_m[0, :, None, ...], step, dataformats="NCHW" |
|
) |
|
|
|
|
|
if config.loss in ["GNLL", "MGNLL"]: |
|
sorted_errors_se = compute_ece( |
|
vars_aleatoric, errs_se, len(data_loader.dataset), percent=5 |
|
) |
|
sorted_errors = {"se_sortAleatoric": sorted_errors_se} |
|
plot_discard( |
|
sorted_errors["se_sortAleatoric"], config, mode, step, is_se=True |
|
) |
|
|
|
|
|
uce_l2, auce_l2 = compute_uce_auce( |
|
vars_aleatoric, |
|
errs, |
|
len(data_loader.dataset), |
|
percent=5, |
|
l2=True, |
|
mode=mode, |
|
step=step, |
|
) |
|
|
|
|
|
img_meter.value()["UCE SE"] = uce_l2.cpu().numpy().item() |
|
img_meter.value()["AUCE SE"] = auce_l2.cpu().numpy().item() |
|
|
|
if config.loss in ["GNLL", "MGNLL"]: |
|
log_aleatoric(writer, config, mode, step, var, f"model/", img_meter) |
|
|
|
return metrics, img_meter.value() |
|
else: |
|
return metrics |
|
|
|
|
|
def plot_discard(sorted_errors, config, mode, step, is_se=True): |
|
metric = "SE" if is_se else "AE" |
|
|
|
fig, ax = plt.subplots() |
|
x_axis = np.arange(0.0, 1.0, 0.05) |
|
ax.scatter( |
|
x_axis, |
|
sorted_errors, |
|
c="b", |
|
alpha=1.0, |
|
marker=r".", |
|
label=f"{metric}, sorted by uncertainty", |
|
) |
|
|
|
|
|
sorted_errors[np.isnan(sorted_errors)] = np.nanmean(sorted_errors) |
|
b, a = np.polyfit(x_axis, sorted_errors, deg=1) |
|
x_seq = np.linspace(0, 1.0, num=1000) |
|
ax.plot( |
|
x_seq, |
|
a + b * x_seq, |
|
c="k", |
|
lw=1.5, |
|
alpha=0.75, |
|
label=f"linear fit, {round(a, 3)} + {round(b, 3)} * x", |
|
) |
|
plt.xlabel("Fraction of samples, sorted ascendingly by uncertainty") |
|
plt.ylabel("Error") |
|
plt.legend(loc="upper left") |
|
plt.grid() |
|
fig.tight_layout() |
|
writer.add_figure(f"Img/{mode}/discard_uncertain", fig, step) |
|
if mode == "test": |
|
path_to = os.path.join(config.res_dir, config.experiment_name) |
|
print(f"Logging discard plots to path {path_to}") |
|
fig.savefig( |
|
os.path.join(path_to, f"plot_{mode}_{metric}_discard.png"), |
|
bbox_inches="tight", |
|
dpi=int(1e3), |
|
) |
|
fig.savefig( |
|
os.path.join(path_to, f"plot_{mode}_{metric}_discard.pdf"), |
|
bbox_inches="tight", |
|
dpi=int(1e3), |
|
) |
|
|
|
|
|
def compute_ece(vars, errors, n_samples, percent=5): |
|
|
|
_, vars_indices = torch.sort(torch.Tensor(vars)) |
|
errors = torch.Tensor(errors) |
|
errs_sort = errors[vars_indices] |
|
|
|
bins = torch.linspace(0, n_samples, 100 // percent + 1, dtype=int)[1:] |
|
|
|
sorted_errors = np.array( |
|
[torch.nanmean(errs_sort[:rdx]).cpu().numpy() for rdx in bins] |
|
) |
|
|
|
return sorted_errors |
|
|
|
|
|
binarize = lambda arg, n_bins, floor=0, ceil=1: np.digitize( |
|
arg, bins=np.linspace(floor, ceil, num=n_bins)[1:] |
|
) |
|
|
|
|
|
def compute_uce_auce(var, errors, n_samples, percent=5, l2=True, mode="val", step=0): |
|
n_bins = 100 // percent |
|
var, errors = torch.Tensor(var), torch.Tensor(errors) |
|
|
|
|
|
|
|
metric = ( |
|
lambda arg: torch.sqrt(torch.mean(arg**2)) |
|
if l2 |
|
else torch.mean(torch.abs(arg)) |
|
) |
|
m_str = "L2" if l2 else "L1" |
|
|
|
|
|
var_idx = torch.Tensor(binarize(var, n_bins, floor=var.min(), ceil=var.max())) |
|
|
|
|
|
bk_var, bk_err = torch.empty(n_bins), torch.empty(n_bins) |
|
for bin_idx in range(n_bins): |
|
bk_var[bin_idx] = metric( |
|
var[var_idx == bin_idx].sqrt() |
|
) |
|
bk_err[bin_idx] = metric( |
|
errors[var_idx == bin_idx] |
|
) |
|
|
|
calib_err = torch.abs( |
|
bk_err - bk_var |
|
) |
|
bk_weight = ( |
|
torch.histogram(var_idx, n_bins)[0] / n_samples |
|
) |
|
uce = torch.nansum(bk_weight * calib_err) |
|
auce = torch.nanmean(calib_err) |
|
|
|
|
|
fig, ax = plt.subplots() |
|
x_min, x_max = bk_var[~bk_var.isnan()].min(), bk_var[~bk_var.isnan()].max() |
|
y_min, y_max = 0, bk_err[~bk_err.isnan()].max() |
|
x_axis = np.linspace(x_min, x_max, num=n_bins) |
|
|
|
ax.plot(x_axis, x_axis) |
|
ax.bar( |
|
x_axis, |
|
bk_err, |
|
width=x_axis[1] - x_axis[0], |
|
alpha=0.75, |
|
edgecolor="k", |
|
color="gray", |
|
) |
|
|
|
plt.xlim(x_min, x_max) |
|
plt.ylim(y_min, y_max) |
|
plt.xlabel("Uncertainty") |
|
plt.ylabel(f"{m_str} Error") |
|
plt.legend(loc="upper left") |
|
plt.grid() |
|
fig.tight_layout() |
|
writer.add_figure(f"Img/{mode}/err_vs_var_{m_str}", fig, step) |
|
|
|
return uce, auce |
|
|
|
|
|
def recursive_todevice(x, device): |
|
if isinstance(x, torch.Tensor): |
|
return x.to(device) |
|
elif isinstance(x, dict): |
|
return {k: recursive_todevice(v, device) for k, v in x.items()} |
|
else: |
|
return [recursive_todevice(c, device) for c in x] |
|
|
|
|
|
def prepare_output(config): |
|
os.makedirs(os.path.join(config.res_dir, config.experiment_name), exist_ok=True) |
|
|
|
|
|
def checkpoint(log, config): |
|
with open( |
|
os.path.join(config.res_dir, config.experiment_name, "trainlog.json"), "w" |
|
) as outfile: |
|
json.dump(log, outfile, indent=4) |
|
|
|
|
|
def save_results(metrics, path, split="test"): |
|
with open(os.path.join(path, f"{split}_metrics.json"), "w") as outfile: |
|
json.dump(metrics, outfile, indent=4) |
|
|
|
|
|
|
|
def import_from_path(split, config): |
|
if os.path.exists( |
|
os.path.join(os.path.dirname(os.getcwd()), "util", "precomputed") |
|
): |
|
import_path = os.path.join( |
|
os.path.dirname(os.getcwd()), |
|
"util", |
|
"precomputed", |
|
f"generic_{config.input_t}_{split}_{config.region}_s2cloudless_mask.npy", |
|
) |
|
else: |
|
import_path = os.path.join( |
|
config.precomputed, |
|
f"generic_{config.input_t}_{split}_{config.region}_s2cloudless_mask.npy", |
|
) |
|
import_data_path = import_path if os.path.isfile(import_path) else None |
|
return import_data_path |
|
|
|
|
|
def main(config): |
|
prepare_output(config) |
|
device = torch.device(config.device) |
|
|
|
|
|
if config.pretrain: |
|
dt_train = SEN12MSCR( |
|
os.path.expanduser(config.root3), |
|
split="train", |
|
region=config.region, |
|
sample_type=config.sample_type, |
|
) |
|
dt_val = SEN12MSCR( |
|
os.path.expanduser(config.root3), |
|
split="test", |
|
region=config.region, |
|
sample_type=config.sample_type, |
|
) |
|
dt_test = SEN12MSCR( |
|
os.path.expanduser(config.root3), |
|
split="val", |
|
region=config.region, |
|
sample_type=config.sample_type, |
|
) |
|
else: |
|
pass |
|
|
|
|
|
dt_train = torch.utils.data.Subset( |
|
dt_train, |
|
range( |
|
0, |
|
min( |
|
config.max_samples_count, |
|
len(dt_train), |
|
int(len(dt_train) * config.max_samples_frac), |
|
), |
|
), |
|
) |
|
dt_val = torch.utils.data.Subset( |
|
dt_val, |
|
range( |
|
0, |
|
min( |
|
config.max_samples_count, |
|
len(dt_val), |
|
int(len(dt_train) * config.max_samples_frac), |
|
), |
|
), |
|
) |
|
dt_test = torch.utils.data.Subset( |
|
dt_test, |
|
range( |
|
0, |
|
min( |
|
config.max_samples_count, |
|
len(dt_test), |
|
int(len(dt_train) * config.max_samples_frac), |
|
), |
|
), |
|
) |
|
|
|
|
|
train_loader = torch.utils.data.DataLoader( |
|
dt_train, |
|
batch_size=config.batch_size, |
|
shuffle=True, |
|
worker_init_fn=seed_worker, |
|
generator=f, |
|
num_workers=config.num_workers, |
|
pin_memory=True, |
|
persistent_workers=True, |
|
) |
|
val_loader = torch.utils.data.DataLoader( |
|
dt_val, |
|
batch_size=config.batch_size, |
|
shuffle=False, |
|
worker_init_fn=seed_worker, |
|
generator=g, |
|
num_workers=config.num_workers, |
|
pin_memory=True, |
|
persistent_workers=True, |
|
) |
|
test_loader = torch.utils.data.DataLoader( |
|
dt_test, |
|
batch_size=config.batch_size, |
|
shuffle=False, |
|
worker_init_fn=seed_worker, |
|
generator=g, |
|
num_workers=config.num_workers, |
|
pin_memory=True, |
|
persistent_workers=True, |
|
) |
|
|
|
print("Train {}, Val {}, Test {}".format(len(dt_train), len(dt_val), len(dt_test))) |
|
|
|
|
|
|
|
model = get_model(config) |
|
|
|
|
|
model.len_epoch = len(train_loader) |
|
|
|
config.N_params = utils.get_ntrainparams(model) |
|
print("\n\nTrainable layers:") |
|
for name, p in model.named_parameters(): |
|
if p.requires_grad: |
|
print(f"\t{name}") |
|
model = model.to(device) |
|
|
|
print("\nInitializing weights randomly.") |
|
model.netG.apply(weight_init) |
|
|
|
if config.trained_checkp and len(config.trained_checkp) > 0: |
|
|
|
print(f"Loading weights from (pre-)trained checkpoint {config.trained_checkp}") |
|
load_model( |
|
config, |
|
model, |
|
train_out_layer=False, |
|
load_out_partly=config.model in ["uncrtaints"], |
|
) |
|
|
|
with open( |
|
os.path.join(config.res_dir, config.experiment_name, "conf.json"), "w" |
|
) as file: |
|
file.write(json.dumps(vars(config), indent=4)) |
|
print(f"TOTAL TRAINABLE PARAMETERS: {config.N_params}\n") |
|
print(model) |
|
|
|
|
|
model.criterion = losses.get_loss(config) |
|
|
|
|
|
is_better, best_loss = lambda new, prev: new <= prev, float("inf") |
|
|
|
|
|
trainlog = {} |
|
|
|
|
|
begin_at = ( |
|
config.resume_at |
|
if config.resume_at >= 0 |
|
else model.scheduler_G.state_dict()["last_epoch"] |
|
) |
|
for epoch in range(begin_at + 1, config.epochs + 1): |
|
print("\nEPOCH {}/{}".format(epoch, config.epochs)) |
|
|
|
|
|
model.train() |
|
model.netG.train() |
|
|
|
|
|
if epoch > config.unfreeze_after and hasattr(model, "frozen") and model.frozen: |
|
print("Unfreezing all network layers") |
|
model.frozen = False |
|
freeze_layers(model.netG, grad=True) |
|
|
|
|
|
|
|
|
|
|
|
if config.vary_samples: |
|
|
|
f.manual_seed(config.rdm_seed + epoch) |
|
train_loader = torch.utils.data.DataLoader( |
|
dt_train, |
|
batch_size=config.batch_size, |
|
shuffle=True, |
|
worker_init_fn=seed_worker, |
|
generator=f, |
|
num_workers=config.num_workers, |
|
) |
|
|
|
train_metrics = iterate( |
|
model, |
|
data_loader=train_loader, |
|
config=config, |
|
writer=writer, |
|
mode="train", |
|
epoch=epoch, |
|
device=device, |
|
) |
|
|
|
|
|
if epoch % config.val_every == 0 and epoch > config.val_after: |
|
print("Validation . . . ") |
|
|
|
model.eval() |
|
model.netG.eval() |
|
|
|
val_metrics, val_img_metrics = iterate( |
|
model, |
|
data_loader=val_loader, |
|
config=config, |
|
writer=writer, |
|
mode="val", |
|
epoch=epoch, |
|
device=device, |
|
) |
|
|
|
print("Using training loss as validation loss") |
|
if "val_loss" in val_metrics: |
|
val_loss = val_metrics["val_loss"] |
|
else: |
|
val_loss = val_metrics["val_loss_ensembleAverage"] |
|
|
|
print(f"Validation Loss {val_loss}") |
|
print(f"validation image metrics: {val_img_metrics}") |
|
save_results( |
|
val_img_metrics, |
|
os.path.join(config.res_dir, config.experiment_name), |
|
split=f"test_epoch_{epoch}", |
|
) |
|
print( |
|
f"\nLogged validation epoch {epoch} metrics to path {os.path.join(config.res_dir, config.experiment_name)}" |
|
) |
|
|
|
|
|
trainlog[epoch] = {**train_metrics, **val_metrics} |
|
checkpoint(trainlog, config) |
|
if is_better(val_loss, best_loss): |
|
best_loss = val_loss |
|
save_model(config, epoch, model, "model") |
|
else: |
|
trainlog[epoch] = {**train_metrics} |
|
checkpoint(trainlog, config) |
|
|
|
|
|
save_model(config, epoch, model, f"model_epoch_{epoch}") |
|
|
|
print(f"Completed current epoch of experiment {config.experiment_name}.") |
|
|
|
|
|
print("Testing best epoch . . .") |
|
load_checkpoint(config, config.res_dir, model, "model") |
|
|
|
model.eval() |
|
model.netG.eval() |
|
|
|
test_metrics, test_img_metrics = iterate( |
|
model, |
|
data_loader=test_loader, |
|
config=config, |
|
writer=writer, |
|
mode="test", |
|
epoch=epoch, |
|
device=device, |
|
) |
|
|
|
if "test_loss" in test_metrics: |
|
test_loss = test_metrics["test_loss"] |
|
else: |
|
test_loss = test_metrics["test_loss_ensembleAverage"] |
|
print(f"Test Loss {test_loss}") |
|
print(f"\nTest image metrics: {test_img_metrics}") |
|
save_results( |
|
test_img_metrics, |
|
os.path.join(config.res_dir, config.experiment_name), |
|
split="test", |
|
) |
|
print( |
|
f"\nLogged test metrics to path {os.path.join(config.res_dir, config.experiment_name)}" |
|
) |
|
|
|
|
|
writer.close() |
|
|
|
print(f"Finished training experiment {config.experiment_name}.") |
|
|
|
|
|
if __name__ == "__main__": |
|
main(config) |
|
exit() |
|
|