|
import os |
|
import sys |
|
import torch |
|
import numpy as np |
|
|
|
sys.path.append(os.path.dirname(os.getcwd())) |
|
sys.path.append(os.path.dirname(os.path.dirname(os.getcwd()))) |
|
from util import pytorch_ssim |
|
|
|
class Metric(object): |
|
"""Base class for all metrics. |
|
From: https://github.com/pytorch/tnt/blob/master/torchnet/meter/meter.py |
|
""" |
|
|
|
def reset(self): pass |
|
def add(self): pass |
|
def value(self): pass |
|
|
|
|
|
def img_metrics(target, pred, var=None, pixelwise=True): |
|
rmse = torch.sqrt(torch.mean(torch.square(target - pred))) |
|
psnr = 20 * torch.log10(1 / rmse) |
|
mae = torch.mean(torch.abs(target - pred)) |
|
|
|
|
|
mat = target * pred |
|
mat = torch.sum(mat, 1) |
|
mat = torch.div(mat, torch.sqrt(torch.sum(target * target, 1))) |
|
mat = torch.div(mat, torch.sqrt(torch.sum(pred * pred, 1))) |
|
sam = torch.mean(torch.acos(torch.clamp(mat, -1, 1))*torch.tensor(180)/torch.pi) |
|
|
|
ssim = pytorch_ssim.ssim(target, pred) |
|
|
|
metric_dict = {'RMSE': rmse.cpu().numpy().item(), |
|
'MAE': mae.cpu().numpy().item(), |
|
'PSNR': psnr.cpu().numpy().item(), |
|
'SAM': sam.cpu().numpy().item(), |
|
'SSIM': ssim.cpu().numpy().item()} |
|
|
|
|
|
if var is not None: |
|
error = target - pred |
|
|
|
se = torch.square(error) |
|
ae = torch.abs(error) |
|
|
|
|
|
|
|
|
|
errvar_samplewise = {'error': error.nanmean().cpu().numpy().item(), |
|
'mean ae': ae.nanmean().cpu().numpy().item(), |
|
'mean se': se.nanmean().cpu().numpy().item(), |
|
'mean var': var.nanmean().cpu().numpy().item()} |
|
if pixelwise: |
|
|
|
errvar_samplewise = {**errvar_samplewise, **{'pixelwise error': error.nanmean(0).nanmean(0).flatten().cpu().numpy(), |
|
'pixelwise ae': ae.nanmean(0).nanmean(0).flatten().cpu().numpy(), |
|
'pixelwise se': se.nanmean(0).nanmean(0).flatten().cpu().numpy(), |
|
'pixelwise var': var.nanmean(0).nanmean(0).flatten().cpu().numpy()}} |
|
|
|
metric_dict = {**metric_dict, **errvar_samplewise} |
|
|
|
return metric_dict |
|
|
|
class avg_img_metrics(Metric): |
|
def __init__(self): |
|
super().__init__() |
|
self.n_samples = 0 |
|
self.metrics = ['RMSE', 'MAE', 'PSNR','SAM','SSIM'] |
|
self.metrics += ['error', 'mean se', 'mean ae', 'mean var'] |
|
|
|
self.running_img_metrics = {} |
|
self.running_nonan_count = {} |
|
self.reset() |
|
|
|
def reset(self): |
|
for metric in self.metrics: |
|
self.running_nonan_count[metric] = 0 |
|
self.running_img_metrics[metric] = np.nan |
|
|
|
def add(self, metrics_dict): |
|
for key, val in metrics_dict.items(): |
|
|
|
if key not in self.metrics: continue |
|
|
|
if torch.is_tensor(val): continue |
|
if isinstance(val, tuple): val=val[0] |
|
|
|
|
|
if np.isnan(val): continue |
|
|
|
if not self.running_nonan_count[key]: |
|
self.running_nonan_count[key] = 1 |
|
self.running_img_metrics[key] = val |
|
else: |
|
self.running_nonan_count[key]+= 1 |
|
self.running_img_metrics[key] = (self.running_nonan_count[key]-1)/self.running_nonan_count[key] * self.running_img_metrics[key] \ |
|
+ 1/self.running_nonan_count[key] * val |
|
|
|
def value(self): |
|
return self.running_img_metrics |