Spaces:
Sleeping
Sleeping
import torch | |
def _take_channels(*xs, ignore_channels=None): | |
if ignore_channels is None: | |
return xs | |
else: | |
channels = [ | |
channel | |
for channel in range(xs[0].shape[1]) | |
if channel not in ignore_channels | |
] | |
xs = [ | |
torch.index_select(x, dim=1, index=torch.tensor(channels).to(x.device)) | |
for x in xs | |
] | |
return xs | |
def _threshold(x, threshold=None): | |
if threshold is not None: | |
return (x > threshold).type(x.dtype) | |
else: | |
return x | |
def iou(pr, gt, eps=1e-7, threshold=None, ignore_channels=None): | |
"""Calculate Intersection over Union between ground truth and prediction | |
Args: | |
pr (torch.Tensor): predicted tensor | |
gt (torch.Tensor): ground truth tensor | |
eps (float): epsilon to avoid zero division | |
threshold: threshold for outputs binarization | |
Returns: | |
float: IoU (Jaccard) score | |
""" | |
pr = _threshold(pr, threshold=threshold) | |
pr, gt = _take_channels(pr, gt, ignore_channels=ignore_channels) | |
intersection = torch.sum(gt * pr) | |
union = torch.sum(gt) + torch.sum(pr) - intersection + eps | |
return (intersection + eps) / union | |
jaccard = iou | |
def f_score(pr, gt, beta=1, eps=1e-7, threshold=None, ignore_channels=None): | |
"""Calculate F-score between ground truth and prediction | |
Args: | |
pr (torch.Tensor): predicted tensor | |
gt (torch.Tensor): ground truth tensor | |
beta (float): positive constant | |
eps (float): epsilon to avoid zero division | |
threshold: threshold for outputs binarization | |
Returns: | |
float: F score | |
""" | |
pr = _threshold(pr, threshold=threshold) | |
pr, gt = _take_channels(pr, gt, ignore_channels=ignore_channels) | |
tp = torch.sum(gt * pr) | |
fp = torch.sum(pr) - tp | |
fn = torch.sum(gt) - tp | |
score = ((1 + beta**2) * tp + eps) / ((1 + beta**2) * tp + beta**2 * fn + fp + eps) | |
return score | |
def accuracy(pr, gt, threshold=0.5, ignore_channels=None): | |
"""Calculate accuracy score between ground truth and prediction | |
Args: | |
pr (torch.Tensor): predicted tensor | |
gt (torch.Tensor): ground truth tensor | |
eps (float): epsilon to avoid zero division | |
threshold: threshold for outputs binarization | |
Returns: | |
float: precision score | |
""" | |
pr = _threshold(pr, threshold=threshold) | |
pr, gt = _take_channels(pr, gt, ignore_channels=ignore_channels) | |
tp = torch.sum(gt == pr, dtype=pr.dtype) | |
score = tp / gt.view(-1).shape[0] | |
return score | |
def precision(pr, gt, eps=1e-7, threshold=None, ignore_channels=None): | |
"""Calculate precision score between ground truth and prediction | |
Args: | |
pr (torch.Tensor): predicted tensor | |
gt (torch.Tensor): ground truth tensor | |
eps (float): epsilon to avoid zero division | |
threshold: threshold for outputs binarization | |
Returns: | |
float: precision score | |
""" | |
pr = _threshold(pr, threshold=threshold) | |
pr, gt = _take_channels(pr, gt, ignore_channels=ignore_channels) | |
tp = torch.sum(gt * pr) | |
fp = torch.sum(pr) - tp | |
score = (tp + eps) / (tp + fp + eps) | |
return score | |
def recall(pr, gt, eps=1e-7, threshold=None, ignore_channels=None): | |
"""Calculate Recall between ground truth and prediction | |
Args: | |
pr (torch.Tensor): A list of predicted elements | |
gt (torch.Tensor): A list of elements that are to be predicted | |
eps (float): epsilon to avoid zero division | |
threshold: threshold for outputs binarization | |
Returns: | |
float: recall score | |
""" | |
pr = _threshold(pr, threshold=threshold) | |
pr, gt = _take_channels(pr, gt, ignore_channels=ignore_channels) | |
tp = torch.sum(gt * pr) | |
fn = torch.sum(gt) - tp | |
score = (tp + eps) / (tp + fn + eps) | |
return score | |