import copy from torch import nn from .abinet_loss import ABINetLoss from .ar_loss import ARLoss from .cdistnet_loss import CDistNetLoss from .ce_loss import CELoss from .cppd_loss import CPPDLoss from .ctc_loss import CTCLoss from .igtr_loss import IGTRLoss from .lister_loss import LISTERLoss from .lpv_loss import LPVLoss from .mgp_loss import MGPLoss from .parseq_loss import PARSeqLoss from .robustscanner_loss import RobustScannerLoss from .smtr_loss import SMTRLoss from .srn_loss import SRNLoss from .visionlan_loss import VisionLANLoss from .cam_loss import CAMLoss from .seed_loss import SEEDLoss support_dict = [ 'CTCLoss', 'ARLoss', 'CELoss', 'CPPDLoss', 'ABINetLoss', 'CDistNetLoss', 'VisionLANLoss', 'PARSeqLoss', 'IGTRLoss', 'SMTRLoss', 'LPVLoss', 'RobustScannerLoss', 'SRNLoss', 'LISTERLoss', 'GTCLoss', 'MGPLoss', 'CAMLoss', 'SEEDLoss' ] def build_loss(config): config = copy.deepcopy(config) module_name = config.pop('name') assert module_name in support_dict, Exception( 'loss only support {}'.format(support_dict)) module_class = eval(module_name)(**config) return module_class class GTCLoss(nn.Module): def __init__(self, gtc_loss, gtc_weight=1.0, ctc_weight=1.0, zero_infinity=True, **kwargs): super(GTCLoss, self).__init__() self.ctc_loss = CTCLoss(zero_infinity=zero_infinity) self.gtc_loss = build_loss(gtc_loss) self.gtc_weight = gtc_weight self.ctc_weight = ctc_weight def forward(self, predicts, batch): ctc_loss = self.ctc_loss(predicts['ctc_pred'], [None] + batch[-2:])['loss'] gtc_loss = self.gtc_loss(predicts['gtc_pred'], batch[:-2])['loss'] return { 'loss': self.ctc_weight * ctc_loss + self.gtc_weight * gtc_loss, 'ctc_loss': ctc_loss, 'gtc_loss': gtc_loss }