File size: 3,891 Bytes
2a13495 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 |
import torch
def _take_channels(*xs, ignore_channels=None):
if ignore_channels is None:
return xs
else:
channels = [
channel
for channel in range(xs[0].shape[1])
if channel not in ignore_channels
]
xs = [
torch.index_select(x, dim=1, index=torch.tensor(channels).to(x.device))
for x in xs
]
return xs
def _threshold(x, threshold=None):
if threshold is not None:
return (x > threshold).type(x.dtype)
else:
return x
def iou(pr, gt, eps=1e-7, threshold=None, ignore_channels=None):
"""Calculate Intersection over Union between ground truth and prediction
Args:
pr (torch.Tensor): predicted tensor
gt (torch.Tensor): ground truth tensor
eps (float): epsilon to avoid zero division
threshold: threshold for outputs binarization
Returns:
float: IoU (Jaccard) score
"""
pr = _threshold(pr, threshold=threshold)
pr, gt = _take_channels(pr, gt, ignore_channels=ignore_channels)
intersection = torch.sum(gt * pr)
union = torch.sum(gt) + torch.sum(pr) - intersection + eps
return (intersection + eps) / union
jaccard = iou
def f_score(pr, gt, beta=1, eps=1e-7, threshold=None, ignore_channels=None):
"""Calculate F-score between ground truth and prediction
Args:
pr (torch.Tensor): predicted tensor
gt (torch.Tensor): ground truth tensor
beta (float): positive constant
eps (float): epsilon to avoid zero division
threshold: threshold for outputs binarization
Returns:
float: F score
"""
pr = _threshold(pr, threshold=threshold)
pr, gt = _take_channels(pr, gt, ignore_channels=ignore_channels)
tp = torch.sum(gt * pr)
fp = torch.sum(pr) - tp
fn = torch.sum(gt) - tp
score = ((1 + beta ** 2) * tp + eps) / (
(1 + beta ** 2) * tp + beta ** 2 * fn + fp + eps
)
return score
def accuracy(pr, gt, threshold=0.5, ignore_channels=None):
"""Calculate accuracy score between ground truth and prediction
Args:
pr (torch.Tensor): predicted tensor
gt (torch.Tensor): ground truth tensor
eps (float): epsilon to avoid zero division
threshold: threshold for outputs binarization
Returns:
float: precision score
"""
pr = _threshold(pr, threshold=threshold)
pr, gt = _take_channels(pr, gt, ignore_channels=ignore_channels)
tp = torch.sum(gt == pr, dtype=pr.dtype)
score = tp / gt.view(-1).shape[0]
return score
def precision(pr, gt, eps=1e-7, threshold=None, ignore_channels=None):
"""Calculate precision score between ground truth and prediction
Args:
pr (torch.Tensor): predicted tensor
gt (torch.Tensor): ground truth tensor
eps (float): epsilon to avoid zero division
threshold: threshold for outputs binarization
Returns:
float: precision score
"""
pr = _threshold(pr, threshold=threshold)
pr, gt = _take_channels(pr, gt, ignore_channels=ignore_channels)
tp = torch.sum(gt * pr)
fp = torch.sum(pr) - tp
score = (tp + eps) / (tp + fp + eps)
return score
def recall(pr, gt, eps=1e-7, threshold=None, ignore_channels=None):
"""Calculate Recall between ground truth and prediction
Args:
pr (torch.Tensor): A list of predicted elements
gt (torch.Tensor): A list of elements that are to be predicted
eps (float): epsilon to avoid zero division
threshold: threshold for outputs binarization
Returns:
float: recall score
"""
pr = _threshold(pr, threshold=threshold)
pr, gt = _take_channels(pr, gt, ignore_channels=ignore_channels)
tp = torch.sum(gt * pr)
fn = torch.sum(gt) - tp
score = (tp + eps) / (tp + fn + eps)
return score
|