Spaces:
Running
Running
File size: 576 Bytes
29f689c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 |
from torch import nn
class RobustScannerLoss(nn.Module):
def __init__(self, **kwargs):
super(RobustScannerLoss, self).__init__()
ignore_index = kwargs.get('ignore_index', 38)
self.loss_func = nn.CrossEntropyLoss(reduction='mean',
ignore_index=ignore_index)
def forward(self, pred, batch):
pred = pred[:, :-1, :]
label = batch[1][:, 1:].reshape([-1])
inputs = pred.reshape([-1, pred.shape[2]])
loss = self.loss_func(inputs, label)
return {'loss': loss}
|