File size: 1,981 Bytes
29f689c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import copy

from torch import nn

from .abinet_loss import ABINetLoss
from .ar_loss import ARLoss
from .cdistnet_loss import CDistNetLoss
from .ce_loss import CELoss
from .cppd_loss import CPPDLoss
from .ctc_loss import CTCLoss
from .igtr_loss import IGTRLoss
from .lister_loss import LISTERLoss
from .lpv_loss import LPVLoss
from .mgp_loss import MGPLoss
from .parseq_loss import PARSeqLoss
from .robustscanner_loss import RobustScannerLoss
from .smtr_loss import SMTRLoss
from .srn_loss import SRNLoss
from .visionlan_loss import VisionLANLoss
from .cam_loss import CAMLoss
from .seed_loss import SEEDLoss

support_dict = [
    'CTCLoss', 'ARLoss', 'CELoss', 'CPPDLoss', 'ABINetLoss', 'CDistNetLoss',
    'VisionLANLoss', 'PARSeqLoss', 'IGTRLoss', 'SMTRLoss', 'LPVLoss',
    'RobustScannerLoss', 'SRNLoss', 'LISTERLoss', 'GTCLoss', 'MGPLoss',
    'CAMLoss', 'SEEDLoss'
]


def build_loss(config):
    config = copy.deepcopy(config)
    module_name = config.pop('name')
    assert module_name in support_dict, Exception(
        'loss only support {}'.format(support_dict))
    module_class = eval(module_name)(**config)
    return module_class


class GTCLoss(nn.Module):

    def __init__(self,
                 gtc_loss,
                 gtc_weight=1.0,
                 ctc_weight=1.0,
                 zero_infinity=True,
                 **kwargs):
        super(GTCLoss, self).__init__()
        self.ctc_loss = CTCLoss(zero_infinity=zero_infinity)
        self.gtc_loss = build_loss(gtc_loss)
        self.gtc_weight = gtc_weight
        self.ctc_weight = ctc_weight

    def forward(self, predicts, batch):
        ctc_loss = self.ctc_loss(predicts['ctc_pred'],
                                 [None] + batch[-2:])['loss']
        gtc_loss = self.gtc_loss(predicts['gtc_pred'], batch[:-2])['loss']
        return {
            'loss': self.ctc_weight * ctc_loss + self.gtc_weight * gtc_loss,
            'ctc_loss': ctc_loss,
            'gtc_loss': gtc_loss
        }