File size: 10,207 Bytes
9bf4bd7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
# Copyright (c) OpenMMLab. All rights reserved.
import math
from typing import Dict, List, Optional, Sequence, Union

import torch
import torch.nn as nn
from mmengine.model import ModuleList

from mmocr.models.common import PositionalEncoding, TFDecoderLayer
from mmocr.models.common.dictionary import Dictionary
from mmocr.registry import MODELS
from mmocr.structures import TextRecogDataSample
from .base import BaseDecoder


@MODELS.register_module()
class NRTRDecoder(BaseDecoder):
    """Transformer Decoder block with self attention mechanism.

    Args:
        n_layers (int): Number of attention layers. Defaults to 6.
        d_embedding (int): Language embedding dimension. Defaults to 512.
        n_head (int): Number of parallel attention heads. Defaults to 8.
        d_k (int): Dimension of the key vector. Defaults to 64.
        d_v (int): Dimension of the value vector. Defaults to 64
        d_model (int): Dimension :math:`D_m` of the input from previous model.
            Defaults to 512.
        d_inner (int): Hidden dimension of feedforward layers. Defaults to 256.
        n_position (int): Length of the positional encoding vector. Must be
            greater than ``max_seq_len``. Defaults to 200.
        dropout (float): Dropout rate for text embedding, MHSA, FFN. Defaults
            to 0.1.
        module_loss (dict, optional): Config to build module_loss. Defaults
            to None.
        postprocessor (dict, optional): Config to build postprocessor.
            Defaults to None.
        dictionary (dict or :obj:`Dictionary`): The config for `Dictionary` or
            the instance of `Dictionary`.
        max_seq_len (int): Maximum output sequence length :math:`T`. Defaults
            to 30.
        init_cfg (dict or list[dict], optional): Initialization configs.
    """

    def __init__(self,
                 n_layers: int = 6,
                 d_embedding: int = 512,
                 n_head: int = 8,
                 d_k: int = 64,
                 d_v: int = 64,
                 d_model: int = 512,
                 d_inner: int = 256,
                 n_position: int = 200,
                 dropout: float = 0.1,
                 module_loss: Optional[Dict] = None,
                 postprocessor: Optional[Dict] = None,
                 dictionary: Optional[Union[Dict, Dictionary]] = None,
                 max_seq_len: int = 30,
                 init_cfg: Optional[Union[Dict, List[Dict]]] = None) -> None:
        super().__init__(
            module_loss=module_loss,
            postprocessor=postprocessor,
            dictionary=dictionary,
            init_cfg=init_cfg,
            max_seq_len=max_seq_len)

        self.padding_idx = self.dictionary.padding_idx
        self.start_idx = self.dictionary.start_idx
        self.max_seq_len = max_seq_len

        self.trg_word_emb = nn.Embedding(
            self.dictionary.num_classes,
            d_embedding,
            padding_idx=self.padding_idx)

        self.position_enc = PositionalEncoding(
            d_embedding, n_position=n_position)
        self.dropout = nn.Dropout(p=dropout)

        self.layer_stack = ModuleList([
            TFDecoderLayer(
                d_model, d_inner, n_head, d_k, d_v, dropout=dropout)
            for _ in range(n_layers)
        ])
        self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)

        pred_num_class = self.dictionary.num_classes
        self.classifier = nn.Linear(d_model, pred_num_class)
        self.softmax = nn.Softmax(dim=-1)

    def _get_target_mask(self, trg_seq: torch.Tensor) -> torch.Tensor:
        """Generate mask for target sequence.

        Args:
            trg_seq (torch.Tensor): Input text sequence. Shape :math:`(N, T)`.

        Returns:
            Tensor: Target mask. Shape :math:`(N, T, T)`.
            E.g.:
            seq = torch.Tensor([[1, 2, 0, 0]]), pad_idx = 0, then
            target_mask =
            torch.Tensor([[[True, False, False, False],
            [True, True, False, False],
            [True, True, False, False],
            [True, True, False, False]]])
        """

        pad_mask = (trg_seq != self.padding_idx).unsqueeze(-2)

        len_s = trg_seq.size(1)
        subsequent_mask = 1 - torch.triu(
            torch.ones((len_s, len_s), device=trg_seq.device), diagonal=1)
        subsequent_mask = subsequent_mask.unsqueeze(0).bool()

        return pad_mask & subsequent_mask

    def _get_source_mask(self, src_seq: torch.Tensor,
                         valid_ratios: Sequence[float]) -> torch.Tensor:
        """Generate mask for source sequence.

        Args:
            src_seq (torch.Tensor): Image sequence. Shape :math:`(N, T, C)`.
            valid_ratios (list[float]): The valid ratio of input image. For
                example, if the width of the original image is w1 and the width
                after padding is w2, then valid_ratio = w1/w2. Source mask is
                used to cover the area of the padding region.

        Returns:
            Tensor or None: Source mask. Shape :math:`(N, T)`. The region of
            padding area are False, and the rest are True.
        """

        N, T, _ = src_seq.size()
        mask = None
        if len(valid_ratios) > 0:
            mask = src_seq.new_zeros((N, T), device=src_seq.device)
            for i, valid_ratio in enumerate(valid_ratios):
                valid_width = min(T, math.ceil(T * valid_ratio))
                mask[i, :valid_width] = 1

        return mask

    def _attention(self,
                   trg_seq: torch.Tensor,
                   src: torch.Tensor,
                   src_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
        """A wrapped process for transformer based decoder including text
        embedding, position embedding, N x transformer decoder and a LayerNorm
        operation.

        Args:
            trg_seq (Tensor): Target sequence in. Shape :math:`(N, T)`.
            src (Tensor): Source sequence from encoder in shape
                Shape :math:`(N, T, D_m)` where :math:`D_m` is ``d_model``.
            src_mask (Tensor, Optional): Mask for source sequence.
                Shape :math:`(N, T)`. Defaults to None.

        Returns:
            Tensor: Output sequence from transformer decoder.
            Shape :math:`(N, T, D_m)` where :math:`D_m` is ``d_model``.
        """

        trg_embedding = self.trg_word_emb(trg_seq)
        trg_pos_encoded = self.position_enc(trg_embedding)
        trg_mask = self._get_target_mask(trg_seq)
        tgt_seq = self.dropout(trg_pos_encoded)

        output = tgt_seq
        for dec_layer in self.layer_stack:
            output = dec_layer(
                output,
                src,
                self_attn_mask=trg_mask,
                dec_enc_attn_mask=src_mask)
        output = self.layer_norm(output)

        return output

    def forward_train(self,
                      feat: Optional[torch.Tensor] = None,
                      out_enc: torch.Tensor = None,
                      data_samples: Sequence[TextRecogDataSample] = None
                      ) -> torch.Tensor:
        """Forward for training. Source mask will be used here.

        Args:
            feat (Tensor, optional): Unused.
            out_enc (Tensor): Encoder output of shape : math:`(N, T, D_m)`
                where :math:`D_m` is ``d_model``. Defaults to None.
            data_samples (list[TextRecogDataSample]): Batch of
                TextRecogDataSample, containing gt_text and valid_ratio
                information. Defaults to None.

        Returns:
            Tensor: The raw logit tensor. Shape :math:`(N, T, C)` where
            :math:`C` is ``num_classes``.
        """
        valid_ratios = []
        for data_sample in data_samples:
            valid_ratios.append(data_sample.get('valid_ratio'))
        src_mask = self._get_source_mask(out_enc, valid_ratios)
        trg_seq = []
        for data_sample in data_samples:
            trg_seq.append(
                data_sample.gt_text.padded_indexes.to(out_enc.device))
        trg_seq = torch.stack(trg_seq, dim=0)
        attn_output = self._attention(trg_seq, out_enc, src_mask=src_mask)
        outputs = self.classifier(attn_output)

        return outputs

    def forward_test(self,
                     feat: Optional[torch.Tensor] = None,
                     out_enc: torch.Tensor = None,
                     data_samples: Sequence[TextRecogDataSample] = None
                     ) -> torch.Tensor:
        """Forward for testing.

        Args:
            feat (Tensor, optional): Unused.
            out_enc (Tensor): Encoder output of shape:
                math:`(N, T, D_m)` where :math:`D_m` is ``d_model``.
                Defaults to None.
            data_samples (list[TextRecogDataSample]): Batch of
                TextRecogDataSample, containing gt_text and valid_ratio
                information. Defaults to None.

        Returns:
            Tensor: Character probabilities. of shape
            :math:`(N, self.max_seq_len, C)` where :math:`C` is
            ``num_classes``.
        """
        valid_ratios = []
        for data_sample in data_samples:
            valid_ratios.append(data_sample.get('valid_ratio'))
        src_mask = self._get_source_mask(out_enc, valid_ratios)
        N = out_enc.size(0)
        init_target_seq = torch.full((N, self.max_seq_len + 1),
                                     self.padding_idx,
                                     device=out_enc.device,
                                     dtype=torch.long)
        # bsz * seq_len
        init_target_seq[:, 0] = self.start_idx

        outputs = []
        for step in range(0, self.max_seq_len):
            decoder_output = self._attention(
                init_target_seq, out_enc, src_mask=src_mask)
            # bsz * seq_len * C
            step_result = self.classifier(decoder_output[:, step, :])
            # bsz * num_classes
            outputs.append(step_result)
            _, step_max_index = torch.max(step_result, dim=-1)
            init_target_seq[:, step + 1] = step_max_index

        outputs = torch.stack(outputs, dim=1)

        return self.softmax(outputs)