XavierJiezou's picture
Upload folder using huggingface_hub
3c8ff2e verified
import os
import sys
import torch
import numpy as np
sys.path.append(os.path.dirname(os.getcwd()))
sys.path.append(os.path.dirname(os.path.dirname(os.getcwd())))
from util import pytorch_ssim
class Metric(object):
"""Base class for all metrics.
From: https://github.com/pytorch/tnt/blob/master/torchnet/meter/meter.py
"""
def reset(self): pass
def add(self): pass
def value(self): pass
def img_metrics(target, pred, var=None, pixelwise=True):
rmse = torch.sqrt(torch.mean(torch.square(target - pred)))
psnr = 20 * torch.log10(1 / rmse)
mae = torch.mean(torch.abs(target - pred))
# spectral angle mapper
mat = target * pred
mat = torch.sum(mat, 1)
mat = torch.div(mat, torch.sqrt(torch.sum(target * target, 1)))
mat = torch.div(mat, torch.sqrt(torch.sum(pred * pred, 1)))
sam = torch.mean(torch.acos(torch.clamp(mat, -1, 1))*torch.tensor(180)/torch.pi)
ssim = pytorch_ssim.ssim(target, pred)
metric_dict = {'RMSE': rmse.cpu().numpy().item(),
'MAE': mae.cpu().numpy().item(),
'PSNR': psnr.cpu().numpy().item(),
'SAM': sam.cpu().numpy().item(),
'SSIM': ssim.cpu().numpy().item()}
# evaluate the (optional) variance maps
if var is not None:
error = target - pred
# average across the spectral dimensions
se = torch.square(error)
ae = torch.abs(error)
# collect sample-wise error, AE, SE and uncertainties
# define a sample as 1 image and provide image-wise statistics
errvar_samplewise = {'error': error.nanmean().cpu().numpy().item(),
'mean ae': ae.nanmean().cpu().numpy().item(),
'mean se': se.nanmean().cpu().numpy().item(),
'mean var': var.nanmean().cpu().numpy().item()}
if pixelwise:
# define a sample as 1 multivariate pixel and provide image-wise statistics
errvar_samplewise = {**errvar_samplewise, **{'pixelwise error': error.nanmean(0).nanmean(0).flatten().cpu().numpy(),
'pixelwise ae': ae.nanmean(0).nanmean(0).flatten().cpu().numpy(),
'pixelwise se': se.nanmean(0).nanmean(0).flatten().cpu().numpy(),
'pixelwise var': var.nanmean(0).nanmean(0).flatten().cpu().numpy()}}
metric_dict = {**metric_dict, **errvar_samplewise}
return metric_dict
class avg_img_metrics(Metric):
def __init__(self):
super().__init__()
self.n_samples = 0
self.metrics = ['RMSE', 'MAE', 'PSNR','SAM','SSIM']
self.metrics += ['error', 'mean se', 'mean ae', 'mean var']
self.running_img_metrics = {}
self.running_nonan_count = {}
self.reset()
def reset(self):
for metric in self.metrics:
self.running_nonan_count[metric] = 0
self.running_img_metrics[metric] = np.nan
def add(self, metrics_dict):
for key, val in metrics_dict.items():
# skip variables not registered
if key not in self.metrics: continue
# filter variables not translated to numpy yet
if torch.is_tensor(val): continue
if isinstance(val, tuple): val=val[0]
# only keep a running mean of non-nan values
if np.isnan(val): continue
if not self.running_nonan_count[key]:
self.running_nonan_count[key] = 1
self.running_img_metrics[key] = val
else:
self.running_nonan_count[key]+= 1
self.running_img_metrics[key] = (self.running_nonan_count[key]-1)/self.running_nonan_count[key] * self.running_img_metrics[key] \
+ 1/self.running_nonan_count[key] * val
def value(self):
return self.running_img_metrics