Spaces:
Running
Running
import torch | |
from torch import nn | |
class ABINetLoss(nn.Module): | |
def __init__(self, | |
smoothing=False, | |
ignore_index=100, | |
align_weight=1.0, | |
**kwargs): | |
super(ABINetLoss, self).__init__() | |
if ignore_index >= 0: | |
self.loss_func = nn.CrossEntropyLoss(reduction='mean', | |
ignore_index=ignore_index) | |
else: | |
self.loss_func = nn.CrossEntropyLoss(reduction='mean') | |
self.smoothing = smoothing | |
self.align_weight = align_weight | |
def forward(self, pred, batch): | |
loss = {} | |
loss_sum = [] | |
for name, logits in pred.items(): | |
if isinstance(logits, list): | |
logit_num = len(logits) | |
if logit_num > 0: | |
all_tgt = torch.cat([batch[1]] * logit_num, 0) | |
all_logits = torch.cat(logits, 0) | |
flt_logtis = all_logits.reshape([-1, all_logits.shape[2]]) | |
flt_tgt = all_tgt.reshape([-1]) | |
else: | |
continue | |
else: | |
flt_logtis = logits.reshape([-1, logits.shape[2]]) | |
flt_tgt = batch[1].reshape([-1]) | |
loss[name + '_loss'] = self.loss_func(flt_logtis, flt_tgt) * ( | |
self.align_weight if name == 'align' else 1.0) | |
loss_sum.append(loss[name + '_loss']) | |
loss['loss'] = sum(loss_sum) | |
return loss | |