File size: 4,022 Bytes
3c8ff2e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import os
import sys
import torch
import numpy as np

sys.path.append(os.path.dirname(os.getcwd()))
sys.path.append(os.path.dirname(os.path.dirname(os.getcwd())))
from util import pytorch_ssim

class Metric(object):
    """Base class for all metrics.
    From: https://github.com/pytorch/tnt/blob/master/torchnet/meter/meter.py
    """

    def reset(self): pass
    def add(self):   pass
    def value(self): pass


def img_metrics(target, pred, var=None, pixelwise=True):
    rmse = torch.sqrt(torch.mean(torch.square(target - pred)))
    psnr = 20 * torch.log10(1 / rmse)
    mae = torch.mean(torch.abs(target - pred))
    
    # spectral angle mapper
    mat = target * pred
    mat = torch.sum(mat, 1)
    mat = torch.div(mat, torch.sqrt(torch.sum(target * target, 1)))
    mat = torch.div(mat, torch.sqrt(torch.sum(pred * pred, 1)))
    sam = torch.mean(torch.acos(torch.clamp(mat, -1, 1))*torch.tensor(180)/torch.pi)

    ssim = pytorch_ssim.ssim(target, pred)

    metric_dict = {'RMSE': rmse.cpu().numpy().item(),
                   'MAE': mae.cpu().numpy().item(),
                   'PSNR': psnr.cpu().numpy().item(),
                   'SAM': sam.cpu().numpy().item(),
                   'SSIM': ssim.cpu().numpy().item()}

    # evaluate the (optional) variance maps
    if var is not None:
        error = target - pred
        # average across the spectral dimensions
        se = torch.square(error)
        ae = torch.abs(error)

        # collect sample-wise error, AE, SE and uncertainties
 
        # define a sample as 1 image and provide image-wise statistics
        errvar_samplewise = {'error': error.nanmean().cpu().numpy().item(),
                            'mean ae': ae.nanmean().cpu().numpy().item(),
                            'mean se': se.nanmean().cpu().numpy().item(),
                            'mean var': var.nanmean().cpu().numpy().item()}
        if pixelwise:
            # define a sample as 1 multivariate pixel and provide image-wise statistics
            errvar_samplewise = {**errvar_samplewise, **{'pixelwise error': error.nanmean(0).nanmean(0).flatten().cpu().numpy(),
                                                        'pixelwise ae': ae.nanmean(0).nanmean(0).flatten().cpu().numpy(),
                                                        'pixelwise se': se.nanmean(0).nanmean(0).flatten().cpu().numpy(),
                                                        'pixelwise var': var.nanmean(0).nanmean(0).flatten().cpu().numpy()}}

        metric_dict     = {**metric_dict, **errvar_samplewise}

    return metric_dict

class avg_img_metrics(Metric):
    def __init__(self):
        super().__init__()
        self.n_samples = 0
        self.metrics   = ['RMSE', 'MAE', 'PSNR','SAM','SSIM']
        self.metrics  += ['error', 'mean se', 'mean ae', 'mean var']

        self.running_img_metrics = {}
        self.running_nonan_count = {}
        self.reset()

    def reset(self):
        for metric in self.metrics: 
            self.running_nonan_count[metric] = 0
            self.running_img_metrics[metric] = np.nan

    def add(self, metrics_dict):
        for key, val in metrics_dict.items():
            # skip variables not registered
            if key not in self.metrics: continue
            # filter variables not translated to numpy yet
            if torch.is_tensor(val): continue
            if isinstance(val, tuple): val=val[0]

            # only keep a running mean of non-nan values
            if np.isnan(val): continue

            if not self.running_nonan_count[key]: 
                self.running_nonan_count[key] = 1
                self.running_img_metrics[key] = val
            else: 
                self.running_nonan_count[key]+= 1
                self.running_img_metrics[key] = (self.running_nonan_count[key]-1)/self.running_nonan_count[key] * self.running_img_metrics[key] \
                                                + 1/self.running_nonan_count[key] * val

    def value(self):
        return self.running_img_metrics