Spaces:
Runtime error
Runtime error
import math | |
from fairseq.criterions import register_criterion | |
from fairseq.criterions.label_smoothed_cross_entropy import LabelSmoothedCrossEntropyCriterion | |
from fairseq import metrics, utils | |
from collections import deque | |
import torch | |
import torch.nn as nn | |
class LabelSmoothedCrossEntropyCriterionWithContrastive( | |
LabelSmoothedCrossEntropyCriterion | |
): | |
def __init__(self, task, sentence_avg, label_smoothing, ignore_prefix_size=0, report_accuracy=False, | |
contrastive_lambda=0.0, | |
temperature=1.0): | |
super().__init__(task, sentence_avg, label_smoothing, ignore_prefix_size, report_accuracy) | |
self.contrastive_lambda = contrastive_lambda | |
self.temperature = temperature | |
def add_args(parser): | |
LabelSmoothedCrossEntropyCriterion.add_args(parser) | |
parser.add_argument("--contrastive-lambda", type=float, | |
default=0.0, | |
help="The contrastive loss weight") | |
parser.add_argument("--temperature", type=float, | |
default=1.0,) | |
def swap_sample(self, sample): | |
target = sample["target"] | |
prev_output_tokens = sample["net_input"]["prev_output_tokens"] | |
src_tokens = torch.cat((prev_output_tokens[:, :1], sample["net_input"]['src_tokens']), dim=-1) | |
return { | |
"net_input": { | |
"src_tokens": target.contiguous(), | |
"src_lengths": (target != self.padding_idx).int().sum(dim=1), | |
"prev_output_tokens": src_tokens[:, :-1].contiguous() | |
}, | |
'nsentences': sample['nsentences'], | |
'ntokens': utils.item((src_tokens[:, 1:] != self.padding_idx).int().sum().data), | |
"target": src_tokens[:, 1:].contiguous(), | |
"id": sample["id"], | |
} | |
def forward(self, model, sample, reduce=True): | |
net_output = model(**sample["net_input"]) | |
loss, nll_loss = self.compute_loss(model, net_output, sample, reduce=reduce) | |
encoder_out = model.encoder.forward(sample["net_input"]["src_tokens"], sample["net_input"]["src_lengths"]).encoder_out | |
reverse_sample = self.swap_sample(sample) | |
reversed_encoder_out = model.encoder.forward(reverse_sample["net_input"]["src_tokens"], reverse_sample["net_input"]["src_lengths"]).encoder_out | |
contrastive_loss = self.get_contrastive_loss( | |
encoder_out, | |
reversed_encoder_out, | |
sample, | |
reverse_sample, | |
) | |
sample_size = ( | |
sample["target"].size(0) if self.sentence_avg else sample["ntokens"] | |
) | |
nsentences = sample["target"].size(0) | |
ntokens = sample["ntokens"] | |
all_loss = loss + contrastive_loss * self.contrastive_lambda * ntokens / nsentences | |
logging_output = { | |
"loss": loss.data, | |
"nll_loss": nll_loss.data, | |
"ntokens": ntokens, | |
"nsentences": nsentences, | |
"sample_size": sample_size, | |
} | |
if isinstance(contrastive_loss, int): | |
logging_output["contrastive_loss"] = 0 | |
else: | |
logging_output["contrastive_loss"] = utils.item(contrastive_loss.data) | |
return all_loss, sample_size, logging_output | |
def similarity_function(self, ): | |
return nn.CosineSimilarity(dim=-1) | |
def get_contrastive_loss(self, encoder_out1, encoder_out2, sample1, sample2): | |
def _sentence_embedding(encoder_out, sample): | |
encoder_output = encoder_out.transpose(0, 1) | |
src_tokens = sample["net_input"]["src_tokens"] | |
mask = (src_tokens != self.padding_idx) | |
encoder_embedding = (encoder_output * mask.unsqueeze(-1)).sum(dim=1) / mask.float().sum(dim=1).unsqueeze(-1) # [batch, hidden_size] | |
return encoder_embedding | |
encoder_embedding1 = _sentence_embedding(encoder_out1, sample1) # [batch, hidden_size] | |
encoder_embedding2 = _sentence_embedding(encoder_out2, sample2) # [batch, hidden_size] | |
batch_size = encoder_embedding2.shape[0] | |
feature_dim = encoder_embedding2.shape[1] | |
anchor_feature = encoder_embedding1 | |
contrast_feature = encoder_embedding2 | |
similarity_function = self.similarity_function() | |
anchor_dot_contrast = similarity_function(anchor_feature.expand((batch_size, batch_size, feature_dim)), | |
torch.transpose(contrast_feature.expand((batch_size, batch_size, feature_dim)), 0, 1)) | |
loss = -nn.LogSoftmax(0)(torch.div(anchor_dot_contrast, self.temperature)).diag().sum() | |
return loss | |
def reduce_metrics(cls, logging_outputs) -> None: | |
super().reduce_metrics(logging_outputs) | |
nsentences = utils.item( | |
sum(log.get("nsentences", 0) for log in logging_outputs) | |
) | |
contrastive_loss = utils.item( | |
sum(log.get("contrastive_loss", 0) for log in logging_outputs) | |
) | |
metrics.log_scalar( | |
"contrastive_loss", | |
contrastive_loss / nsentences / math.log(2), | |
nsentences, | |
round=3, | |
) | |