import torch from torch import nn class ABINetLoss(nn.Module): def __init__(self, smoothing=False, ignore_index=100, align_weight=1.0, **kwargs): super(ABINetLoss, self).__init__() if ignore_index >= 0: self.loss_func = nn.CrossEntropyLoss(reduction='mean', ignore_index=ignore_index) else: self.loss_func = nn.CrossEntropyLoss(reduction='mean') self.smoothing = smoothing self.align_weight = align_weight def forward(self, pred, batch): loss = {} loss_sum = [] for name, logits in pred.items(): if isinstance(logits, list): logit_num = len(logits) if logit_num > 0: all_tgt = torch.cat([batch[1]] * logit_num, 0) all_logits = torch.cat(logits, 0) flt_logtis = all_logits.reshape([-1, all_logits.shape[2]]) flt_tgt = all_tgt.reshape([-1]) else: continue else: flt_logtis = logits.reshape([-1, logits.shape[2]]) flt_tgt = batch[1].reshape([-1]) loss[name + '_loss'] = self.loss_func(flt_logtis, flt_tgt) * ( self.align_weight if name == 'align' else 1.0) loss_sum.append(loss[name + '_loss']) loss['loss'] = sum(loss_sum) return loss