File size: 6,209 Bytes
9bf4bd7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
# Copyright (c) OpenMMLab. All rights reserved.
import warnings
from typing import Dict, List, Optional, Sequence, Union

import torch
import torch.nn as nn

from mmocr.models.common.dictionary import Dictionary
from mmocr.registry import MODELS
from mmocr.structures import TextRecogDataSample
from .base import BaseDecoder


@MODELS.register_module()
class RobustScannerFuser(BaseDecoder):
    """Decoder for RobustScanner.

    Args:
        dictionary (dict or :obj:`Dictionary`): The config for `Dictionary` or
            the instance of `Dictionary`.
        module_loss (dict, optional): Config to build module_loss. Defaults
            to None.
        postprocessor (dict, optional): Config to build postprocessor.
            Defaults to None.
        hybrid_decoder (dict): Config to build hybrid_decoder. Defaults to
            dict(type='SequenceAttentionDecoder').
        position_decoder (dict): Config to build position_decoder. Defaults to
            dict(type='PositionAttentionDecoder').
        fuser (dict): Config to build fuser. Defaults to
            dict(type='RobustScannerFuser').
        max_seq_len (int): Maximum sequence length. The
            sequence is usually generated from decoder. Defaults to 30.
        in_channels (list[int]): List of input channels.
            Defaults to [512, 512].
        dim (int): The dimension on which to split the input. Defaults to -1.
        init_cfg (dict or list[dict], optional): Initialization configs.
            Defaults to None.
    """

    def __init__(self,
                 dictionary: Union[Dict, Dictionary],
                 module_loss: Optional[Dict] = None,
                 postprocessor: Optional[Dict] = None,
                 hybrid_decoder: Dict = dict(type='SequenceAttentionDecoder'),
                 position_decoder: Dict = dict(
                     type='PositionAttentionDecoder'),
                 max_seq_len: int = 30,
                 in_channels: List[int] = [512, 512],
                 dim: int = -1,
                 init_cfg: Optional[Union[Dict, List[Dict]]] = None) -> None:
        super().__init__(
            dictionary=dictionary,
            module_loss=module_loss,
            postprocessor=postprocessor,
            max_seq_len=max_seq_len,
            init_cfg=init_cfg)

        for cfg_name in ['hybrid_decoder', 'position_decoder']:
            cfg = eval(cfg_name)
            if cfg is not None:
                if cfg.get('dictionary', None) is None:
                    cfg.update(dictionary=self.dictionary)
                else:
                    warnings.warn(f"Using dictionary {cfg['dictionary']} "
                                  "in decoder's config.")
                if cfg.get('max_seq_len', None) is None:
                    cfg.update(max_seq_len=max_seq_len)
                else:
                    warnings.warn(f"Using max_seq_len {cfg['max_seq_len']} "
                                  "in decoder's config.")
                setattr(self, cfg_name, MODELS.build(cfg))

        in_channels = sum(in_channels)
        self.dim = dim

        self.linear_layer = nn.Linear(in_channels, in_channels)
        self.glu_layer = nn.GLU(dim=dim)
        self.prediction = nn.Linear(
            int(in_channels / 2), self.dictionary.num_classes)
        self.softmax = nn.Softmax(dim=-1)

    def forward_train(
        self,
        feat: Optional[torch.Tensor] = None,
        out_enc: Optional[torch.Tensor] = None,
        data_samples: Optional[Sequence[TextRecogDataSample]] = None
    ) -> torch.Tensor:
        """Forward for training.

        Args:
            feat (torch.Tensor, optional): The feature map from backbone of
                shape :math:`(N, E, H, W)`. Defaults to None.
            out_enc (torch.Tensor, optional): Encoder output. Defaults to None.
            data_samples (Sequence[TextRecogDataSample]): Batch of
                TextRecogDataSample, containing gt_text information. Defaults
                to None.
        """
        hybrid_glimpse = self.hybrid_decoder(feat, out_enc, data_samples)
        position_glimpse = self.position_decoder(feat, out_enc, data_samples)
        fusion_input = torch.cat([hybrid_glimpse, position_glimpse], self.dim)
        outputs = self.linear_layer(fusion_input)
        outputs = self.glu_layer(outputs)
        return self.prediction(outputs)

    def forward_test(
        self,
        feat: Optional[torch.Tensor] = None,
        out_enc: Optional[torch.Tensor] = None,
        data_samples: Optional[Sequence[TextRecogDataSample]] = None
    ) -> torch.Tensor:
        """Forward for testing.

        Args:
            feat (torch.Tensor, optional): The feature map from backbone of
                shape :math:`(N, E, H, W)`. Defaults to None.
            out_enc (torch.Tensor, optional): Encoder output. Defaults to None.
            data_samples (Sequence[TextRecogDataSample]): Batch of
                TextRecogDataSample, containing vaild_ratio information.
                Defaults to None.

        Returns:
            Tensor: Character probabilities. of shape
            :math:`(N, self.max_seq_len, C)` where :math:`C` is
            ``num_classes``.
        """
        position_glimpse = self.position_decoder(feat, out_enc, data_samples)

        batch_size = feat.size(0)
        decode_sequence = (feat.new_ones((batch_size, self.max_seq_len)) *
                           self.dictionary.start_idx).long()
        outputs = []
        for step in range(self.max_seq_len):
            hybrid_glimpse_step = self.hybrid_decoder.forward_test_step(
                feat, out_enc, decode_sequence, step, data_samples)

            fusion_input = torch.cat(
                [hybrid_glimpse_step, position_glimpse[:, step, :]], self.dim)
            output = self.linear_layer(fusion_input)
            output = self.glu_layer(output)
            output = self.prediction(output)
            _, max_idx = torch.max(output, dim=1, keepdim=False)
            if step < self.max_seq_len - 1:
                decode_sequence[:, step + 1] = max_idx
            outputs.append(output)
        outputs = torch.stack(outputs, 1)
        return self.softmax(outputs)