File size: 1,503 Bytes
29f689c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import torch.nn.functional as F
from torch import nn


class SRNLoss(nn.Module):

    def __init__(self, label_smoothing=0.0, **kwargs):
        super(SRNLoss, self).__init__()
        self.label_smoothing = label_smoothing

    def forward(self, preds, batch):
        pvam_preds, gsrm_preds, vsfd_preds = preds

        label = batch[1].reshape([-1])

        ignore_index = pvam_preds.shape[-1] + 1

        loss_pvam = F.cross_entropy(pvam_preds,
                                    label,
                                    reduction='mean',
                                    label_smoothing=self.label_smoothing,
                                    ignore_index=ignore_index)
        loss_gsrm = F.cross_entropy(gsrm_preds,
                                    label,
                                    reduction='mean',
                                    label_smoothing=self.label_smoothing,
                                    ignore_index=ignore_index)
        loss_vsfd = F.cross_entropy(vsfd_preds,
                                    label,
                                    reduction='mean',
                                    label_smoothing=self.label_smoothing,
                                    ignore_index=ignore_index)

        loss_dict = {}
        loss_dict['loss_pvam'] = loss_pvam
        loss_dict['loss_gsrm'] = loss_gsrm
        loss_dict['loss_vsfd'] = loss_vsfd

        loss_dict['loss'] = loss_pvam * 3.0 + loss_gsrm * 0.15 + loss_vsfd
        return loss_dict