File size: 1,511 Bytes
29f689c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
import torch
from torch import nn


class ABINetLoss(nn.Module):

    def __init__(self,
                 smoothing=False,
                 ignore_index=100,
                 align_weight=1.0,
                 **kwargs):
        super(ABINetLoss, self).__init__()
        if ignore_index >= 0:
            self.loss_func = nn.CrossEntropyLoss(reduction='mean',
                                                 ignore_index=ignore_index)
        else:
            self.loss_func = nn.CrossEntropyLoss(reduction='mean')
        self.smoothing = smoothing
        self.align_weight = align_weight

    def forward(self, pred, batch):
        loss = {}
        loss_sum = []
        for name, logits in pred.items():
            if isinstance(logits, list):
                logit_num = len(logits)
                if logit_num > 0:
                    all_tgt = torch.cat([batch[1]] * logit_num, 0)
                    all_logits = torch.cat(logits, 0)
                    flt_logtis = all_logits.reshape([-1, all_logits.shape[2]])
                    flt_tgt = all_tgt.reshape([-1])
                else:
                    continue
            else:
                flt_logtis = logits.reshape([-1, logits.shape[2]])
                flt_tgt = batch[1].reshape([-1])

            loss[name + '_loss'] = self.loss_func(flt_logtis, flt_tgt) * (
                self.align_weight if name == 'align' else 1.0)
            loss_sum.append(loss[name + '_loss'])
        loss['loss'] = sum(loss_sum)
        return loss