|
|
|
|
|
|
|
|
|
|
|
import math |
|
|
|
import torch |
|
import torch.nn.functional as F |
|
from fairseq import utils |
|
from fairseq.criterions import LegacyFairseqCriterion, register_criterion |
|
from fairseq.data import encoders |
|
|
|
|
|
@register_criterion("wsc") |
|
class WSCCriterion(LegacyFairseqCriterion): |
|
def __init__(self, args, task): |
|
super().__init__(args, task) |
|
if self.args.save_predictions is not None: |
|
self.prediction_h = open(self.args.save_predictions, "w") |
|
else: |
|
self.prediction_h = None |
|
self.bpe = encoders.build_bpe(args.bpe) |
|
self.tokenizer = encoders.build_tokenizer(args.tokenizer) |
|
|
|
def __del__(self): |
|
if self.prediction_h is not None: |
|
self.prediction_h.close() |
|
|
|
@staticmethod |
|
def add_args(parser): |
|
"""Add criterion-specific arguments to the parser.""" |
|
parser.add_argument("--wsc-margin-alpha", type=float, metavar="A", default=1.0) |
|
parser.add_argument("--wsc-margin-beta", type=float, metavar="B", default=0.0) |
|
parser.add_argument( |
|
"--wsc-cross-entropy", |
|
action="store_true", |
|
help="use cross entropy formulation instead of margin loss", |
|
) |
|
parser.add_argument( |
|
"--save-predictions", metavar="FILE", help="file to save predictions to" |
|
) |
|
|
|
def get_masked_input(self, tokens, mask): |
|
masked_tokens = tokens.clone() |
|
masked_tokens[mask] = self.task.mask |
|
return masked_tokens |
|
|
|
def get_lprobs(self, model, tokens, mask): |
|
logits, _ = model(src_tokens=self.get_masked_input(tokens, mask)) |
|
lprobs = F.log_softmax(logits, dim=-1, dtype=torch.float) |
|
scores = lprobs.gather(2, tokens.unsqueeze(-1)).squeeze(-1) |
|
mask = mask.type_as(scores) |
|
scores = (scores * mask).sum(dim=-1) / mask.sum(dim=-1) |
|
return scores |
|
|
|
def get_loss(self, query_lprobs, cand_lprobs): |
|
if self.args.wsc_cross_entropy: |
|
return F.cross_entropy( |
|
torch.cat([query_lprobs, cand_lprobs]).unsqueeze(0), |
|
query_lprobs.new([0]).long(), |
|
) |
|
else: |
|
return ( |
|
-query_lprobs |
|
+ self.args.wsc_margin_alpha |
|
* (cand_lprobs - query_lprobs + self.args.wsc_margin_beta).clamp(min=0) |
|
).sum() |
|
|
|
def forward(self, model, sample, reduce=True): |
|
|
|
loss, nloss = 0.0, 0 |
|
ncorrect, nqueries = 0, 0 |
|
|
|
for i, label in enumerate(sample["labels"]): |
|
query_lprobs = self.get_lprobs( |
|
model, |
|
sample["query_tokens"][i].unsqueeze(0), |
|
sample["query_masks"][i].unsqueeze(0), |
|
) |
|
cand_lprobs = self.get_lprobs( |
|
model, |
|
sample["candidate_tokens"][i], |
|
sample["candidate_masks"][i], |
|
) |
|
|
|
pred = (query_lprobs >= cand_lprobs).all().item() |
|
|
|
if label is not None: |
|
label = 1 if label else 0 |
|
ncorrect += 1 if pred == label else 0 |
|
nqueries += 1 |
|
|
|
if label: |
|
|
|
nloss += 1 |
|
loss += self.get_loss(query_lprobs, cand_lprobs) |
|
|
|
id = sample["id"][i].item() |
|
if self.prediction_h is not None: |
|
print("{}\t{}\t{}".format(id, pred, label), file=self.prediction_h) |
|
|
|
if nloss == 0: |
|
loss = torch.tensor(0.0, requires_grad=True) |
|
|
|
sample_size = nqueries if nqueries > 0 else 1 |
|
logging_output = { |
|
"loss": utils.item(loss.data) if reduce else loss.data, |
|
"ntokens": sample["ntokens"], |
|
"nsentences": sample["nsentences"], |
|
"sample_size": sample_size, |
|
"ncorrect": ncorrect, |
|
"nqueries": nqueries, |
|
} |
|
return loss, sample_size, logging_output |
|
|
|
@staticmethod |
|
def aggregate_logging_outputs(logging_outputs): |
|
"""Aggregate logging outputs from data parallel training.""" |
|
loss_sum = sum(log.get("loss", 0) for log in logging_outputs) |
|
ntokens = sum(log.get("ntokens", 0) for log in logging_outputs) |
|
nsentences = sum(log.get("nsentences", 0) for log in logging_outputs) |
|
sample_size = sum(log.get("sample_size", 0) for log in logging_outputs) |
|
|
|
agg_output = { |
|
"loss": loss_sum / sample_size / math.log(2), |
|
"ntokens": ntokens, |
|
"nsentences": nsentences, |
|
"sample_size": sample_size, |
|
} |
|
|
|
ncorrect = sum(log.get("ncorrect", 0) for log in logging_outputs) |
|
nqueries = sum(log.get("nqueries", 0) for log in logging_outputs) |
|
if nqueries > 0: |
|
agg_output["accuracy"] = ncorrect / float(nqueries) |
|
|
|
return agg_output |
|
|
|
|
|
@register_criterion("winogrande") |
|
class WinograndeCriterion(WSCCriterion): |
|
def forward(self, model, sample, reduce=True): |
|
|
|
query_lprobs = self.get_lprobs( |
|
model, |
|
sample["query_tokens"], |
|
sample["query_masks"], |
|
) |
|
cand_lprobs = self.get_lprobs( |
|
model, |
|
sample["candidate_tokens"], |
|
sample["candidate_masks"], |
|
) |
|
pred = query_lprobs >= cand_lprobs |
|
loss = self.get_loss(query_lprobs, cand_lprobs) |
|
|
|
sample_size = sample["query_tokens"].size(0) |
|
ncorrect = pred.sum().item() |
|
logging_output = { |
|
"loss": utils.item(loss.data) if reduce else loss.data, |
|
"ntokens": sample["ntokens"], |
|
"nsentences": sample["nsentences"], |
|
"sample_size": sample_size, |
|
"ncorrect": ncorrect, |
|
"nqueries": sample_size, |
|
} |
|
return loss, sample_size, logging_output |
|
|