Spaces:
Build error
Build error
# Copyright (c) Microsoft Corporation. | |
# Licensed under the MIT license. | |
import sys | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
class MLECriterion(nn.Module): | |
""" | |
Class to define loss give input, model output and groundtruth | |
""" | |
def __init__(self, opt, module): | |
super().__init__() | |
self.opt = opt | |
self.ignore_index = ( | |
self.opt["IGNORE_INDEX"] | |
if "IGNORE_INDEX" in self.opt | |
else module.tokenizer.pad_token_id | |
) | |
def forward(self, vocab_logprob, batch): | |
extended_vocab_size = vocab_logprob.shape[2] | |
y = batch["decoder_input_ids"] | |
if "USE_BOS_TOKEN" in self.opt: | |
y = y[:, 1:] | |
if "USE_EOS_TOKEN" in self.opt: | |
vocab_logprob = vocab_logprob[:, :-1, :] | |
loss = F.nll_loss( | |
vocab_logprob.contiguous().view(-1, extended_vocab_size), | |
y.contiguous().view(-1), | |
ignore_index=self.ignore_index, | |
) | |
return loss | |