Spaces:
Running
Running
File size: 4,096 Bytes
29f689c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 |
import torch.nn as nn
__all__ = ['build_decoder']
def build_decoder(config):
# rec decoder
from .abinet_decoder import ABINetDecoder
from .aster_decoder import ASTERDecoder
from .cdistnet_decoder import CDistNetDecoder
from .cppd_decoder import CPPDDecoder
from .rctc_decoder import RCTCDecoder
from .ctc_decoder import CTCDecoder
from .dan_decoder import DANDecoder
from .igtr_decoder import IGTRDecoder
from .lister_decoder import LISTERDecoder
from .lpv_decoder import LPVDecoder
from .mgp_decoder import MGPDecoder
from .nrtr_decoder import NRTRDecoder
from .parseq_decoder import PARSeqDecoder
from .robustscanner_decoder import RobustScannerDecoder
from .sar_decoder import SARDecoder
from .smtr_decoder import SMTRDecoder
from .smtr_decoder_nattn import SMTRDecoderNumAttn
from .srn_decoder import SRNDecoder
from .visionlan_decoder import VisionLANDecoder
from .matrn_decoder import MATRNDecoder
from .cam_decoder import CAMDecoder
from .ote_decoder import OTEDecoder
from .bus_decoder import BUSDecoder
support_dict = [
'CTCDecoder', 'NRTRDecoder', 'CPPDDecoder', 'ABINetDecoder',
'CDistNetDecoder', 'VisionLANDecoder', 'PARSeqDecoder', 'IGTRDecoder',
'SMTRDecoder', 'LPVDecoder', 'SARDecoder', 'RobustScannerDecoder',
'SRNDecoder', 'ASTERDecoder', 'RCTCDecoder', 'LISTERDecoder',
'GTCDecoder', 'SMTRDecoderNumAttn', 'MATRNDecoder', 'MGPDecoder',
'DANDecoder', 'CAMDecoder', 'OTEDecoder', 'BUSDecoder'
]
module_name = config.pop('name')
assert module_name in support_dict, Exception(
'decoder only support {}'.format(support_dict))
module_class = eval(module_name)(**config)
return module_class
class GTCDecoder(nn.Module):
def __init__(self,
in_channels,
gtc_decoder,
ctc_decoder,
detach=True,
infer_gtc=False,
out_channels=0,
**kwargs):
super(GTCDecoder, self).__init__()
self.detach = detach
self.infer_gtc = infer_gtc
if infer_gtc:
gtc_decoder['out_channels'] = out_channels[0]
ctc_decoder['out_channels'] = out_channels[1]
gtc_decoder['in_channels'] = in_channels
ctc_decoder['in_channels'] = in_channels
self.gtc_decoder = build_decoder(gtc_decoder)
else:
ctc_decoder['in_channels'] = in_channels
ctc_decoder['out_channels'] = out_channels
self.ctc_decoder = build_decoder(ctc_decoder)
def forward(self, x, data=None):
ctc_pred = self.ctc_decoder(x.detach() if self.detach else x,
data=data)
if self.training or self.infer_gtc:
gtc_pred = self.gtc_decoder(x.flatten(2).transpose(1, 2),
data=data)
return {'gtc_pred': gtc_pred, 'ctc_pred': ctc_pred}
else:
return ctc_pred
class GTCDecoderTwo(nn.Module):
def __init__(self,
in_channels,
gtc_decoder,
ctc_decoder,
infer_gtc=False,
out_channels=0,
**kwargs):
super(GTCDecoderTwo, self).__init__()
self.infer_gtc = infer_gtc
gtc_decoder['out_channels'] = out_channels[0]
ctc_decoder['out_channels'] = out_channels[1]
gtc_decoder['in_channels'] = in_channels
ctc_decoder['in_channels'] = in_channels
self.gtc_decoder = build_decoder(gtc_decoder)
self.ctc_decoder = build_decoder(ctc_decoder)
def forward(self, x, data=None):
x_ctc, x_gtc = x
ctc_pred = self.ctc_decoder(x_ctc, data=data)
if self.training or self.infer_gtc:
gtc_pred = self.gtc_decoder(x_gtc.flatten(2).transpose(1, 2),
data=data)
return {'gtc_pred': gtc_pred, 'ctc_pred': ctc_pred}
else:
return ctc_pred
|