File size: 6,448 Bytes
14c9181
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Sequence

from mmocr.models.common.dictionary import Dictionary
from mmocr.registry import TASK_UTILS


@TASK_UTILS.register_module()
class SPTSDictionary(Dictionary):
    """The class generates a dictionary for recognition. It pre-defines four
    special tokens: ``start_token``, ``end_token``, ``pad_token``, and
    ``unknown_token``, which will be sequentially placed at the end of the
    dictionary when their corresponding flags are True.

    Args:
        dict_file (str): The path of Character dict file which a single
            character must occupies a line.
        num_bins (int): Number of bins dividing the image, which is used to
            shift the character indexes. Defaults to 1000.
        with_start (bool): The flag to control whether to include the start
            token. Defaults to False.
        with_end (bool): The flag to control whether to include the end token.
            Defaults to False.
        with_seq end (bool): The flag to control whether to include the
            sequence end token. Defaults to False.
        same_start_end (bool): The flag to control whether the start token and
            end token are the same. It only works when both ``with_start`` and
            ``with_end`` are True. Defaults to False.
        with_padding (bool):The padding token may represent more than a
            padding. It can also represent tokens like the blank token in CTC
            or the background token in SegOCR. Defaults to False.
        with_unknown (bool): The flag to control whether to include the
            unknown token. Defaults to False.
        start_token (str): The start token as a string. Defaults to '<BOS>'.
        end_token (str): The end token as a string. Defaults to '<EOS>'.
        seq_end_token (str): The sequence end token as a string. Defaults to
            '<SEQEOS>'.
        start_end_token (str): The start/end token as a string. if start and
            end is the same. Defaults to '<BOS/EOS>'.
        padding_token (str): The padding token as a string.
            Defaults to '<PAD>'.
        unknown_token (str, optional): The unknown token as a string. If it's
            set to None and ``with_unknown`` is True, the unknown token will be
            skipped when converting string to index. Defaults to '<UKN>'.
    """

    def __init__(
        self,
        dict_file: str,
        num_bins: int = 1000,
        with_start: bool = False,
        with_end: bool = False,
        with_seq_end: bool = False,
        same_start_end: bool = False,
        with_padding: bool = False,
        with_unknown: bool = False,
        start_token: str = '<BOS>',
        end_token: str = '<EOS>',
        seq_end_token: str = '<SEQEOS>',
        start_end_token: str = '<BOS/EOS>',
        padding_token: str = '<PAD>',
        unknown_token: str = '<UKN>',
    ) -> None:
        self.with_seq_end = with_seq_end
        self.seq_end_token = seq_end_token

        super().__init__(
            dict_file=dict_file,
            with_start=with_start,
            with_end=with_end,
            same_start_end=same_start_end,
            with_padding=with_padding,
            with_unknown=with_unknown,
            start_token=start_token,
            end_token=end_token,
            start_end_token=start_end_token,
            padding_token=padding_token,
            unknown_token=unknown_token)

        self.num_bins = num_bins
        self._shift_idx()

    @property
    def num_classes(self) -> int:
        """int: Number of output classes. Special tokens are counted.
        """
        return len(self._dict) + self.num_bins

    def _shift_idx(self):
        idx_terms = [
            'start_idx', 'end_idx', 'unknown_idx', 'seq_end_idx', 'padding_idx'
        ]
        for term in idx_terms:
            value = getattr(self, term)
            if value:
                setattr(self, term, value + self.num_bins)
        for char in self._dict:
            self._char2idx[char] += self.num_bins

    def _update_dict(self):
        """Update the dict with tokens according to parameters."""
        # BOS/EOS
        self.start_idx = None
        self.end_idx = None
        # unknown
        self.unknown_idx = None
        # TODO: Check if this line in Dictionary is correct and
        # work as expected
        # if self.with_unknown and self.unknown_token is not None:
        if self.with_unknown:
            self._dict.append(self.unknown_token)
            self.unknown_idx = len(self._dict) - 1

        if self.with_start and self.with_end and self.same_start_end:
            self._dict.append(self.start_end_token)
            self.start_idx = len(self._dict) - 1
            self.end_idx = self.start_idx
            if self.with_seq_end:
                self._dict.append(self.seq_end_token)
                self.seq_end_idx = len(self.dict) - 1
        else:
            if self.with_end:
                self._dict.append(self.end_token)
                self.end_idx = len(self._dict) - 1
            if self.with_seq_end:
                self._dict.append(self.seq_end_token)
                self.seq_end_idx = len(self.dict) - 1
            if self.with_start:
                self._dict.append(self.start_token)
                self.start_idx = len(self._dict) - 1

        # padding
        self.padding_idx = None
        if self.with_padding:
            self._dict.append(self.padding_token)
            self.padding_idx = len(self._dict) - 1

        # update char2idx
        self._char2idx = {}
        for idx, char in enumerate(self._dict):
            self._char2idx[char] = idx

    def idx2str(self, index: Sequence[int]) -> str:
        """Convert a list of index to string.

        Args:
            index (list[int]): The list of indexes to convert to string.

        Return:
            str: The converted string.
        """
        assert isinstance(index, (list, tuple))
        string = ''
        for i in index:
            assert i < self.num_classes, f'Index: {i} out of range! Index ' \
                                        f'must be less than {self.num_classes}'
            # TODO: find its difference from ignore_chars
            # in TextRecogPostprocessor
            shifted_i = i - self.num_bins
            if self._dict[shifted_i] is not None:
                string += self._dict[shifted_i]
        return string