Spaces:
Sleeping
Sleeping
# Copyright (c) OpenMMLab. All rights reserved. | |
import re | |
from difflib import SequenceMatcher | |
from typing import Dict, Optional, Sequence, Union | |
import mmengine | |
from mmengine.evaluator import BaseMetric | |
from rapidfuzz.distance import Levenshtein | |
from mmocr.registry import METRICS | |
class WordMetric(BaseMetric): | |
"""Word metrics for text recognition task. | |
Args: | |
mode (str or list[str]): Options are: | |
- 'exact': Accuracy at word level. | |
- 'ignore_case': Accuracy at word level, ignoring letter | |
case. | |
- 'ignore_case_symbol': Accuracy at word level, ignoring | |
letter case and symbol. (Default metric for academic evaluation) | |
If mode is a list, then metrics in mode will be calculated | |
separately. Defaults to 'ignore_case_symbol' | |
valid_symbol (str): Valid characters. Defaults to | |
'[^A-Z^a-z^0-9^\u4e00-\u9fa5]' | |
collect_device (str): Device name used for collecting results from | |
different ranks during distributed training. Must be 'cpu' or | |
'gpu'. Defaults to 'cpu'. | |
prefix (str, optional): The prefix that will be added in the metric | |
names to disambiguate homonymous metrics of different evaluators. | |
If prefix is not provided in the argument, self.default_prefix | |
will be used instead. Defaults to None. | |
""" | |
default_prefix: Optional[str] = 'recog' | |
def __init__(self, | |
mode: Union[str, Sequence[str]] = 'ignore_case_symbol', | |
valid_symbol: str = '[^A-Z^a-z^0-9^\u4e00-\u9fa5]', | |
collect_device: str = 'cpu', | |
prefix: Optional[str] = None) -> None: | |
super().__init__(collect_device, prefix) | |
self.valid_symbol = re.compile(valid_symbol) | |
if isinstance(mode, str): | |
mode = [mode] | |
assert mmengine.is_seq_of(mode, str) | |
assert set(mode).issubset( | |
{'exact', 'ignore_case', 'ignore_case_symbol'}) | |
self.mode = set(mode) | |
def process(self, data_batch: Sequence[Dict], | |
data_samples: Sequence[Dict]) -> None: | |
"""Process one batch of data_samples. The processed results should be | |
stored in ``self.results``, which will be used to compute the metrics | |
when all batches have been processed. | |
Args: | |
data_batch (Sequence[Dict]): A batch of gts. | |
data_samples (Sequence[Dict]): A batch of outputs from the model. | |
""" | |
for data_sample in data_samples: | |
match_num = 0 | |
match_ignore_case_num = 0 | |
match_ignore_case_symbol_num = 0 | |
pred_text = data_sample.get('pred_text').get('item') | |
gt_text = data_sample.get('gt_text').get('item') | |
if 'ignore_case' in self.mode or 'ignore_case_symbol' in self.mode: | |
pred_text_lower = pred_text.lower() | |
gt_text_lower = gt_text.lower() | |
if 'ignore_case_symbol' in self.mode: | |
gt_text_lower_ignore = self.valid_symbol.sub('', gt_text_lower) | |
pred_text_lower_ignore = self.valid_symbol.sub( | |
'', pred_text_lower) | |
match_ignore_case_symbol_num =\ | |
gt_text_lower_ignore == pred_text_lower_ignore | |
if 'ignore_case' in self.mode: | |
match_ignore_case_num = pred_text_lower == gt_text_lower | |
if 'exact' in self.mode: | |
match_num = pred_text == gt_text | |
result = dict( | |
match_num=match_num, | |
match_ignore_case_num=match_ignore_case_num, | |
match_ignore_case_symbol_num=match_ignore_case_symbol_num) | |
self.results.append(result) | |
def compute_metrics(self, results: Sequence[Dict]) -> Dict: | |
"""Compute the metrics from processed results. | |
Args: | |
results (list[Dict]): The processed results of each batch. | |
Returns: | |
Dict: The computed metrics. The keys are the names of the metrics, | |
and the values are corresponding results. | |
""" | |
eps = 1e-8 | |
eval_res = {} | |
gt_word_num = len(results) | |
if 'exact' in self.mode: | |
match_nums = [result['match_num'] for result in results] | |
match_nums = sum(match_nums) | |
eval_res['word_acc'] = 1.0 * match_nums / (eps + gt_word_num) | |
if 'ignore_case' in self.mode: | |
match_ignore_case_num = [ | |
result['match_ignore_case_num'] for result in results | |
] | |
match_ignore_case_num = sum(match_ignore_case_num) | |
eval_res['word_acc_ignore_case'] = 1.0 *\ | |
match_ignore_case_num / (eps + gt_word_num) | |
if 'ignore_case_symbol' in self.mode: | |
match_ignore_case_symbol_num = [ | |
result['match_ignore_case_symbol_num'] for result in results | |
] | |
match_ignore_case_symbol_num = sum(match_ignore_case_symbol_num) | |
eval_res['word_acc_ignore_case_symbol'] = 1.0 *\ | |
match_ignore_case_symbol_num / (eps + gt_word_num) | |
for key, value in eval_res.items(): | |
eval_res[key] = float(f'{value:.4f}') | |
return eval_res | |
class CharMetric(BaseMetric): | |
"""Character metrics for text recognition task. | |
Args: | |
valid_symbol (str): Valid characters. | |
Defaults to '[^A-Z^a-z^0-9^\u4e00-\u9fa5]' | |
collect_device (str): Device name used for collecting results from | |
different ranks during distributed training. Must be 'cpu' or | |
'gpu'. Defaults to 'cpu'. | |
prefix (str, optional): The prefix that will be added in the metric | |
names to disambiguate homonymous metrics of different evaluators. | |
If prefix is not provided in the argument, self.default_prefix | |
will be used instead. Defaults to None. | |
""" | |
default_prefix: Optional[str] = 'recog' | |
def __init__(self, | |
valid_symbol: str = '[^A-Z^a-z^0-9^\u4e00-\u9fa5]', | |
collect_device: str = 'cpu', | |
prefix: Optional[str] = None) -> None: | |
super().__init__(collect_device, prefix) | |
self.valid_symbol = re.compile(valid_symbol) | |
def process(self, data_batch: Sequence[Dict], | |
data_samples: Sequence[Dict]) -> None: | |
"""Process one batch of data_samples. The processed results should be | |
stored in ``self.results``, which will be used to compute the metrics | |
when all batches have been processed. | |
Args: | |
data_batch (Sequence[Dict]): A batch of gts. | |
data_samples (Sequence[Dict]): A batch of outputs from the model. | |
""" | |
for data_sample in data_samples: | |
pred_text = data_sample.get('pred_text').get('item') | |
gt_text = data_sample.get('gt_text').get('item') | |
gt_text_lower = gt_text.lower() | |
pred_text_lower = pred_text.lower() | |
gt_text_lower_ignore = self.valid_symbol.sub('', gt_text_lower) | |
pred_text_lower_ignore = self.valid_symbol.sub('', pred_text_lower) | |
# number to calculate char level recall & precision | |
result = dict( | |
gt_char_num=len(gt_text_lower_ignore), | |
pred_char_num=len(pred_text_lower_ignore), | |
true_positive_char_num=self._cal_true_positive_char( | |
pred_text_lower_ignore, gt_text_lower_ignore)) | |
self.results.append(result) | |
def compute_metrics(self, results: Sequence[Dict]) -> Dict: | |
"""Compute the metrics from processed results. | |
Args: | |
results (list[Dict]): The processed results of each batch. | |
Returns: | |
Dict: The computed metrics. The keys are the names of the | |
metrics, and the values are corresponding results. | |
""" | |
gt_char_num = [result['gt_char_num'] for result in results] | |
pred_char_num = [result['pred_char_num'] for result in results] | |
true_positive_char_num = [ | |
result['true_positive_char_num'] for result in results | |
] | |
gt_char_num = sum(gt_char_num) | |
pred_char_num = sum(pred_char_num) | |
true_positive_char_num = sum(true_positive_char_num) | |
eps = 1e-8 | |
char_recall = 1.0 * true_positive_char_num / (eps + gt_char_num) | |
char_precision = 1.0 * true_positive_char_num / (eps + pred_char_num) | |
eval_res = {} | |
eval_res['char_recall'] = char_recall | |
eval_res['char_precision'] = char_precision | |
for key, value in eval_res.items(): | |
eval_res[key] = float(f'{value:.4f}') | |
return eval_res | |
def _cal_true_positive_char(self, pred: str, gt: str) -> int: | |
"""Calculate correct character number in prediction. | |
Args: | |
pred (str): Prediction text. | |
gt (str): Ground truth text. | |
Returns: | |
true_positive_char_num (int): The true positive number. | |
""" | |
all_opt = SequenceMatcher(None, pred, gt) | |
true_positive_char_num = 0 | |
for opt, _, _, s2, e2 in all_opt.get_opcodes(): | |
if opt == 'equal': | |
true_positive_char_num += (e2 - s2) | |
else: | |
pass | |
return true_positive_char_num | |
class OneMinusNEDMetric(BaseMetric): | |
"""One minus NED metric for text recognition task. | |
Args: | |
valid_symbol (str): Valid characters. Defaults to | |
'[^A-Z^a-z^0-9^\u4e00-\u9fa5]' | |
collect_device (str): Device name used for collecting results from | |
different ranks during distributed training. Must be 'cpu' or | |
'gpu'. Defaults to 'cpu'. | |
prefix (str, optional): The prefix that will be added in the metric | |
names to disambiguate homonymous metrics of different evaluators. | |
If prefix is not provided in the argument, self.default_prefix | |
will be used instead. Defaults to None | |
""" | |
default_prefix: Optional[str] = 'recog' | |
def __init__(self, | |
valid_symbol: str = '[^A-Z^a-z^0-9^\u4e00-\u9fa5]', | |
collect_device: str = 'cpu', | |
prefix: Optional[str] = None) -> None: | |
super().__init__(collect_device, prefix) | |
self.valid_symbol = re.compile(valid_symbol) | |
def process(self, data_batch: Sequence[Dict], | |
data_samples: Sequence[Dict]) -> None: | |
"""Process one batch of data_samples. The processed results should be | |
stored in ``self.results``, which will be used to compute the metrics | |
when all batches have been processed. | |
Args: | |
data_batch (Sequence[Dict]): A batch of gts. | |
data_samples (Sequence[Dict]): A batch of outputs from the model. | |
""" | |
for data_sample in data_samples: | |
pred_text = data_sample.get('pred_text').get('item') | |
gt_text = data_sample.get('gt_text').get('item') | |
gt_text_lower = gt_text.lower() | |
pred_text_lower = pred_text.lower() | |
gt_text_lower_ignore = self.valid_symbol.sub('', gt_text_lower) | |
pred_text_lower_ignore = self.valid_symbol.sub('', pred_text_lower) | |
norm_ed = Levenshtein.normalized_distance(pred_text_lower_ignore, | |
gt_text_lower_ignore) | |
result = dict(norm_ed=norm_ed) | |
self.results.append(result) | |
def compute_metrics(self, results: Sequence[Dict]) -> Dict: | |
"""Compute the metrics from processed results. | |
Args: | |
results (list[Dict]): The processed results of each batch. | |
Returns: | |
Dict: The computed metrics. The keys are the names of the | |
metrics, and the values are corresponding results. | |
""" | |
gt_word_num = len(results) | |
norm_ed = [result['norm_ed'] for result in results] | |
norm_ed_sum = sum(norm_ed) | |
normalized_edit_distance = norm_ed_sum / max(1, gt_word_num) | |
eval_res = {} | |
eval_res['1-N.E.D'] = 1.0 - normalized_edit_distance | |
for key, value in eval_res.items(): | |
eval_res[key] = float(f'{value:.4f}') | |
return eval_res | |