File size: 4,292 Bytes
9bf4bd7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
# Copyright (c) OpenMMLab. All rights reserved.
import warnings
from typing import Dict, Optional, Sequence, Tuple, Union

import mmengine
import torch
from mmengine.structures import LabelData

from mmocr.models.common.dictionary import Dictionary
from mmocr.registry import TASK_UTILS
from mmocr.structures import TextRecogDataSample


class BaseTextRecogPostprocessor:
    """Base text recognition postprocessor.

    Args:
        dictionary (dict or :obj:`Dictionary`): The config for `Dictionary` or
            the instance of `Dictionary`.
        max_seq_len (int): max_seq_len (int): Maximum sequence length. The
            sequence is usually generated from decoder. Defaults to 40.
        ignore_chars (list[str]): A list of characters to be ignored from the
            final results. Postprocessor will skip over these characters when
            converting raw indexes to characters. Apart from single characters,
            each item can be one of the following reversed keywords: 'padding',
            'end' and 'unknown', which refer to their corresponding special
            tokens in the dictionary.
    """

    def __init__(self,
                 dictionary: Union[Dictionary, Dict],
                 max_seq_len: int = 40,
                 ignore_chars: Sequence[str] = ['padding'],
                 **kwargs) -> None:

        if isinstance(dictionary, dict):
            self.dictionary = TASK_UTILS.build(dictionary)
        elif isinstance(dictionary, Dictionary):
            self.dictionary = dictionary
        else:
            raise TypeError(
                'The type of dictionary should be `Dictionary` or dict, '
                f'but got {type(dictionary)}')
        self.max_seq_len = max_seq_len

        mapping_table = {
            'padding': self.dictionary.padding_idx,
            'end': self.dictionary.end_idx,
            'unknown': self.dictionary.unknown_idx,
        }
        if not mmengine.is_list_of(ignore_chars, str):
            raise TypeError('ignore_chars must be list of str')
        ignore_indexes = list()
        for ignore_char in ignore_chars:
            index = mapping_table.get(
                ignore_char,
                self.dictionary.char2idx(ignore_char, strict=False))
            if index is None or (index == self.dictionary.unknown_idx
                                 and ignore_char != 'unknown'):
                warnings.warn(
                    f'{ignore_char} does not exist in the dictionary',
                    UserWarning)
                continue
            ignore_indexes.append(index)
        self.ignore_indexes = ignore_indexes

    def get_single_prediction(
        self,
        probs: torch.Tensor,
        data_sample: Optional[TextRecogDataSample] = None,
    ) -> Tuple[Sequence[int], Sequence[float]]:
        """Convert the output probabilities of a single image to index and
        score.

        Args:
           probs (torch.Tensor): Character probabilities with shape
                :math:`(T, C)`.
           data_sample (TextRecogDataSample): Datasample of an image.

        Returns:
            tuple(list[int], list[float]): Index and scores per-character.
        """
        raise NotImplementedError

    def __call__(
        self, probs: torch.Tensor, data_samples: Sequence[TextRecogDataSample]
    ) -> Sequence[TextRecogDataSample]:
        """Convert outputs to strings and scores.

        Args:
            probs (torch.Tensor): Batched character probabilities, the model's
                softmaxed output in size: :math:`(N, T, C)`.
            data_samples (list[TextRecogDataSample]): The list of
                TextRecogDataSample.

        Returns:
            list(TextRecogDataSample): The list of TextRecogDataSample. It
            usually contain ``pred_text`` information.
        """
        batch_size = probs.size(0)

        for idx in range(batch_size):
            index, score = self.get_single_prediction(probs[idx, :, :],
                                                      data_samples[idx])
            text = self.dictionary.idx2str(index)
            pred_text = LabelData()
            pred_text.score = score
            pred_text.item = text
            data_samples[idx].pred_text = pred_text
        return data_samples