|
import torch |
|
|
|
|
|
def _take_channels(*xs, ignore_channels=None): |
|
if ignore_channels is None: |
|
return xs |
|
else: |
|
channels = [ |
|
channel |
|
for channel in range(xs[0].shape[1]) |
|
if channel not in ignore_channels |
|
] |
|
xs = [ |
|
torch.index_select(x, dim=1, index=torch.tensor(channels).to(x.device)) |
|
for x in xs |
|
] |
|
return xs |
|
|
|
|
|
def _threshold(x, threshold=None): |
|
if threshold is not None: |
|
return (x > threshold).type(x.dtype) |
|
else: |
|
return x |
|
|
|
|
|
def iou(pr, gt, eps=1e-7, threshold=None, ignore_channels=None): |
|
"""Calculate Intersection over Union between ground truth and prediction |
|
Args: |
|
pr (torch.Tensor): predicted tensor |
|
gt (torch.Tensor): ground truth tensor |
|
eps (float): epsilon to avoid zero division |
|
threshold: threshold for outputs binarization |
|
Returns: |
|
float: IoU (Jaccard) score |
|
""" |
|
|
|
pr = _threshold(pr, threshold=threshold) |
|
pr, gt = _take_channels(pr, gt, ignore_channels=ignore_channels) |
|
|
|
intersection = torch.sum(gt * pr) |
|
union = torch.sum(gt) + torch.sum(pr) - intersection + eps |
|
return (intersection + eps) / union |
|
|
|
|
|
jaccard = iou |
|
|
|
|
|
def f_score(pr, gt, beta=1, eps=1e-7, threshold=None, ignore_channels=None): |
|
"""Calculate F-score between ground truth and prediction |
|
Args: |
|
pr (torch.Tensor): predicted tensor |
|
gt (torch.Tensor): ground truth tensor |
|
beta (float): positive constant |
|
eps (float): epsilon to avoid zero division |
|
threshold: threshold for outputs binarization |
|
Returns: |
|
float: F score |
|
""" |
|
|
|
pr = _threshold(pr, threshold=threshold) |
|
pr, gt = _take_channels(pr, gt, ignore_channels=ignore_channels) |
|
|
|
tp = torch.sum(gt * pr) |
|
fp = torch.sum(pr) - tp |
|
fn = torch.sum(gt) - tp |
|
|
|
score = ((1 + beta ** 2) * tp + eps) / ( |
|
(1 + beta ** 2) * tp + beta ** 2 * fn + fp + eps |
|
) |
|
|
|
return score |
|
|
|
|
|
def accuracy(pr, gt, threshold=0.5, ignore_channels=None): |
|
"""Calculate accuracy score between ground truth and prediction |
|
Args: |
|
pr (torch.Tensor): predicted tensor |
|
gt (torch.Tensor): ground truth tensor |
|
eps (float): epsilon to avoid zero division |
|
threshold: threshold for outputs binarization |
|
Returns: |
|
float: precision score |
|
""" |
|
pr = _threshold(pr, threshold=threshold) |
|
pr, gt = _take_channels(pr, gt, ignore_channels=ignore_channels) |
|
|
|
tp = torch.sum(gt == pr, dtype=pr.dtype) |
|
score = tp / gt.view(-1).shape[0] |
|
return score |
|
|
|
|
|
def precision(pr, gt, eps=1e-7, threshold=None, ignore_channels=None): |
|
"""Calculate precision score between ground truth and prediction |
|
Args: |
|
pr (torch.Tensor): predicted tensor |
|
gt (torch.Tensor): ground truth tensor |
|
eps (float): epsilon to avoid zero division |
|
threshold: threshold for outputs binarization |
|
Returns: |
|
float: precision score |
|
""" |
|
|
|
pr = _threshold(pr, threshold=threshold) |
|
pr, gt = _take_channels(pr, gt, ignore_channels=ignore_channels) |
|
|
|
tp = torch.sum(gt * pr) |
|
fp = torch.sum(pr) - tp |
|
|
|
score = (tp + eps) / (tp + fp + eps) |
|
|
|
return score |
|
|
|
|
|
def recall(pr, gt, eps=1e-7, threshold=None, ignore_channels=None): |
|
"""Calculate Recall between ground truth and prediction |
|
Args: |
|
pr (torch.Tensor): A list of predicted elements |
|
gt (torch.Tensor): A list of elements that are to be predicted |
|
eps (float): epsilon to avoid zero division |
|
threshold: threshold for outputs binarization |
|
Returns: |
|
float: recall score |
|
""" |
|
|
|
pr = _threshold(pr, threshold=threshold) |
|
pr, gt = _take_channels(pr, gt, ignore_channels=ignore_channels) |
|
|
|
tp = torch.sum(gt * pr) |
|
fn = torch.sum(gt) - tp |
|
|
|
score = (tp + eps) / (tp + fn + eps) |
|
|
|
return score |
|
|