Spaces:
Running
Running
import torch.nn as nn | |
__all__ = ['build_decoder'] | |
def build_decoder(config): | |
# rec decoder | |
from .abinet_decoder import ABINetDecoder | |
from .aster_decoder import ASTERDecoder | |
from .cdistnet_decoder import CDistNetDecoder | |
from .cppd_decoder import CPPDDecoder | |
from .rctc_decoder import RCTCDecoder | |
from .ctc_decoder import CTCDecoder | |
from .dan_decoder import DANDecoder | |
from .igtr_decoder import IGTRDecoder | |
from .lister_decoder import LISTERDecoder | |
from .lpv_decoder import LPVDecoder | |
from .mgp_decoder import MGPDecoder | |
from .nrtr_decoder import NRTRDecoder | |
from .parseq_decoder import PARSeqDecoder | |
from .robustscanner_decoder import RobustScannerDecoder | |
from .sar_decoder import SARDecoder | |
from .smtr_decoder import SMTRDecoder | |
from .smtr_decoder_nattn import SMTRDecoderNumAttn | |
from .srn_decoder import SRNDecoder | |
from .visionlan_decoder import VisionLANDecoder | |
from .matrn_decoder import MATRNDecoder | |
from .cam_decoder import CAMDecoder | |
from .ote_decoder import OTEDecoder | |
from .bus_decoder import BUSDecoder | |
support_dict = [ | |
'CTCDecoder', 'NRTRDecoder', 'CPPDDecoder', 'ABINetDecoder', | |
'CDistNetDecoder', 'VisionLANDecoder', 'PARSeqDecoder', 'IGTRDecoder', | |
'SMTRDecoder', 'LPVDecoder', 'SARDecoder', 'RobustScannerDecoder', | |
'SRNDecoder', 'ASTERDecoder', 'RCTCDecoder', 'LISTERDecoder', | |
'GTCDecoder', 'SMTRDecoderNumAttn', 'MATRNDecoder', 'MGPDecoder', | |
'DANDecoder', 'CAMDecoder', 'OTEDecoder', 'BUSDecoder' | |
] | |
module_name = config.pop('name') | |
assert module_name in support_dict, Exception( | |
'decoder only support {}'.format(support_dict)) | |
module_class = eval(module_name)(**config) | |
return module_class | |
class GTCDecoder(nn.Module): | |
def __init__(self, | |
in_channels, | |
gtc_decoder, | |
ctc_decoder, | |
detach=True, | |
infer_gtc=False, | |
out_channels=0, | |
**kwargs): | |
super(GTCDecoder, self).__init__() | |
self.detach = detach | |
self.infer_gtc = infer_gtc | |
if infer_gtc: | |
gtc_decoder['out_channels'] = out_channels[0] | |
ctc_decoder['out_channels'] = out_channels[1] | |
gtc_decoder['in_channels'] = in_channels | |
ctc_decoder['in_channels'] = in_channels | |
self.gtc_decoder = build_decoder(gtc_decoder) | |
else: | |
ctc_decoder['in_channels'] = in_channels | |
ctc_decoder['out_channels'] = out_channels | |
self.ctc_decoder = build_decoder(ctc_decoder) | |
def forward(self, x, data=None): | |
ctc_pred = self.ctc_decoder(x.detach() if self.detach else x, | |
data=data) | |
if self.training or self.infer_gtc: | |
gtc_pred = self.gtc_decoder(x.flatten(2).transpose(1, 2), | |
data=data) | |
return {'gtc_pred': gtc_pred, 'ctc_pred': ctc_pred} | |
else: | |
return ctc_pred | |
class GTCDecoderTwo(nn.Module): | |
def __init__(self, | |
in_channels, | |
gtc_decoder, | |
ctc_decoder, | |
infer_gtc=False, | |
out_channels=0, | |
**kwargs): | |
super(GTCDecoderTwo, self).__init__() | |
self.infer_gtc = infer_gtc | |
gtc_decoder['out_channels'] = out_channels[0] | |
ctc_decoder['out_channels'] = out_channels[1] | |
gtc_decoder['in_channels'] = in_channels | |
ctc_decoder['in_channels'] = in_channels | |
self.gtc_decoder = build_decoder(gtc_decoder) | |
self.ctc_decoder = build_decoder(ctc_decoder) | |
def forward(self, x, data=None): | |
x_ctc, x_gtc = x | |
ctc_pred = self.ctc_decoder(x_ctc, data=data) | |
if self.training or self.infer_gtc: | |
gtc_pred = self.gtc_decoder(x_gtc.flatten(2).transpose(1, 2), | |
data=data) | |
return {'gtc_pred': gtc_pred, 'ctc_pred': ctc_pred} | |
else: | |
return ctc_pred | |