|
|
|
|
|
|
|
|
|
|
|
import math |
|
|
|
import torch |
|
import torch.nn.functional as F |
|
from fairseq import metrics, utils |
|
from fairseq.criterions import FairseqCriterion, register_criterion |
|
|
|
|
|
def compute_cross_entropy_loss(logits, targets, ignore_index=-100): |
|
""" |
|
Function to compute the cross entropy loss. The default value of |
|
ignore_index is the same as the default value for F.cross_entropy in |
|
pytorch. |
|
""" |
|
assert logits.size(0) == targets.size( |
|
-1 |
|
), "Logits and Targets tensor shapes don't match up" |
|
|
|
loss = F.nll_loss( |
|
F.log_softmax(logits, -1, dtype=torch.float32), |
|
targets, |
|
reduction="sum", |
|
ignore_index=ignore_index, |
|
) |
|
return loss |
|
|
|
|
|
@register_criterion("legacy_masked_lm_loss") |
|
class LegacyMaskedLmLoss(FairseqCriterion): |
|
""" |
|
Implementation for the loss used in masked language model (MLM) training. |
|
This optionally also computes the next sentence prediction (NSP) loss and |
|
adds it to the overall loss based on the specified args. There are three |
|
cases to consider: |
|
1) Generic MLM training without NSP loss. In this case sentence_targets |
|
and sentence_logits are both None. |
|
2) BERT training without NSP loss. In this case sentence_targets is |
|
not None but sentence_logits is None and we should not be computing |
|
a sentence level loss. |
|
3) BERT training with NSP loss. In this case both sentence_targets and |
|
sentence_logits are not None and we should be computing a sentence |
|
level loss. The weight of the sentence level loss is specified as |
|
an argument. |
|
""" |
|
|
|
def __init__(self, task, masked_lm_only, nsp_loss_weight): |
|
super().__init__(task) |
|
self.masked_lm_only = masked_lm_only |
|
self.nsp_loss_weight = nsp_loss_weight |
|
|
|
@staticmethod |
|
def add_args(parser): |
|
"""Args for MaskedLM Loss""" |
|
|
|
parser.add_argument( |
|
"--masked-lm-only", |
|
default=False, |
|
action="store_true", |
|
help="compute MLM loss only", |
|
) |
|
parser.add_argument( |
|
"--nsp-loss-weight", |
|
default=1.0, |
|
type=float, |
|
help="weight for next sentence prediction" " loss (default 1)", |
|
) |
|
|
|
def forward(self, model, sample, reduce=True): |
|
"""Compute the loss for the given sample. |
|
Returns a tuple with three elements: |
|
1) the loss |
|
2) the sample size, which is used as the denominator for the gradient |
|
3) logging outputs to display while training |
|
""" |
|
lm_logits, output_metadata = model(**sample["net_input"]) |
|
|
|
|
|
lm_logits = lm_logits.view(-1, lm_logits.size(-1)) |
|
lm_targets = sample["lm_target"].view(-1) |
|
lm_loss = compute_cross_entropy_loss(lm_logits, lm_targets, self.padding_idx) |
|
|
|
|
|
|
|
ntokens = utils.strip_pad(lm_targets, self.padding_idx).numel() |
|
loss = lm_loss / ntokens |
|
nsentences = sample["nsentences"] |
|
|
|
|
|
|
|
sentence_loss = None |
|
if not self.masked_lm_only: |
|
sentence_logits = output_metadata["sentence_logits"] |
|
sentence_targets = sample["sentence_target"].view(-1) |
|
|
|
|
|
|
|
|
|
nsentences = sentence_targets.size(0) |
|
|
|
|
|
|
|
|
|
|
|
if sentence_logits is not None: |
|
sentence_loss = compute_cross_entropy_loss( |
|
sentence_logits, sentence_targets |
|
) |
|
|
|
loss += self.nsp_loss_weight * (sentence_loss / nsentences) |
|
|
|
|
|
|
|
|
|
sample_size = 1 |
|
logging_output = { |
|
"loss": utils.item(loss.data) if reduce else loss.data, |
|
"lm_loss": utils.item(lm_loss.data) if reduce else lm_loss.data, |
|
|
|
"sentence_loss": ( |
|
(utils.item(sentence_loss.data) if reduce else sentence_loss.data) |
|
if sentence_loss is not None |
|
else 0.0 |
|
), |
|
"ntokens": ntokens, |
|
"nsentences": nsentences, |
|
"sample_size": sample_size, |
|
} |
|
return loss, sample_size, logging_output |
|
|
|
@staticmethod |
|
def reduce_metrics(logging_outputs) -> None: |
|
"""Aggregate logging outputs from data parallel training.""" |
|
lm_loss_sum = sum(log.get("lm_loss", 0) for log in logging_outputs) |
|
sentence_loss_sum = sum(log.get("sentence_loss", 0) for log in logging_outputs) |
|
ntokens = sum(log.get("ntokens", 0) for log in logging_outputs) |
|
nsentences = sum(log.get("nsentences", 0) for log in logging_outputs) |
|
sample_size = sum(log.get("sample_size", 0) for log in logging_outputs) |
|
agg_loss = sum(log.get("loss", 0) for log in logging_outputs) |
|
|
|
metrics.log_scalar( |
|
"loss", |
|
agg_loss / sample_size / math.log(2) if sample_size > 0 else 0.0, |
|
sample_size, |
|
round=3, |
|
) |
|
metrics.log_scalar( |
|
"lm_loss", |
|
lm_loss_sum / ntokens / math.log(2) if ntokens > 0 else 0.0, |
|
ntokens, |
|
round=3, |
|
) |
|
metrics.log_scalar( |
|
"sentence_loss", |
|
sentence_loss_sum / nsentences / math.log(2) if nsentences > 0 else 0.0, |
|
nsentences, |
|
round=3, |
|
) |
|
metrics.log_scalar( |
|
"nll_loss", |
|
lm_loss_sum / ntokens / math.log(2) if ntokens > 0 else 0.0, |
|
ntokens, |
|
round=3, |
|
) |
|
|
|
@staticmethod |
|
def logging_outputs_can_be_summed() -> bool: |
|
""" |
|
Whether the logging outputs returned by `forward` can be summed |
|
across workers prior to calling `reduce_metrics`. Setting this |
|
to True will improves distributed training speed. |
|
""" |
|
return True |
|
|