Spaces:
Sleeping
Sleeping
File size: 4,292 Bytes
9bf4bd7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 |
# Copyright (c) OpenMMLab. All rights reserved.
import warnings
from typing import Dict, Optional, Sequence, Tuple, Union
import mmengine
import torch
from mmengine.structures import LabelData
from mmocr.models.common.dictionary import Dictionary
from mmocr.registry import TASK_UTILS
from mmocr.structures import TextRecogDataSample
class BaseTextRecogPostprocessor:
"""Base text recognition postprocessor.
Args:
dictionary (dict or :obj:`Dictionary`): The config for `Dictionary` or
the instance of `Dictionary`.
max_seq_len (int): max_seq_len (int): Maximum sequence length. The
sequence is usually generated from decoder. Defaults to 40.
ignore_chars (list[str]): A list of characters to be ignored from the
final results. Postprocessor will skip over these characters when
converting raw indexes to characters. Apart from single characters,
each item can be one of the following reversed keywords: 'padding',
'end' and 'unknown', which refer to their corresponding special
tokens in the dictionary.
"""
def __init__(self,
dictionary: Union[Dictionary, Dict],
max_seq_len: int = 40,
ignore_chars: Sequence[str] = ['padding'],
**kwargs) -> None:
if isinstance(dictionary, dict):
self.dictionary = TASK_UTILS.build(dictionary)
elif isinstance(dictionary, Dictionary):
self.dictionary = dictionary
else:
raise TypeError(
'The type of dictionary should be `Dictionary` or dict, '
f'but got {type(dictionary)}')
self.max_seq_len = max_seq_len
mapping_table = {
'padding': self.dictionary.padding_idx,
'end': self.dictionary.end_idx,
'unknown': self.dictionary.unknown_idx,
}
if not mmengine.is_list_of(ignore_chars, str):
raise TypeError('ignore_chars must be list of str')
ignore_indexes = list()
for ignore_char in ignore_chars:
index = mapping_table.get(
ignore_char,
self.dictionary.char2idx(ignore_char, strict=False))
if index is None or (index == self.dictionary.unknown_idx
and ignore_char != 'unknown'):
warnings.warn(
f'{ignore_char} does not exist in the dictionary',
UserWarning)
continue
ignore_indexes.append(index)
self.ignore_indexes = ignore_indexes
def get_single_prediction(
self,
probs: torch.Tensor,
data_sample: Optional[TextRecogDataSample] = None,
) -> Tuple[Sequence[int], Sequence[float]]:
"""Convert the output probabilities of a single image to index and
score.
Args:
probs (torch.Tensor): Character probabilities with shape
:math:`(T, C)`.
data_sample (TextRecogDataSample): Datasample of an image.
Returns:
tuple(list[int], list[float]): Index and scores per-character.
"""
raise NotImplementedError
def __call__(
self, probs: torch.Tensor, data_samples: Sequence[TextRecogDataSample]
) -> Sequence[TextRecogDataSample]:
"""Convert outputs to strings and scores.
Args:
probs (torch.Tensor): Batched character probabilities, the model's
softmaxed output in size: :math:`(N, T, C)`.
data_samples (list[TextRecogDataSample]): The list of
TextRecogDataSample.
Returns:
list(TextRecogDataSample): The list of TextRecogDataSample. It
usually contain ``pred_text`` information.
"""
batch_size = probs.size(0)
for idx in range(batch_size):
index, score = self.get_single_prediction(probs[idx, :, :],
data_samples[idx])
text = self.dictionary.idx2str(index)
pred_text = LabelData()
pred_text.score = score
pred_text.item = text
data_samples[idx].pred_text = pred_text
return data_samples
|