Spaces:
Running
Running
import copy | |
from torch import nn | |
from .abinet_loss import ABINetLoss | |
from .ar_loss import ARLoss | |
from .cdistnet_loss import CDistNetLoss | |
from .ce_loss import CELoss | |
from .cppd_loss import CPPDLoss | |
from .ctc_loss import CTCLoss | |
from .igtr_loss import IGTRLoss | |
from .lister_loss import LISTERLoss | |
from .lpv_loss import LPVLoss | |
from .mgp_loss import MGPLoss | |
from .parseq_loss import PARSeqLoss | |
from .robustscanner_loss import RobustScannerLoss | |
from .smtr_loss import SMTRLoss | |
from .srn_loss import SRNLoss | |
from .visionlan_loss import VisionLANLoss | |
from .cam_loss import CAMLoss | |
from .seed_loss import SEEDLoss | |
support_dict = [ | |
'CTCLoss', 'ARLoss', 'CELoss', 'CPPDLoss', 'ABINetLoss', 'CDistNetLoss', | |
'VisionLANLoss', 'PARSeqLoss', 'IGTRLoss', 'SMTRLoss', 'LPVLoss', | |
'RobustScannerLoss', 'SRNLoss', 'LISTERLoss', 'GTCLoss', 'MGPLoss', | |
'CAMLoss', 'SEEDLoss' | |
] | |
def build_loss(config): | |
config = copy.deepcopy(config) | |
module_name = config.pop('name') | |
assert module_name in support_dict, Exception( | |
'loss only support {}'.format(support_dict)) | |
module_class = eval(module_name)(**config) | |
return module_class | |
class GTCLoss(nn.Module): | |
def __init__(self, | |
gtc_loss, | |
gtc_weight=1.0, | |
ctc_weight=1.0, | |
zero_infinity=True, | |
**kwargs): | |
super(GTCLoss, self).__init__() | |
self.ctc_loss = CTCLoss(zero_infinity=zero_infinity) | |
self.gtc_loss = build_loss(gtc_loss) | |
self.gtc_weight = gtc_weight | |
self.ctc_weight = ctc_weight | |
def forward(self, predicts, batch): | |
ctc_loss = self.ctc_loss(predicts['ctc_pred'], | |
[None] + batch[-2:])['loss'] | |
gtc_loss = self.gtc_loss(predicts['gtc_pred'], batch[:-2])['loss'] | |
return { | |
'loss': self.ctc_weight * ctc_loss + self.gtc_weight * gtc_loss, | |
'ctc_loss': ctc_loss, | |
'gtc_loss': gtc_loss | |
} | |