|
from collections import OrderedDict |
|
|
|
import annotator.mmpkg.mmcv as mmcv |
|
import numpy as np |
|
import torch |
|
|
|
|
|
def f_score(precision, recall, beta=1): |
|
"""calcuate the f-score value. |
|
|
|
Args: |
|
precision (float | torch.Tensor): The precision value. |
|
recall (float | torch.Tensor): The recall value. |
|
beta (int): Determines the weight of recall in the combined score. |
|
Default: False. |
|
|
|
Returns: |
|
[torch.tensor]: The f-score value. |
|
""" |
|
score = (1 + beta**2) * (precision * recall) / ( |
|
(beta**2 * precision) + recall) |
|
return score |
|
|
|
|
|
def intersect_and_union(pred_label, |
|
label, |
|
num_classes, |
|
ignore_index, |
|
label_map=dict(), |
|
reduce_zero_label=False): |
|
"""Calculate intersection and Union. |
|
|
|
Args: |
|
pred_label (ndarray | str): Prediction segmentation map |
|
or predict result filename. |
|
label (ndarray | str): Ground truth segmentation map |
|
or label filename. |
|
num_classes (int): Number of categories. |
|
ignore_index (int): Index that will be ignored in evaluation. |
|
label_map (dict): Mapping old labels to new labels. The parameter will |
|
work only when label is str. Default: dict(). |
|
reduce_zero_label (bool): Whether ignore zero label. The parameter will |
|
work only when label is str. Default: False. |
|
|
|
Returns: |
|
torch.Tensor: The intersection of prediction and ground truth |
|
histogram on all classes. |
|
torch.Tensor: The union of prediction and ground truth histogram on |
|
all classes. |
|
torch.Tensor: The prediction histogram on all classes. |
|
torch.Tensor: The ground truth histogram on all classes. |
|
""" |
|
|
|
if isinstance(pred_label, str): |
|
pred_label = torch.from_numpy(np.load(pred_label)) |
|
else: |
|
pred_label = torch.from_numpy((pred_label)) |
|
|
|
if isinstance(label, str): |
|
label = torch.from_numpy( |
|
mmcv.imread(label, flag='unchanged', backend='pillow')) |
|
else: |
|
label = torch.from_numpy(label) |
|
|
|
if label_map is not None: |
|
for old_id, new_id in label_map.items(): |
|
label[label == old_id] = new_id |
|
if reduce_zero_label: |
|
label[label == 0] = 255 |
|
label = label - 1 |
|
label[label == 254] = 255 |
|
|
|
mask = (label != ignore_index) |
|
pred_label = pred_label[mask] |
|
label = label[mask] |
|
|
|
intersect = pred_label[pred_label == label] |
|
area_intersect = torch.histc( |
|
intersect.float(), bins=(num_classes), min=0, max=num_classes - 1) |
|
area_pred_label = torch.histc( |
|
pred_label.float(), bins=(num_classes), min=0, max=num_classes - 1) |
|
area_label = torch.histc( |
|
label.float(), bins=(num_classes), min=0, max=num_classes - 1) |
|
area_union = area_pred_label + area_label - area_intersect |
|
return area_intersect, area_union, area_pred_label, area_label |
|
|
|
|
|
def total_intersect_and_union(results, |
|
gt_seg_maps, |
|
num_classes, |
|
ignore_index, |
|
label_map=dict(), |
|
reduce_zero_label=False): |
|
"""Calculate Total Intersection and Union. |
|
|
|
Args: |
|
results (list[ndarray] | list[str]): List of prediction segmentation |
|
maps or list of prediction result filenames. |
|
gt_seg_maps (list[ndarray] | list[str]): list of ground truth |
|
segmentation maps or list of label filenames. |
|
num_classes (int): Number of categories. |
|
ignore_index (int): Index that will be ignored in evaluation. |
|
label_map (dict): Mapping old labels to new labels. Default: dict(). |
|
reduce_zero_label (bool): Whether ignore zero label. Default: False. |
|
|
|
Returns: |
|
ndarray: The intersection of prediction and ground truth histogram |
|
on all classes. |
|
ndarray: The union of prediction and ground truth histogram on all |
|
classes. |
|
ndarray: The prediction histogram on all classes. |
|
ndarray: The ground truth histogram on all classes. |
|
""" |
|
num_imgs = len(results) |
|
assert len(gt_seg_maps) == num_imgs |
|
total_area_intersect = torch.zeros((num_classes, ), dtype=torch.float64) |
|
total_area_union = torch.zeros((num_classes, ), dtype=torch.float64) |
|
total_area_pred_label = torch.zeros((num_classes, ), dtype=torch.float64) |
|
total_area_label = torch.zeros((num_classes, ), dtype=torch.float64) |
|
for i in range(num_imgs): |
|
area_intersect, area_union, area_pred_label, area_label = \ |
|
intersect_and_union( |
|
results[i], gt_seg_maps[i], num_classes, ignore_index, |
|
label_map, reduce_zero_label) |
|
total_area_intersect += area_intersect |
|
total_area_union += area_union |
|
total_area_pred_label += area_pred_label |
|
total_area_label += area_label |
|
return total_area_intersect, total_area_union, total_area_pred_label, \ |
|
total_area_label |
|
|
|
|
|
def mean_iou(results, |
|
gt_seg_maps, |
|
num_classes, |
|
ignore_index, |
|
nan_to_num=None, |
|
label_map=dict(), |
|
reduce_zero_label=False): |
|
"""Calculate Mean Intersection and Union (mIoU) |
|
|
|
Args: |
|
results (list[ndarray] | list[str]): List of prediction segmentation |
|
maps or list of prediction result filenames. |
|
gt_seg_maps (list[ndarray] | list[str]): list of ground truth |
|
segmentation maps or list of label filenames. |
|
num_classes (int): Number of categories. |
|
ignore_index (int): Index that will be ignored in evaluation. |
|
nan_to_num (int, optional): If specified, NaN values will be replaced |
|
by the numbers defined by the user. Default: None. |
|
label_map (dict): Mapping old labels to new labels. Default: dict(). |
|
reduce_zero_label (bool): Whether ignore zero label. Default: False. |
|
|
|
Returns: |
|
dict[str, float | ndarray]: |
|
<aAcc> float: Overall accuracy on all images. |
|
<Acc> ndarray: Per category accuracy, shape (num_classes, ). |
|
<IoU> ndarray: Per category IoU, shape (num_classes, ). |
|
""" |
|
iou_result = eval_metrics( |
|
results=results, |
|
gt_seg_maps=gt_seg_maps, |
|
num_classes=num_classes, |
|
ignore_index=ignore_index, |
|
metrics=['mIoU'], |
|
nan_to_num=nan_to_num, |
|
label_map=label_map, |
|
reduce_zero_label=reduce_zero_label) |
|
return iou_result |
|
|
|
|
|
def mean_dice(results, |
|
gt_seg_maps, |
|
num_classes, |
|
ignore_index, |
|
nan_to_num=None, |
|
label_map=dict(), |
|
reduce_zero_label=False): |
|
"""Calculate Mean Dice (mDice) |
|
|
|
Args: |
|
results (list[ndarray] | list[str]): List of prediction segmentation |
|
maps or list of prediction result filenames. |
|
gt_seg_maps (list[ndarray] | list[str]): list of ground truth |
|
segmentation maps or list of label filenames. |
|
num_classes (int): Number of categories. |
|
ignore_index (int): Index that will be ignored in evaluation. |
|
nan_to_num (int, optional): If specified, NaN values will be replaced |
|
by the numbers defined by the user. Default: None. |
|
label_map (dict): Mapping old labels to new labels. Default: dict(). |
|
reduce_zero_label (bool): Whether ignore zero label. Default: False. |
|
|
|
Returns: |
|
dict[str, float | ndarray]: Default metrics. |
|
<aAcc> float: Overall accuracy on all images. |
|
<Acc> ndarray: Per category accuracy, shape (num_classes, ). |
|
<Dice> ndarray: Per category dice, shape (num_classes, ). |
|
""" |
|
|
|
dice_result = eval_metrics( |
|
results=results, |
|
gt_seg_maps=gt_seg_maps, |
|
num_classes=num_classes, |
|
ignore_index=ignore_index, |
|
metrics=['mDice'], |
|
nan_to_num=nan_to_num, |
|
label_map=label_map, |
|
reduce_zero_label=reduce_zero_label) |
|
return dice_result |
|
|
|
|
|
def mean_fscore(results, |
|
gt_seg_maps, |
|
num_classes, |
|
ignore_index, |
|
nan_to_num=None, |
|
label_map=dict(), |
|
reduce_zero_label=False, |
|
beta=1): |
|
"""Calculate Mean Intersection and Union (mIoU) |
|
|
|
Args: |
|
results (list[ndarray] | list[str]): List of prediction segmentation |
|
maps or list of prediction result filenames. |
|
gt_seg_maps (list[ndarray] | list[str]): list of ground truth |
|
segmentation maps or list of label filenames. |
|
num_classes (int): Number of categories. |
|
ignore_index (int): Index that will be ignored in evaluation. |
|
nan_to_num (int, optional): If specified, NaN values will be replaced |
|
by the numbers defined by the user. Default: None. |
|
label_map (dict): Mapping old labels to new labels. Default: dict(). |
|
reduce_zero_label (bool): Whether ignore zero label. Default: False. |
|
beta (int): Determines the weight of recall in the combined score. |
|
Default: False. |
|
|
|
|
|
Returns: |
|
dict[str, float | ndarray]: Default metrics. |
|
<aAcc> float: Overall accuracy on all images. |
|
<Fscore> ndarray: Per category recall, shape (num_classes, ). |
|
<Precision> ndarray: Per category precision, shape (num_classes, ). |
|
<Recall> ndarray: Per category f-score, shape (num_classes, ). |
|
""" |
|
fscore_result = eval_metrics( |
|
results=results, |
|
gt_seg_maps=gt_seg_maps, |
|
num_classes=num_classes, |
|
ignore_index=ignore_index, |
|
metrics=['mFscore'], |
|
nan_to_num=nan_to_num, |
|
label_map=label_map, |
|
reduce_zero_label=reduce_zero_label, |
|
beta=beta) |
|
return fscore_result |
|
|
|
|
|
def eval_metrics(results, |
|
gt_seg_maps, |
|
num_classes, |
|
ignore_index, |
|
metrics=['mIoU'], |
|
nan_to_num=None, |
|
label_map=dict(), |
|
reduce_zero_label=False, |
|
beta=1): |
|
"""Calculate evaluation metrics |
|
Args: |
|
results (list[ndarray] | list[str]): List of prediction segmentation |
|
maps or list of prediction result filenames. |
|
gt_seg_maps (list[ndarray] | list[str]): list of ground truth |
|
segmentation maps or list of label filenames. |
|
num_classes (int): Number of categories. |
|
ignore_index (int): Index that will be ignored in evaluation. |
|
metrics (list[str] | str): Metrics to be evaluated, 'mIoU' and 'mDice'. |
|
nan_to_num (int, optional): If specified, NaN values will be replaced |
|
by the numbers defined by the user. Default: None. |
|
label_map (dict): Mapping old labels to new labels. Default: dict(). |
|
reduce_zero_label (bool): Whether ignore zero label. Default: False. |
|
Returns: |
|
float: Overall accuracy on all images. |
|
ndarray: Per category accuracy, shape (num_classes, ). |
|
ndarray: Per category evaluation metrics, shape (num_classes, ). |
|
""" |
|
if isinstance(metrics, str): |
|
metrics = [metrics] |
|
allowed_metrics = ['mIoU', 'mDice', 'mFscore'] |
|
if not set(metrics).issubset(set(allowed_metrics)): |
|
raise KeyError('metrics {} is not supported'.format(metrics)) |
|
|
|
total_area_intersect, total_area_union, total_area_pred_label, \ |
|
total_area_label = total_intersect_and_union( |
|
results, gt_seg_maps, num_classes, ignore_index, label_map, |
|
reduce_zero_label) |
|
all_acc = total_area_intersect.sum() / total_area_label.sum() |
|
ret_metrics = OrderedDict({'aAcc': all_acc}) |
|
for metric in metrics: |
|
if metric == 'mIoU': |
|
iou = total_area_intersect / total_area_union |
|
acc = total_area_intersect / total_area_label |
|
ret_metrics['IoU'] = iou |
|
ret_metrics['Acc'] = acc |
|
elif metric == 'mDice': |
|
dice = 2 * total_area_intersect / ( |
|
total_area_pred_label + total_area_label) |
|
acc = total_area_intersect / total_area_label |
|
ret_metrics['Dice'] = dice |
|
ret_metrics['Acc'] = acc |
|
elif metric == 'mFscore': |
|
precision = total_area_intersect / total_area_pred_label |
|
recall = total_area_intersect / total_area_label |
|
f_value = torch.tensor( |
|
[f_score(x[0], x[1], beta) for x in zip(precision, recall)]) |
|
ret_metrics['Fscore'] = f_value |
|
ret_metrics['Precision'] = precision |
|
ret_metrics['Recall'] = recall |
|
|
|
ret_metrics = { |
|
metric: value.numpy() |
|
for metric, value in ret_metrics.items() |
|
} |
|
if nan_to_num is not None: |
|
ret_metrics = OrderedDict({ |
|
metric: np.nan_to_num(metric_value, nan=nan_to_num) |
|
for metric, metric_value in ret_metrics.items() |
|
}) |
|
return ret_metrics |
|
|