File size: 9,214 Bytes
3c8ff2e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
"""
 Python script to obtain Deep Ensemble predictions by collecting each instance's pre-computed predictions.
    Each member's predictions are first meant to be pre-computed via test_reconstruct.py, with the outputs exported,
    and read again in this script. Online ensembling is currently not implemented as this may exceed hardware constraints.
    For every ensemble member, the path to its output directory has to be specified in the list 'ensemble_paths'.
"""

import os
import sys
import torch
import numpy as np
from tqdm import tqdm
from natsort import natsorted

dirname = os.path.dirname(os.path.abspath(__file__))
sys.path.append(os.path.dirname(dirname))

from data.dataLoader import SEN12MSCR, SEN12MSCRTS
from src.learning.metrics import img_metrics, avg_img_metrics
from train_reconstruct import recursive_todevice, compute_uce_auce, export, plot_img, save_results

epoch       = 1
root		= '/home/data/' # path to directory containing dataset
mode        = 'test'        # split to evaluate on
in_time     = 3             # length of input time series
region      = 'all'         # region of areas of interest
max_samples = 1e9           # maximum count of samples to consider
uncertainty = 'both'        # e.g. 'aleatoric', 'epistemic', 'both' --- only matters if ensemble==True
ensemble    = True          # whether to compute ensemble mean and var or not
pixelwise   = True          # whether to summarize errors and variances for image-based AUCE and UCE or keep pixel-based statistics
export_path = None          # where to export ensemble statistics, set to None if no writing to files is desired

# define path to find the individual ensembe member's predictions in
ensemble_paths = [os.path.join(dirname, 'inference', f'diagonal_1/export/epoch_{epoch}/{mode}'),
                  os.path.join(dirname, 'inference', f'diagonal_2/export/epoch_{epoch}/{mode}'),
                  os.path.join(dirname, 'inference', f'diagonal_3/export/epoch_{epoch}/{mode}'),
                  os.path.join(dirname, 'inference', f'diagonal_4/export/epoch_{epoch}/{mode}'), 
                  os.path.join(dirname, 'inference', f'diagonal_5/export/epoch_{epoch}/{mode}'),
                  ]

n_ensemble = len(ensemble_paths)
print('Ensembling over model predictions:')
for instance in ensemble_paths: print(instance)

if export_path:
    plot_dir    = os.path.join(export_path, 'plots', f'epoch_{epoch}', f'{mode}')
    export_dir  = os.path.join(export_path, 'export', f'epoch_{epoch}', f'{mode}')


def prepare_data_multi(batch, device, batch_size=1, use_sar=True):
    in_S2     = recursive_todevice(torch.tensor(batch['input']['S2']), device)
    in_S2_td  = recursive_todevice(torch.tensor(batch['input']['S2 TD']), device)
    if batch_size>1: in_S2_td = torch.stack((in_S2_td)).T
    in_m      = recursive_todevice(torch.tensor(batch['input']['masks']), device) 
    target_S2 = recursive_todevice(torch.tensor(batch['target']['S2']), device)
    y         = target_S2 

    if use_sar: 
        in_S1 = recursive_todevice(torch.tensor(batch['input']['S1']), device)
        in_S1_td = recursive_todevice(torch.tensor(batch['input']['S1 TD']), device)
        if batch_size>1: in_S1_td = torch.stack((in_S1_td)).T
        x     = torch.cat((torch.stack(in_S1,dim=1), torch.stack(in_S2,dim=1)),dim=2)
        dates = torch.stack((torch.tensor(in_S1_td),torch.tensor(in_S2_td))).float().mean(dim=0).to(device)
    else:
        x     = in_S2 # torch.stack(in_S2,dim=1)
        dates = torch.tensor(in_S2_td).float().to(device)
    
    return x.unsqueeze(dim=0), y.unsqueeze(dim=0), in_m.unsqueeze(dim=0), dates


def main():

    # list all predictions of the first ensemble member
    dataPath = ensemble_paths[0]
    samples  = natsorted([os.path.join(dataPath, f) for f in os.listdir(dataPath) if (os.path.isfile(os.path.join(dataPath, f)) and "_pred.npy" in f)])

    # collect sample-averaged uncertainties and errors
    img_meter  = avg_img_metrics()
    vars_aleatoric = []
    errs, errs_se, errs_ae = [], [], []

    import_data_path = os.path.join(os.getcwd(), 'util', 'precomputed', f'generic_{in_time}_{mode}_{region}_s2cloudless_mask.npy')
    import_data_path = import_data_path if os.path.isfile(import_data_path) else None
    dt_test          = SEN12MSCRTS(os.path.join(root, 'SEN12MSCRTS'), split=mode, region=region, sample_type="cloudy_cloudfree" , n_input_samples=in_time, import_data_path=import_data_path)
    if len(dt_test.paths) != len(samples): raise AssertionError

    # iterate over the ensemble member's mean predictions
    for idx, sample_mean in enumerate(tqdm(samples)):
        if idx >= max_samples: break # exceeded desired sample count

        # fetch target data and cloud masks of idx-th sample from
        batch = dt_test.getsample(idx) # ... in order to compute metrics
        x, y, in_m, _ = prepare_data_multi(batch, 'cuda', batch_size=1, use_sar=False)

        try:
            mean, var = [], []
            for path in ensemble_paths: # for each ensemble member ...
                # ... load the member's mean predictions and ...
                mean.append(np.load(os.path.join(path, os.path.basename(sample_mean))))
                # ... load the member's covariance or var predictions
                sample_var  = sample_mean.replace('_pred', '_covar') 
                if not os.path.isfile(os.path.join(path, os.path.basename(sample_var))):
                    sample_var  = sample_mean.replace('_pred', '_var') 
                var.append(np.load(os.path.join(path, os.path.basename(sample_var))))
        except:
            # skip any sample for which not all members provide predictions
            # (note: we also next'ed the dataloader's sample already)
            print(f'Skipped sample {idx}, missing data.')
            continue
        mean, var = np.array(mean), np.array(var)

        # get the variances from the covariance matrix
        if len(var.shape) > 4: # loaded covariance matrix
            var = np.moveaxis(np.diagonal(var, axis1=1, axis2=2), -1, 1)

        # combine predictions

        if ensemble:
            # get ensemble estimate and epistemic uncertainty,
            # approximate 1 Gaussian by mixture parameter ensembling
            mean_ensemble = 1/n_ensemble * np.sum(mean, axis=0)
            
            if uncertainty == 'aleatoric':
                # average the members' aleatoric uncertainties
                var_ensemble  = 1/n_ensemble * np.sum(var, axis=0)
            elif uncertainty == 'epistemic':     
                # compute average variance of ensemble predictions      
                var_ensemble  = 1/n_ensemble * np.sum(mean**2, axis=0) - mean_ensemble**2
            elif uncertainty == 'both':
                # combine both
                var_ensemble  = 1/n_ensemble * np.sum(var + mean**2, axis=0) - mean_ensemble**2
            else: raise NotImplementedError
        else: mean_ensemble, var_ensemble = mean[0], var[0]

        mean_ensemble = torch.tensor(mean_ensemble).cuda()
        var_ensemble  = torch.tensor(var_ensemble).cuda()

        # compute test metrics on ensemble prediction
        extended_metrics = img_metrics(y[0], mean_ensemble.unsqueeze(dim=0),
                                       var=var_ensemble.unsqueeze(dim=0), 
                                       pixelwise=pixelwise)
        img_meter.add(extended_metrics) # accumulate performances over the entire split

        if pixelwise: # collect variances and errors
            vars_aleatoric.extend(extended_metrics['pixelwise var']) 
            errs.extend(extended_metrics['pixelwise error']) 
            errs_se.extend(extended_metrics['pixelwise se']) 
            errs_ae.extend(extended_metrics['pixelwise ae']) 
        else:
            vars_aleatoric.append(extended_metrics['mean var']) 
            errs.append(extended_metrics['error'])
            errs_se.append(extended_metrics['mean se'])
            errs_ae.append(extended_metrics['mean ae'])

        if export_path: # plot and export ensemble predictions
            plot_img(mean_ensemble.unsqueeze(dim=0), 'pred', plot_dir, file_id=idx)
            plot_img(x[0], 'in', plot_dir, file_id=idx)
            plot_img(var_ensemble.mean(dim=0, keepdims=True).expand(3, *var_ensemble.shape[1:]).unsqueeze(dim=0), 'var', plot_dir, file_id=idx)
            export(mean_ensemble[None], 'pred', export_dir, file_id=idx)
            export(var_ensemble[None], 'var', export_dir, file_id=idx)


    # compute UCE and AUCE
    uce_l2, auce_l2 = compute_uce_auce(vars_aleatoric, errs, len(vars_aleatoric), percent=5, l2=True, mode=mode, step=0)

    # no need for a running mean here
    img_meter.value()['UCE SE']  = uce_l2.cpu().numpy().item()
    img_meter.value()['AUCE SE'] = auce_l2.cpu().numpy().item()

    print(f'{mode} split image metrics: {img_meter.value()}')
    if export_path:
        np.save(os.path.join(export_path, f'pred_var_{uncertainty}.npy'), vars_aleatoric)
        np.save(os.path.join(export_path, 'errors.npy'), errs)
        save_results(img_meter.value(), export_path, split=mode)
        print(f'Exported predictions to path {export_path}')


if __name__ == "__main__":
    main()
    exit()