Spaces:
Sleeping
Sleeping
# Copyright (c) OpenMMLab. All rights reserved. | |
import warnings | |
from typing import Dict, Optional, Sequence, Tuple, Union | |
import mmengine | |
import torch | |
from mmengine.structures import LabelData | |
from mmocr.models.common.dictionary import Dictionary | |
from mmocr.registry import TASK_UTILS | |
from mmocr.structures import TextRecogDataSample | |
class BaseTextRecogPostprocessor: | |
"""Base text recognition postprocessor. | |
Args: | |
dictionary (dict or :obj:`Dictionary`): The config for `Dictionary` or | |
the instance of `Dictionary`. | |
max_seq_len (int): max_seq_len (int): Maximum sequence length. The | |
sequence is usually generated from decoder. Defaults to 40. | |
ignore_chars (list[str]): A list of characters to be ignored from the | |
final results. Postprocessor will skip over these characters when | |
converting raw indexes to characters. Apart from single characters, | |
each item can be one of the following reversed keywords: 'padding', | |
'end' and 'unknown', which refer to their corresponding special | |
tokens in the dictionary. | |
""" | |
def __init__(self, | |
dictionary: Union[Dictionary, Dict], | |
max_seq_len: int = 40, | |
ignore_chars: Sequence[str] = ['padding'], | |
**kwargs) -> None: | |
if isinstance(dictionary, dict): | |
self.dictionary = TASK_UTILS.build(dictionary) | |
elif isinstance(dictionary, Dictionary): | |
self.dictionary = dictionary | |
else: | |
raise TypeError( | |
'The type of dictionary should be `Dictionary` or dict, ' | |
f'but got {type(dictionary)}') | |
self.max_seq_len = max_seq_len | |
mapping_table = { | |
'padding': self.dictionary.padding_idx, | |
'end': self.dictionary.end_idx, | |
'unknown': self.dictionary.unknown_idx, | |
} | |
if not mmengine.is_list_of(ignore_chars, str): | |
raise TypeError('ignore_chars must be list of str') | |
ignore_indexes = list() | |
for ignore_char in ignore_chars: | |
index = mapping_table.get( | |
ignore_char, | |
self.dictionary.char2idx(ignore_char, strict=False)) | |
if index is None or (index == self.dictionary.unknown_idx | |
and ignore_char != 'unknown'): | |
warnings.warn( | |
f'{ignore_char} does not exist in the dictionary', | |
UserWarning) | |
continue | |
ignore_indexes.append(index) | |
self.ignore_indexes = ignore_indexes | |
def get_single_prediction( | |
self, | |
probs: torch.Tensor, | |
data_sample: Optional[TextRecogDataSample] = None, | |
) -> Tuple[Sequence[int], Sequence[float]]: | |
"""Convert the output probabilities of a single image to index and | |
score. | |
Args: | |
probs (torch.Tensor): Character probabilities with shape | |
:math:`(T, C)`. | |
data_sample (TextRecogDataSample): Datasample of an image. | |
Returns: | |
tuple(list[int], list[float]): Index and scores per-character. | |
""" | |
raise NotImplementedError | |
def __call__( | |
self, probs: torch.Tensor, data_samples: Sequence[TextRecogDataSample] | |
) -> Sequence[TextRecogDataSample]: | |
"""Convert outputs to strings and scores. | |
Args: | |
probs (torch.Tensor): Batched character probabilities, the model's | |
softmaxed output in size: :math:`(N, T, C)`. | |
data_samples (list[TextRecogDataSample]): The list of | |
TextRecogDataSample. | |
Returns: | |
list(TextRecogDataSample): The list of TextRecogDataSample. It | |
usually contain ``pred_text`` information. | |
""" | |
batch_size = probs.size(0) | |
for idx in range(batch_size): | |
index, score = self.get_single_prediction(probs[idx, :, :], | |
data_samples[idx]) | |
text = self.dictionary.idx2str(index) | |
pred_text = LabelData() | |
pred_text.score = score | |
pred_text.item = text | |
data_samples[idx].pred_text = pred_text | |
return data_samples | |