Shawn001's picture
Upload 53 files
c2c125c
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
#copyright (c) go-hiroaki & Chokurei
#email: [email protected]
# [email protected]
#
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
eps = 1e-6
def _binarize(y_data, threshold):
"""
args:
y_data : [float] 4-d tensor in [batch_size, channels, img_rows, img_cols]
threshold : [float] [0.0, 1.0]
return 4-d binarized y_data
"""
y_data[y_data < threshold] = 0.0
y_data[y_data >= threshold] = 1.0
return y_data
def _argmax(y_data, dim):
"""
args:
y_data : 4-d tensor in [batch_size, chs, img_rows, img_cols]
dim : int
return 3-d [int] y_data
"""
return torch.argmax(y_data, dim).int()
def _get_tp(y_pred, y_true):
"""
args:
y_true : [int] 3-d in [batch_size, img_rows, img_cols]
y_pred : [int] 3-d in [batch_size, img_rows, img_cols]
return [float] true_positive
"""
return torch.sum(y_true * y_pred).float()
def _get_fp(y_pred, y_true):
"""
args:
y_true : 3-d ndarray in [batch_size, img_rows, img_cols]
y_pred : 3-d ndarray in [batch_size, img_rows, img_cols]
return [float] false_positive
"""
return torch.sum((1 - y_true) * y_pred).float()
def _get_tn(y_pred, y_true):
"""
args:
y_true : 3-d ndarray in [batch_size, img_rows, img_cols]
y_pred : 3-d ndarray in [batch_size, img_rows, img_cols]
return [float] true_negative
"""
return torch.sum((1 - y_true) * (1 - y_pred)).float()
def _get_fn(y_pred, y_true):
"""
args:
y_true : 3-d ndarray in [batch_size, img_rows, img_cols]
y_pred : 3-d ndarray in [batch_size, img_rows, img_cols]
return [float] false_negative
"""
return torch.sum(y_true * (1 - y_pred)).float()
def _get_weights(y_true, nb_ch):
"""
args:
y_true : 3-d ndarray in [batch_size, img_rows, img_cols]
nb_ch : int
return [float] weights
"""
batch_size, img_rows, img_cols = y_true.shape
pixels = batch_size * img_rows * img_cols
weights = [torch.sum(y_true==ch).item() / pixels for ch in range(nb_ch)]
return weights
class CFMatrix(object):
def __init__(self, des=None):
self.des = des
def __repr__(self):
return "ConfusionMatrix"
def __call__(self, y_pred, y_true, ignore_index, threshold=0.5):
"""
args:
y_true : 3-d ndarray in [batch_size, img_rows, img_cols]
y_pred : 4-d ndarray in [batch_size, chs, img_rows, img_cols]
threshold : [0.0, 1.0]
return confusion matrix
"""
batch_size, img_rows, img_cols = y_pred.shape
chs = ignore_index
device = y_true.device
if chs == 1:
y_pred = _binarize(y_pred, threshold)
y_true = _binarize(y_true, threshold)
nb_tp = _get_tp(y_pred, y_true)
nb_fp = _get_fp(y_pred, y_true)
nb_tn = _get_tn(y_pred, y_true)
nb_fn = _get_fn(y_pred, y_true)
mperforms = [nb_tp, nb_fp, nb_tn, nb_fn]
performs = None
else:
performs = torch.zeros(chs, 4).to(device)
weights = _get_weights(y_true, chs)
for ch in range(chs):
y_true_ch = torch.zeros(batch_size, img_rows, img_cols)
y_false_ch = torch.zeros(batch_size, img_rows, img_cols)
y_pred_ch = torch.zeros(batch_size, img_rows, img_cols)
y_true_ch[y_true == ch] = 1
y_false_ch[torch.logical_and((y_true != ch), (y_true != ignore_index))] = 1
y_pred_ch[y_pred == ch] = 1
nb_tp = _get_tp(y_pred_ch, y_true_ch)
nb_fp = torch.sum(y_false_ch * y_pred_ch).float()
nb_tn = torch.sum(y_false_ch * (1 - y_pred_ch)).float()
nb_fn = _get_fn(y_pred_ch, y_true_ch)
performs[int(ch), :] = torch.FloatTensor([nb_tp, nb_fp, nb_tn, nb_fn])
mperforms = sum([i*j for (i, j) in zip(performs, weights)])
return mperforms, performs
class OAAcc(object):
def __init__(self, des="Overall Accuracy"):
self.des = des
def __repr__(self):
return "OAcc"
def __call__(self, y_pred, y_true, threshold=0.5):
"""
args:
y_true : 4-d ndarray in [batch_size, chs, img_rows, img_cols]
y_pred : 4-d ndarray in [batch_size, chs, img_rows, img_cols]
threshold : [0.0, 1.0]
return (tp+tn)/total
"""
batch_size, chs, img_rows, img_cols = y_true.shape
device = y_true.device
if chs == 1:
y_pred = _binarize(y_pred, threshold)
y_true = _binarize(y_true, threshold)
else:
y_pred = _argmax(y_pred, 1)
y_true = _argmax(y_true, 1)
nb_tp_tn = torch.sum(y_true == y_pred).float()
mperforms = nb_tp_tn / (batch_size * img_rows * img_cols)
performs = None
return mperforms, performs
class Precision(object):
def __init__(self, des="Precision"):
self.des = des
def __repr__(self):
return "Prec"
def __call__(self, y_pred, y_true, threshold=0.5):
"""
args:
y_true : 4-d ndarray in [batch_size, chs, img_rows, img_cols]
y_pred : 4-d ndarray in [batch_size, chs, img_rows, img_cols]
threshold : [0.0, 1.0]
return tp/(tp+fp)
"""
batch_size, chs, img_rows, img_cols = y_true.shape
device = y_true.device
if chs == 1:
y_pred = _binarize(y_pred, threshold)
y_true = _binarize(y_true, threshold)
nb_tp = _get_tp(y_pred, y_true)
nb_fp = _get_fp(y_pred, y_true)
mperforms = nb_tp / (nb_tp + nb_fp + esp)
performs = None
else:
y_pred = _argmax(y_pred, 1)
y_true = _argmax(y_true, 1)
performs = torch.zeros(chs, 1).to(device)
weights = _get_weights(y_true, chs)
for ch in range(chs):
y_true_ch = torch.zeros(batch_size, img_rows, img_cols)
y_pred_ch = torch.zeros(batch_size, img_rows, img_cols)
y_true_ch[y_true == ch] = 1
y_pred_ch[y_pred == ch] = 1
nb_tp = _get_tp(y_pred_ch, y_true_ch)
nb_fp = _get_fp(y_pred_ch, y_true_ch)
performs[int(ch)] = nb_tp / (nb_tp + nb_fp + esp)
mperforms = sum([i*j for (i, j) in zip(performs, weights)])
return mperforms, performs
class Recall(object):
def __init__(self, des="Recall"):
self.des = des
def __repr__(self):
return "Reca"
def __call__(self, y_pred, y_true, threshold=0.5):
"""
args:
y_true : 4-d ndarray in [batch_size, chs, img_rows, img_cols]
y_pred : 4-d ndarray in [batch_size, chs, img_rows, img_cols]
threshold : [0.0, 1.0]
return tp/(tp+fn)
"""
batch_size, chs, img_rows, img_cols = y_true.shape
device = y_true.device
if chs == 1:
y_pred = _binarize(y_pred, threshold)
y_true = _binarize(y_true, threshold)
nb_tp = _get_tp(y_pred, y_true)
nb_fn = _get_fn(y_pred, y_true)
mperforms = nb_tp / (nb_tp + nb_fn + esp)
performs = None
else:
y_pred = _argmax(y_pred, 1)
y_true = _argmax(y_true, 1)
performs = torch.zeros(chs, 1).to(device)
weights = _get_weights(y_true, chs)
for ch in range(chs):
y_true_ch = torch.zeros(batch_size, img_rows, img_cols)
y_pred_ch = torch.zeros(batch_size, img_rows, img_cols)
y_true_ch[y_true == ch] = 1
y_pred_ch[y_pred == ch] = 1
nb_tp = _get_tp(y_pred_ch, y_true_ch)
nb_fn = _get_fn(y_pred_ch, y_true_ch)
performs[int(ch)] = nb_tp / (nb_tp + nb_fn + esp)
mperforms = sum([i*j for (i, j) in zip(performs, weights)])
return mperforms, performs
class F1Score(object):
def __init__(self, des="F1Score"):
self.des = des
def __repr__(self):
return "F1Sc"
def __call__(self, y_pred, y_true, threshold=0.5):
"""
args:
y_true : 4-d ndarray in [batch_size, chs, img_rows, img_cols]
y_pred : 4-d ndarray in [batch_size, chs, img_rows, img_cols]
threshold : [0.0, 1.0]
return 2*precision*recall/(precision+recall)
"""
batch_size, chs, img_rows, img_cols = y_true.shape
device = y_true.device
if chs == 1:
y_pred = _binarize(y_pred, threshold)
y_true = _binarize(y_true, threshold)
nb_tp = _get_tp(y_pred, y_true)
nb_fp = _get_fp(y_pred, y_true)
nb_fn = _get_fn(y_pred, y_true)
_precision = nb_tp / (nb_tp + nb_fp + esp)
_recall = nb_tp / (nb_tp + nb_fn + esp)
mperforms = 2 * _precision * _recall / (_precision + _recall + esp)
performs = None
else:
y_pred = _argmax(y_pred, 1)
y_true = _argmax(y_true, 1)
performs = torch.zeros(chs, 1).to(device)
weights = _get_weights(y_true, chs)
for ch in range(chs):
y_true_ch = torch.zeros(batch_size, img_rows, img_cols)
y_pred_ch = torch.zeros(batch_size, img_rows, img_cols)
y_true_ch[y_true == ch] = 1
y_pred_ch[y_pred == ch] = 1
nb_tp = _get_tp(y_pred_ch, y_true_ch)
nb_fp = _get_fp(y_pred_ch, y_true_ch)
nb_fn = _get_fn(y_pred_ch, y_true_ch)
_precision = nb_tp / (nb_tp + nb_fp + esp)
_recall = nb_tp / (nb_tp + nb_fn + esp)
performs[int(ch)] = 2 * _precision * \
_recall / (_precision + _recall + esp)
mperforms = sum([i*j for (i, j) in zip(performs, weights)])
return mperforms, performs
class Kappa(object):
def __init__(self, des="Kappa"):
self.des = des
def __repr__(self):
return "Kapp"
def __call__(self, y_pred, y_true, threshold=0.5):
"""
args:
y_true : 4-d ndarray in [batch_size, chs, img_rows, img_cols]
y_pred : 4-d ndarray in [batch_size, chs, img_rows, img_cols]
threshold : [0.0, 1.0]
return (Po-Pe)/(1-Pe)
"""
batch_size, chs, img_rows, img_cols = y_true.shape
device = y_true.device
if chs == 1:
y_pred = _binarize(y_pred, threshold)
y_true = _binarize(y_true, threshold)
nb_tp = _get_tp(y_pred, y_true)
nb_fp = _get_fp(y_pred, y_true)
nb_tn = _get_tn(y_pred, y_true)
nb_fn = _get_fn(y_pred, y_true)
nb_total = nb_tp + nb_fp + nb_tn + nb_fn
Po = (nb_tp + nb_tn) / nb_total
Pe = ((nb_tp + nb_fp) * (nb_tp + nb_fn) +
(nb_fn + nb_tn) * (nb_fp + nb_tn)) / (nb_total**2)
mperforms = (Po - Pe) / (1 - Pe + esp)
performs = None
else:
y_pred = _argmax(y_pred, 1)
y_true = _argmax(y_true, 1)
performs = torch.zeros(chs, 1).to(device)
weights = _get_weights(y_true, chs)
for ch in range(chs):
y_true_ch = torch.zeros(batch_size, img_rows, img_cols)
y_pred_ch = torch.zeros(batch_size, img_rows, img_cols)
y_true_ch[y_true == ch] = 1
y_pred_ch[y_pred == ch] = 1
nb_tp = _get_tp(y_pred_ch, y_true_ch)
nb_fp = _get_fp(y_pred_ch, y_true_ch)
nb_tn = _get_tn(y_pred_ch, y_true_ch)
nb_fn = _get_fn(y_pred_ch, y_true_ch)
nb_total = nb_tp + nb_fp + nb_tn + nb_fn
Po = (nb_tp + nb_tn) / nb_total
Pe = ((nb_tp + nb_fp) * (nb_tp + nb_fn)
+ (nb_fn + nb_tn) * (nb_fp + nb_tn)) / (nb_total**2)
performs[int(ch)] = (Po - Pe) / (1 - Pe + esp)
mperforms = sum([i*j for (i, j) in zip(performs, weights)])
return mperforms, performs
class Jaccard(object):
def __init__(self, des="Jaccard"):
self.des = des
def __repr__(self):
return "Jacc"
def __call__(self, y_pred, y_true, threshold=0.5):
"""
args:
y_true : 4-d ndarray in [batch_size, chs, img_rows, img_cols]
y_pred : 4-d ndarray in [batch_size, chs, img_rows, img_cols]
threshold : [0.0, 1.0]
return intersection / (sum-intersection)
"""
batch_size, chs, img_rows, img_cols = y_true.shape
device = y_true.device
if chs == 1:
y_pred = _binarize(y_pred, threshold)
y_true = _binarize(y_true, threshold)
_intersec = torch.sum(y_true * y_pred).float()
_sum = torch.sum(y_true + y_pred).float()
mperforms = _intersec / (_sum - _intersec + esp)
performs = None
else:
y_pred = _argmax(y_pred, 1)
y_true = _argmax(y_true, 1)
performs = torch.zeros(chs, 1).to(device)
weights = _get_weights(y_true, chs)
for ch in range(chs):
y_true_ch = torch.zeros(batch_size, img_rows, img_cols)
y_pred_ch = torch.zeros(batch_size, img_rows, img_cols)
y_true_ch[y_true == ch] = 1
y_pred_ch[y_pred == ch] = 1
_intersec = torch.sum(y_true_ch * y_pred_ch).float()
_sum = torch.sum(y_true_ch + y_pred_ch).float()
performs[int(ch)] = _intersec / (_sum - _intersec + esp)
mperforms = sum([i*j for (i, j) in zip(performs, weights)])
return mperforms, performs
class MSE(object):
def __init__(self, des="Mean Square Error"):
self.des = des
def __repr__(self):
return "MSE"
def __call__(self, y_pred, y_true, dim=1, threshold=None):
"""
args:
y_true : 4-d ndarray in [batch_size, channels, img_rows, img_cols]
y_pred : 4-d ndarray in [batch_size, channels, img_rows, img_cols]
threshold : [0.0, 1.0]
return mean_squared_error, smaller the better
"""
if threshold:
y_pred = _binarize(y_pred, threshold)
return torch.mean((y_pred - y_true) ** 2)
class PSNR(object):
def __init__(self, des="Peak Signal to Noise Ratio"):
self.des = des
def __repr__(self):
return "PSNR"
def __call__(self, y_pred, y_true, dim=1, threshold=None):
"""
args:
y_true : 4-d ndarray in [batch_size, channels, img_rows, img_cols]
y_pred : 4-d ndarray in [batch_size, channels, img_rows, img_cols]
threshold : [0.0, 1.0]
return PSNR, larger the better
"""
if threshold:
y_pred = _binarize(y_pred, threshold)
mse = torch.mean((y_pred - y_true) ** 2)
return 10 * torch.log10(1 / mse)
class SSIM(object):
'''
modified from https://github.com/jorge-pessoa/pytorch-msssim
'''
def __init__(self, des="structural similarity index"):
self.des = des
def __repr__(self):
return "SSIM"
def gaussian(self, w_size, sigma):
gauss = torch.Tensor([math.exp(-(x - w_size//2)**2/float(2*sigma**2)) for x in range(w_size)])
return gauss/gauss.sum()
def create_window(self, w_size, channel=1):
_1D_window = self.gaussian(w_size, 1.5).unsqueeze(1)
_2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
window = _2D_window.expand(channel, 1, w_size, w_size).contiguous()
return window
def __call__(self, y_pred, y_true, w_size=11, size_average=True, full=False):
"""
args:
y_true : 4-d ndarray in [batch_size, channels, img_rows, img_cols]
y_pred : 4-d ndarray in [batch_size, channels, img_rows, img_cols]
w_size : int, default 11
size_average : boolean, default True
full : boolean, default False
return ssim, larger the better
"""
# Value range can be different from 255. Other common ranges are 1 (sigmoid) and 2 (tanh).
if torch.max(y_pred) > 128:
max_val = 255
else:
max_val = 1
if torch.min(y_pred) < -0.5:
min_val = -1
else:
min_val = 0
L = max_val - min_val
padd = 0
(_, channel, height, width) = y_pred.size()
window = self.create_window(w_size, channel=channel).to(y_pred.device)
mu1 = F.conv2d(y_pred, window, padding=padd, groups=channel)
mu2 = F.conv2d(y_true, window, padding=padd, groups=channel)
mu1_sq = mu1.pow(2)
mu2_sq = mu2.pow(2)
mu1_mu2 = mu1 * mu2
sigma1_sq = F.conv2d(y_pred * y_pred, window, padding=padd, groups=channel) - mu1_sq
sigma2_sq = F.conv2d(y_true * y_true, window, padding=padd, groups=channel) - mu2_sq
sigma12 = F.conv2d(y_pred * y_true, window, padding=padd, groups=channel) - mu1_mu2
C1 = (0.01 * L) ** 2
C2 = (0.03 * L) ** 2
v1 = 2.0 * sigma12 + C2
v2 = sigma1_sq + sigma2_sq + C2
cs = torch.mean(v1 / v2) # contrast sensitivity
ssim_map = ((2 * mu1_mu2 + C1) * v1) / ((mu1_sq + mu2_sq + C1) * v2)
if size_average:
ret = ssim_map.mean()
else:
ret = ssim_map.mean(1).mean(1).mean(1)
if full:
return ret, cs
return ret
class AE(object):
"""
Modified from matlab : colorangle.m, MATLAB V2019b
angle = acos(RGB1' * RGB2 / (norm(RGB1) * norm(RGB2)));
angle = 180 / pi * angle;
"""
def __init__(self, des='average Angular Error'):
self.des = des
def __repr__(self):
return "AE"
def __call__(self, y_pred, y_true):
"""
args:
y_true : 4-d ndarray in [batch_size, channels, img_rows, img_cols]
y_pred : 4-d ndarray in [batch_size, channels, img_rows, img_cols]
return average AE, smaller the better
"""
dotP = torch.sum(y_pred * y_true, dim=1)
Norm_pred = torch.sqrt(torch.sum(y_pred * y_pred, dim=1))
Norm_true = torch.sqrt(torch.sum(y_true * y_true, dim=1))
ae = 180 / math.pi * torch.acos(dotP / (Norm_pred * Norm_true + eps))
return ae.mean(1).mean(1)
if __name__ == "__main__":
for ch in [3, 1]:
batch_size, img_row, img_col = 1, 224, 224
y_true = torch.rand(batch_size, ch, img_row, img_col)
noise = torch.zeros(y_true.size()).data.normal_(0, std=0.1)
y_pred = y_true + noise
for cuda in [False, True]:
if cuda:
y_pred = y_pred.cuda()
y_true = y_true.cuda()
print('#'*20, 'Cuda : {} ; size : {}'.format(cuda, y_true.size()))
########### similarity metrics
metric = MSE()
acc = metric(y_pred, y_true).item()
print("{} ==> {}".format(repr(metric), acc))
metric = PSNR()
acc = metric(y_pred, y_true).item()
print("{} ==> {}".format(repr(metric), acc))
metric = SSIM()
acc = metric(y_pred, y_true).item()
print("{} ==> {}".format(repr(metric), acc))
metric = LPIPS(cuda)
acc = metric(y_pred, y_true).item()
print("{} ==> {}".format(repr(metric), acc))
metric = AE()
acc = metric(y_pred, y_true).item()
print("{} ==> {}".format(repr(metric), acc))
########### accuracy metrics
metric = OAAcc()
maccu, accu = metric(y_pred, y_true)
print('mAccu:', maccu, 'Accu', accu)
metric = Precision()
mprec, prec = metric(y_pred, y_true)
print('mPrec:', mprec, 'Prec', prec)
metric = Recall()
mreca, reca = metric(y_pred, y_true)
print('mReca:', mreca, 'Reca', reca)
metric = F1Score()
mf1sc, f1sc = metric(y_pred, y_true)
print('mF1sc:', mf1sc, 'F1sc', f1sc)
metric = Kappa()
mkapp, kapp = metric(y_pred, y_true)
print('mKapp:', mkapp, 'Kapp', kapp)
metric = Jaccard()
mjacc, jacc = metric(y_pred, y_true)
print('mJacc:', mjacc, 'Jacc', jacc)