Spaces:
Sleeping
Sleeping
# Copyright (c) OpenMMLab. All rights reserved. | |
import warnings | |
from collections import OrderedDict | |
from typing import Sequence, Union | |
from mmengine.dist import (broadcast_object_list, collect_results, | |
is_main_process) | |
from mmengine.evaluator import BaseMetric, Evaluator | |
from mmengine.evaluator.metric import _to_cpu | |
from mmocr.registry import EVALUATOR | |
from mmocr.utils.typing_utils import ConfigType | |
class MultiDatasetsEvaluator(Evaluator): | |
"""Wrapper class to compose class: `ConcatDataset` and multiple | |
:class:`BaseMetric` instances. | |
The metrics will be evaluated on each dataset slice separately. The name of | |
the each metric is the concatenation of the dataset prefix, the metric | |
prefix and the key of metric - e.g. | |
`dataset_prefix/metric_prefix/accuracy`. | |
Args: | |
metrics (dict or BaseMetric or Sequence): The config of metrics. | |
dataset_prefixes (Sequence[str]): The prefix of each dataset. The | |
length of this sequence should be the same as the length of the | |
datasets. | |
""" | |
def __init__(self, metrics: Union[ConfigType, BaseMetric, Sequence], | |
dataset_prefixes: Sequence[str]) -> None: | |
super().__init__(metrics) | |
self.dataset_prefixes = dataset_prefixes | |
def evaluate(self, size: int) -> dict: | |
"""Invoke ``evaluate`` method of each metric and collect the metrics | |
dictionary. | |
Args: | |
size (int): Length of the entire validation dataset. When batch | |
size > 1, the dataloader may pad some data samples to make | |
sure all ranks have the same length of dataset slice. The | |
``collect_results`` function will drop the padded data based on | |
this size. | |
Returns: | |
dict: Evaluation results of all metrics. The keys are the names | |
of the metrics, and the values are corresponding results. | |
""" | |
metrics_results = OrderedDict() | |
dataset_slices = self.dataset_meta.get('cumulative_sizes', [size]) | |
assert len(dataset_slices) == len(self.dataset_prefixes) | |
for metric in self.metrics: | |
if len(metric.results) == 0: | |
warnings.warn( | |
f'{metric.__class__.__name__} got empty `self.results`.' | |
'Please ensure that the processed results are properly ' | |
'added into `self.results` in `process` method.') | |
results = collect_results(metric.results, size, | |
metric.collect_device) | |
if is_main_process(): | |
# cast all tensors in results list to cpu | |
results = _to_cpu(results) | |
for start, end, dataset_prefix in zip([0] + | |
dataset_slices[:-1], | |
dataset_slices, | |
self.dataset_prefixes): | |
metric_results = metric.compute_metrics( | |
results[start:end]) # type: ignore | |
# Add prefix to metric names | |
if metric.prefix: | |
final_prefix = '/'.join( | |
(dataset_prefix, metric.prefix)) | |
else: | |
final_prefix = dataset_prefix | |
metric_results = { | |
'/'.join((final_prefix, k)): v | |
for k, v in metric_results.items() | |
} | |
# Check metric name conflicts | |
for name in metric_results.keys(): | |
if name in metrics_results: | |
raise ValueError( | |
'There are multiple evaluation results with ' | |
f'the same metric name {name}. Please make ' | |
'sure all metrics have different prefixes.') | |
metrics_results.update(metric_results) | |
metric.results.clear() | |
if is_main_process(): | |
metrics_results = [metrics_results] | |
else: | |
metrics_results = [None] # type: ignore | |
broadcast_object_list(metrics_results) | |
return metrics_results[0] | |