File size: 5,516 Bytes
9bf4bd7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
# Copyright (c) OpenMMLab. All rights reserved.

from typing import Dict, Sequence, Union

import torch
import torch.nn as nn

from mmocr.models.common.dictionary import Dictionary
from mmocr.registry import TASK_UTILS
from mmocr.structures import TextRecogDataSample


class BaseTextRecogModuleLoss(nn.Module):
    """Base recognition loss.

    Args:
        dictionary (dict or :obj:`Dictionary`): The config for `Dictionary` or
            the instance of `Dictionary`.
        max_seq_len (int): Maximum sequence length. The sequence is usually
            generated from decoder. Defaults to 40.
        letter_case (str): There are three options to alter the letter cases
            of gt texts:
            - unchanged: Do not change gt texts.
            - upper: Convert gt texts into uppercase characters.
            - lower: Convert gt texts into lowercase characters.
            Usually, it only works for English characters. Defaults to
            'unchanged'.
        pad_with (str): The padding strategy for ``gt_text.padded_indexes``.
            Defaults to 'auto'. Options are:
            - 'auto': Use dictionary.padding_idx to pad gt texts, or
              dictionary.end_idx if dictionary.padding_idx
              is None.
            - 'padding': Always use dictionary.padding_idx to pad gt texts.
            - 'end': Always use dictionary.end_idx to pad gt texts.
            - 'none': Do not pad gt texts.
    """

    def __init__(self,
                 dictionary: Union[Dict, Dictionary],
                 max_seq_len: int = 40,
                 letter_case: str = 'unchanged',
                 pad_with: str = 'auto',
                 **kwargs) -> None:
        super().__init__()
        if isinstance(dictionary, dict):
            self.dictionary = TASK_UTILS.build(dictionary)
        elif isinstance(dictionary, Dictionary):
            self.dictionary = dictionary
        else:
            raise TypeError(
                'The type of dictionary should be `Dictionary` or dict, '
                f'but got {type(dictionary)}')
        self.max_seq_len = max_seq_len
        assert letter_case in ['unchanged', 'upper', 'lower']
        self.letter_case = letter_case

        assert pad_with in ['auto', 'padding', 'end', 'none']
        if pad_with == 'auto':
            self.pad_idx = self.dictionary.padding_idx or \
                self.dictionary.end_idx
        elif pad_with == 'padding':
            self.pad_idx = self.dictionary.padding_idx
        elif pad_with == 'end':
            self.pad_idx = self.dictionary.end_idx
        else:
            self.pad_idx = None
        if self.pad_idx is None and pad_with != 'none':
            if pad_with == 'auto':
                raise ValueError('pad_with="auto", but dictionary.end_idx'
                                 ' and dictionary.padding_idx are both None')
            else:
                raise ValueError(
                    f'pad_with="{pad_with}", but dictionary.{pad_with}_idx is'
                    ' None')

    def get_targets(
        self, data_samples: Sequence[TextRecogDataSample]
    ) -> Sequence[TextRecogDataSample]:
        """Target generator.

        Args:
            data_samples (list[TextRecogDataSample]): It usually includes
                ``gt_text`` information.

        Returns:
            list[TextRecogDataSample]: Updated data_samples. Two keys will be
            added to data_sample:

            - indexes (torch.LongTensor): Character indexes representing gt
              texts. All special tokens are excluded, except for UKN.
            - padded_indexes (torch.LongTensor): Character indexes
              representing gt texts with BOS and EOS if applicable, following
              several padding indexes until the length reaches ``max_seq_len``.
              In particular, if ``pad_with='none'``, no padding will be
              applied.
        """

        for data_sample in data_samples:
            if data_sample.get('have_target', False):
                continue
            text = data_sample.gt_text.item
            if self.letter_case in ['upper', 'lower']:
                text = getattr(text, self.letter_case)()
            indexes = self.dictionary.str2idx(text)
            indexes = torch.LongTensor(indexes)

            # target indexes for loss
            src_target = torch.LongTensor(indexes.size(0) + 2).fill_(0)
            src_target[1:-1] = indexes
            if self.dictionary.start_idx is not None:
                src_target[0] = self.dictionary.start_idx
                slice_start = 0
            else:
                slice_start = 1
            if self.dictionary.end_idx is not None:
                src_target[-1] = self.dictionary.end_idx
                slice_end = src_target.size(0)
            else:
                slice_end = src_target.size(0) - 1
            src_target = src_target[slice_start:slice_end]
            if self.pad_idx is not None:
                padded_indexes = (torch.ones(self.max_seq_len) *
                                  self.pad_idx).long()
                char_num = min(src_target.size(0), self.max_seq_len)
                padded_indexes[:char_num] = src_target[:char_num]
            else:
                padded_indexes = src_target
            # put in DataSample
            data_sample.gt_text.indexes = indexes
            data_sample.gt_text.padded_indexes = padded_indexes
            data_sample.set_metainfo(dict(have_target=True))
        return data_samples