File size: 4,096 Bytes
29f689c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import torch.nn as nn

__all__ = ['build_decoder']


def build_decoder(config):
    # rec decoder
    from .abinet_decoder import ABINetDecoder
    from .aster_decoder import ASTERDecoder
    from .cdistnet_decoder import CDistNetDecoder
    from .cppd_decoder import CPPDDecoder
    from .rctc_decoder import RCTCDecoder
    from .ctc_decoder import CTCDecoder
    from .dan_decoder import DANDecoder
    from .igtr_decoder import IGTRDecoder
    from .lister_decoder import LISTERDecoder
    from .lpv_decoder import LPVDecoder
    from .mgp_decoder import MGPDecoder
    from .nrtr_decoder import NRTRDecoder
    from .parseq_decoder import PARSeqDecoder
    from .robustscanner_decoder import RobustScannerDecoder
    from .sar_decoder import SARDecoder
    from .smtr_decoder import SMTRDecoder
    from .smtr_decoder_nattn import SMTRDecoderNumAttn
    from .srn_decoder import SRNDecoder
    from .visionlan_decoder import VisionLANDecoder
    from .matrn_decoder import MATRNDecoder
    from .cam_decoder import CAMDecoder
    from .ote_decoder import OTEDecoder
    from .bus_decoder import BUSDecoder

    support_dict = [
        'CTCDecoder', 'NRTRDecoder', 'CPPDDecoder', 'ABINetDecoder',
        'CDistNetDecoder', 'VisionLANDecoder', 'PARSeqDecoder', 'IGTRDecoder',
        'SMTRDecoder', 'LPVDecoder', 'SARDecoder', 'RobustScannerDecoder',
        'SRNDecoder', 'ASTERDecoder', 'RCTCDecoder', 'LISTERDecoder',
        'GTCDecoder', 'SMTRDecoderNumAttn', 'MATRNDecoder', 'MGPDecoder',
        'DANDecoder', 'CAMDecoder', 'OTEDecoder', 'BUSDecoder'
    ]

    module_name = config.pop('name')
    assert module_name in support_dict, Exception(
        'decoder only support {}'.format(support_dict))
    module_class = eval(module_name)(**config)
    return module_class


class GTCDecoder(nn.Module):

    def __init__(self,
                 in_channels,
                 gtc_decoder,
                 ctc_decoder,
                 detach=True,
                 infer_gtc=False,
                 out_channels=0,
                 **kwargs):
        super(GTCDecoder, self).__init__()
        self.detach = detach
        self.infer_gtc = infer_gtc
        if infer_gtc:
            gtc_decoder['out_channels'] = out_channels[0]
            ctc_decoder['out_channels'] = out_channels[1]
            gtc_decoder['in_channels'] = in_channels
            ctc_decoder['in_channels'] = in_channels
            self.gtc_decoder = build_decoder(gtc_decoder)
        else:
            ctc_decoder['in_channels'] = in_channels
            ctc_decoder['out_channels'] = out_channels
        self.ctc_decoder = build_decoder(ctc_decoder)

    def forward(self, x, data=None):
        ctc_pred = self.ctc_decoder(x.detach() if self.detach else x,
                                    data=data)
        if self.training or self.infer_gtc:
            gtc_pred = self.gtc_decoder(x.flatten(2).transpose(1, 2),
                                        data=data)
            return {'gtc_pred': gtc_pred, 'ctc_pred': ctc_pred}
        else:
            return ctc_pred


class GTCDecoderTwo(nn.Module):

    def __init__(self,
                 in_channels,
                 gtc_decoder,
                 ctc_decoder,
                 infer_gtc=False,
                 out_channels=0,
                 **kwargs):
        super(GTCDecoderTwo, self).__init__()
        self.infer_gtc = infer_gtc
        gtc_decoder['out_channels'] = out_channels[0]
        ctc_decoder['out_channels'] = out_channels[1]
        gtc_decoder['in_channels'] = in_channels
        ctc_decoder['in_channels'] = in_channels
        self.gtc_decoder = build_decoder(gtc_decoder)
        self.ctc_decoder = build_decoder(ctc_decoder)

    def forward(self, x, data=None):
        x_ctc, x_gtc = x
        ctc_pred = self.ctc_decoder(x_ctc, data=data)
        if self.training or self.infer_gtc:
            gtc_pred = self.gtc_decoder(x_gtc.flatten(2).transpose(1, 2),
                                        data=data)
            return {'gtc_pred': gtc_pred, 'ctc_pred': ctc_pred}
        else:
            return ctc_pred