|
""" |
|
Python script to obtain Deep Ensemble predictions by collecting each instance's pre-computed predictions. |
|
Each member's predictions are first meant to be pre-computed via test_reconstruct.py, with the outputs exported, |
|
and read again in this script. Online ensembling is currently not implemented as this may exceed hardware constraints. |
|
For every ensemble member, the path to its output directory has to be specified in the list 'ensemble_paths'. |
|
""" |
|
|
|
import os |
|
import sys |
|
import torch |
|
import numpy as np |
|
from tqdm import tqdm |
|
from natsort import natsorted |
|
|
|
dirname = os.path.dirname(os.path.abspath(__file__)) |
|
sys.path.append(os.path.dirname(dirname)) |
|
|
|
from data.dataLoader import SEN12MSCR, SEN12MSCRTS |
|
from src.learning.metrics import img_metrics, avg_img_metrics |
|
from train_reconstruct import recursive_todevice, compute_uce_auce, export, plot_img, save_results |
|
|
|
epoch = 1 |
|
root = '/home/data/' |
|
mode = 'test' |
|
in_time = 3 |
|
region = 'all' |
|
max_samples = 1e9 |
|
uncertainty = 'both' |
|
ensemble = True |
|
pixelwise = True |
|
export_path = None |
|
|
|
|
|
ensemble_paths = [os.path.join(dirname, 'inference', f'diagonal_1/export/epoch_{epoch}/{mode}'), |
|
os.path.join(dirname, 'inference', f'diagonal_2/export/epoch_{epoch}/{mode}'), |
|
os.path.join(dirname, 'inference', f'diagonal_3/export/epoch_{epoch}/{mode}'), |
|
os.path.join(dirname, 'inference', f'diagonal_4/export/epoch_{epoch}/{mode}'), |
|
os.path.join(dirname, 'inference', f'diagonal_5/export/epoch_{epoch}/{mode}'), |
|
] |
|
|
|
n_ensemble = len(ensemble_paths) |
|
print('Ensembling over model predictions:') |
|
for instance in ensemble_paths: print(instance) |
|
|
|
if export_path: |
|
plot_dir = os.path.join(export_path, 'plots', f'epoch_{epoch}', f'{mode}') |
|
export_dir = os.path.join(export_path, 'export', f'epoch_{epoch}', f'{mode}') |
|
|
|
|
|
def prepare_data_multi(batch, device, batch_size=1, use_sar=True): |
|
in_S2 = recursive_todevice(torch.tensor(batch['input']['S2']), device) |
|
in_S2_td = recursive_todevice(torch.tensor(batch['input']['S2 TD']), device) |
|
if batch_size>1: in_S2_td = torch.stack((in_S2_td)).T |
|
in_m = recursive_todevice(torch.tensor(batch['input']['masks']), device) |
|
target_S2 = recursive_todevice(torch.tensor(batch['target']['S2']), device) |
|
y = target_S2 |
|
|
|
if use_sar: |
|
in_S1 = recursive_todevice(torch.tensor(batch['input']['S1']), device) |
|
in_S1_td = recursive_todevice(torch.tensor(batch['input']['S1 TD']), device) |
|
if batch_size>1: in_S1_td = torch.stack((in_S1_td)).T |
|
x = torch.cat((torch.stack(in_S1,dim=1), torch.stack(in_S2,dim=1)),dim=2) |
|
dates = torch.stack((torch.tensor(in_S1_td),torch.tensor(in_S2_td))).float().mean(dim=0).to(device) |
|
else: |
|
x = in_S2 |
|
dates = torch.tensor(in_S2_td).float().to(device) |
|
|
|
return x.unsqueeze(dim=0), y.unsqueeze(dim=0), in_m.unsqueeze(dim=0), dates |
|
|
|
|
|
def main(): |
|
|
|
|
|
dataPath = ensemble_paths[0] |
|
samples = natsorted([os.path.join(dataPath, f) for f in os.listdir(dataPath) if (os.path.isfile(os.path.join(dataPath, f)) and "_pred.npy" in f)]) |
|
|
|
|
|
img_meter = avg_img_metrics() |
|
vars_aleatoric = [] |
|
errs, errs_se, errs_ae = [], [], [] |
|
|
|
import_data_path = os.path.join(os.getcwd(), 'util', 'precomputed', f'generic_{in_time}_{mode}_{region}_s2cloudless_mask.npy') |
|
import_data_path = import_data_path if os.path.isfile(import_data_path) else None |
|
dt_test = SEN12MSCRTS(os.path.join(root, 'SEN12MSCRTS'), split=mode, region=region, sample_type="cloudy_cloudfree" , n_input_samples=in_time, import_data_path=import_data_path) |
|
if len(dt_test.paths) != len(samples): raise AssertionError |
|
|
|
|
|
for idx, sample_mean in enumerate(tqdm(samples)): |
|
if idx >= max_samples: break |
|
|
|
|
|
batch = dt_test.getsample(idx) |
|
x, y, in_m, _ = prepare_data_multi(batch, 'cuda', batch_size=1, use_sar=False) |
|
|
|
try: |
|
mean, var = [], [] |
|
for path in ensemble_paths: |
|
|
|
mean.append(np.load(os.path.join(path, os.path.basename(sample_mean)))) |
|
|
|
sample_var = sample_mean.replace('_pred', '_covar') |
|
if not os.path.isfile(os.path.join(path, os.path.basename(sample_var))): |
|
sample_var = sample_mean.replace('_pred', '_var') |
|
var.append(np.load(os.path.join(path, os.path.basename(sample_var)))) |
|
except: |
|
|
|
|
|
print(f'Skipped sample {idx}, missing data.') |
|
continue |
|
mean, var = np.array(mean), np.array(var) |
|
|
|
|
|
if len(var.shape) > 4: |
|
var = np.moveaxis(np.diagonal(var, axis1=1, axis2=2), -1, 1) |
|
|
|
|
|
|
|
if ensemble: |
|
|
|
|
|
mean_ensemble = 1/n_ensemble * np.sum(mean, axis=0) |
|
|
|
if uncertainty == 'aleatoric': |
|
|
|
var_ensemble = 1/n_ensemble * np.sum(var, axis=0) |
|
elif uncertainty == 'epistemic': |
|
|
|
var_ensemble = 1/n_ensemble * np.sum(mean**2, axis=0) - mean_ensemble**2 |
|
elif uncertainty == 'both': |
|
|
|
var_ensemble = 1/n_ensemble * np.sum(var + mean**2, axis=0) - mean_ensemble**2 |
|
else: raise NotImplementedError |
|
else: mean_ensemble, var_ensemble = mean[0], var[0] |
|
|
|
mean_ensemble = torch.tensor(mean_ensemble).cuda() |
|
var_ensemble = torch.tensor(var_ensemble).cuda() |
|
|
|
|
|
extended_metrics = img_metrics(y[0], mean_ensemble.unsqueeze(dim=0), |
|
var=var_ensemble.unsqueeze(dim=0), |
|
pixelwise=pixelwise) |
|
img_meter.add(extended_metrics) |
|
|
|
if pixelwise: |
|
vars_aleatoric.extend(extended_metrics['pixelwise var']) |
|
errs.extend(extended_metrics['pixelwise error']) |
|
errs_se.extend(extended_metrics['pixelwise se']) |
|
errs_ae.extend(extended_metrics['pixelwise ae']) |
|
else: |
|
vars_aleatoric.append(extended_metrics['mean var']) |
|
errs.append(extended_metrics['error']) |
|
errs_se.append(extended_metrics['mean se']) |
|
errs_ae.append(extended_metrics['mean ae']) |
|
|
|
if export_path: |
|
plot_img(mean_ensemble.unsqueeze(dim=0), 'pred', plot_dir, file_id=idx) |
|
plot_img(x[0], 'in', plot_dir, file_id=idx) |
|
plot_img(var_ensemble.mean(dim=0, keepdims=True).expand(3, *var_ensemble.shape[1:]).unsqueeze(dim=0), 'var', plot_dir, file_id=idx) |
|
export(mean_ensemble[None], 'pred', export_dir, file_id=idx) |
|
export(var_ensemble[None], 'var', export_dir, file_id=idx) |
|
|
|
|
|
|
|
uce_l2, auce_l2 = compute_uce_auce(vars_aleatoric, errs, len(vars_aleatoric), percent=5, l2=True, mode=mode, step=0) |
|
|
|
|
|
img_meter.value()['UCE SE'] = uce_l2.cpu().numpy().item() |
|
img_meter.value()['AUCE SE'] = auce_l2.cpu().numpy().item() |
|
|
|
print(f'{mode} split image metrics: {img_meter.value()}') |
|
if export_path: |
|
np.save(os.path.join(export_path, f'pred_var_{uncertainty}.npy'), vars_aleatoric) |
|
np.save(os.path.join(export_path, 'errors.npy'), errs) |
|
save_results(img_meter.value(), export_path, split=mode) |
|
print(f'Exported predictions to path {export_path}') |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
exit() |