Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
from typing import Optional, Sequence | |
from mmpretrain.registry import METRICS | |
from mmpretrain.structures import label_to_onehot | |
from .multi_label import AveragePrecision, MultiLabelMetric | |
class VOCMetricMixin: | |
"""A mixin class for VOC dataset metrics, VOC annotations have extra | |
`difficult` attribute for each object, therefore, extra option is needed | |
for calculating VOC metrics. | |
Args: | |
difficult_as_postive (Optional[bool]): Whether to map the difficult | |
labels as positive in one-hot ground truth for evaluation. If it | |
set to True, map difficult gt labels to positive ones(1), If it | |
set to False, map difficult gt labels to negative ones(0). | |
Defaults to None, the difficult labels will be set to '-1'. | |
""" | |
def __init__(self, | |
*arg, | |
difficult_as_positive: Optional[bool] = None, | |
**kwarg): | |
self.difficult_as_positive = difficult_as_positive | |
super().__init__(*arg, **kwarg) | |
def process(self, data_batch, data_samples: Sequence[dict]): | |
"""Process one batch of data samples. | |
The processed results should be stored in ``self.results``, which will | |
be used to computed the metrics when all batches have been processed. | |
Args: | |
data_batch: A batch of data from the dataloader. | |
data_samples (Sequence[dict]): A batch of outputs from the model. | |
""" | |
for data_sample in data_samples: | |
result = dict() | |
gt_label = data_sample['gt_label'] | |
gt_label_difficult = data_sample['gt_label_difficult'] | |
result['pred_score'] = data_sample['pred_score'].clone() | |
num_classes = result['pred_score'].size()[-1] | |
if 'gt_score' in data_sample: | |
result['gt_score'] = data_sample['gt_score'].clone() | |
else: | |
result['gt_score'] = label_to_onehot(gt_label, num_classes) | |
# VOC annotation labels all the objects in a single image | |
# therefore, some categories are appeared both in | |
# difficult objects and non-difficult objects. | |
# Here we reckon those labels which are only exists in difficult | |
# objects as difficult labels. | |
difficult_label = set(gt_label_difficult) - ( | |
set(gt_label_difficult) & set(gt_label.tolist())) | |
# set difficult label for better eval | |
if self.difficult_as_positive is None: | |
result['gt_score'][[*difficult_label]] = -1 | |
elif self.difficult_as_positive: | |
result['gt_score'][[*difficult_label]] = 1 | |
# Save the result to `self.results`. | |
self.results.append(result) | |
class VOCMultiLabelMetric(VOCMetricMixin, MultiLabelMetric): | |
"""A collection of metrics for multi-label multi-class classification task | |
based on confusion matrix for VOC dataset. | |
It includes precision, recall, f1-score and support. | |
Args: | |
difficult_as_postive (Optional[bool]): Whether to map the difficult | |
labels as positive in one-hot ground truth for evaluation. If it | |
set to True, map difficult gt labels to positive ones(1), If it | |
set to False, map difficult gt labels to negative ones(0). | |
Defaults to None, the difficult labels will be set to '-1'. | |
**kwarg: Refers to `MultiLabelMetric` for detailed docstrings. | |
""" | |
class VOCAveragePrecision(VOCMetricMixin, AveragePrecision): | |
"""Calculate the average precision with respect of classes for VOC dataset. | |
Args: | |
difficult_as_postive (Optional[bool]): Whether to map the difficult | |
labels as positive in one-hot ground truth for evaluation. If it | |
set to True, map difficult gt labels to positive ones(1), If it | |
set to False, map difficult gt labels to negative ones(0). | |
Defaults to None, the difficult labels will be set to '-1'. | |
**kwarg: Refers to `AveragePrecision` for detailed docstrings. | |
""" | |