|
|
|
|
|
|
|
|
|
|
|
|
|
import math |
|
import string |
|
from dataclasses import dataclass, field |
|
from collections import OrderedDict |
|
from typing import Optional |
|
|
|
import torch |
|
from fairseq import metrics, utils |
|
from fairseq.criterions import FairseqCriterion, register_criterion |
|
from fairseq.dataclass import FairseqDataclass |
|
from omegaconf import II |
|
|
|
from data import data_utils |
|
from utils.cider.pyciderevalcap.ciderD.ciderD import CiderD |
|
|
|
|
|
|
|
def scst_loss(lprobs, target, reward, ignore_index=None, reduce=True, ce=False): |
|
|
|
if ce: |
|
loss = -lprobs.gather(dim=-1, index=target.unsqueeze(-1)).squeeze(-1) |
|
elif isinstance(reward, float): |
|
loss = -lprobs.gather(dim=-1, index=target.unsqueeze(-1)).squeeze() * reward |
|
else: |
|
loss = -lprobs.gather(dim=-1, index=target.unsqueeze(-1)).squeeze() * reward.unsqueeze(-1) |
|
|
|
if ignore_index is not None: |
|
pad_mask = target.eq(ignore_index) |
|
loss.masked_fill_(pad_mask, 0.0) |
|
ntokens = (~pad_mask).sum() |
|
else: |
|
loss = loss.squeeze(-1) |
|
ntokens = target.numel() |
|
if reduce: |
|
loss = loss.sum() |
|
return loss, ntokens |
|
|
|
|
|
@dataclass |
|
class RefCOCOScstRewardCriterionConfig(FairseqDataclass): |
|
scst_cider_cached_tokens: Optional[str] = field( |
|
default="coco-train-words.p", |
|
metadata={"help": "path to cached cPickle file used to calculate CIDEr scores"}, |
|
) |
|
ignore_prefix_size: int = field( |
|
default=0, |
|
metadata={"help": "Ignore first N tokens"}, |
|
) |
|
sentence_avg: bool = II("optimization.sentence_avg") |
|
constraint_range: Optional[str] = field( |
|
default=None, |
|
metadata={"help": "constraint range"} |
|
) |
|
|
|
|
|
acc_thresh: Optional[float] = field( |
|
default=None, metadata={"help": "acc thresh for refcoco"} |
|
) |
|
metric: Optional[str] = field( |
|
default='acc', |
|
metadata={"help": "metric"} |
|
) |
|
|
|
max_area_size: Optional[float] = field( |
|
default=None, metadata={"help": "max_area_size"} |
|
) |
|
|
|
min_area_size: Optional[float] = field( |
|
default=None, metadata={"help": "min_area_size"} |
|
) |
|
logprob: Optional[bool] = field( |
|
default=False, metadata={"help": "maximise log prob"} |
|
) |
|
|
|
pos_reward: Optional[float] = field( |
|
default=None, metadata={"help": "pos_reward"} |
|
) |
|
|
|
neg_reward: Optional[float] = field( |
|
default=None, metadata={"help": "neg_reward"} |
|
) |
|
|
|
reinforce: Optional[bool] = field( |
|
default=False, metadata={"help": "reinforce"} |
|
) |
|
|
|
lambda_reinforce: Optional[float] = field( |
|
default=0, metadata={"help": "lambda_reinforce"} |
|
) |
|
|
|
medium_area: Optional[bool] = field( |
|
default=False, metadata={"help": "reinforce"} |
|
) |
|
|
|
@register_criterion( |
|
"refcoco_scst_reward_criterion", dataclass=RefCOCOScstRewardCriterionConfig |
|
) |
|
class RefCOCOScstRewardCriterion(FairseqCriterion): |
|
CIDER_REWARD_WEIGHT = 1 |
|
|
|
def __init__( |
|
self, |
|
task, |
|
scst_cider_cached_tokens, |
|
sentence_avg, |
|
ignore_prefix_size=0, |
|
constraint_range=None, |
|
acc_thresh=None, |
|
metric='acc', |
|
max_area_size=None, |
|
min_area_size=None, |
|
logprob=False, |
|
pos_reward=None, |
|
neg_reward=None, |
|
reinforce=False, |
|
lambda_reinforce=0, |
|
medium_area=False, |
|
): |
|
super().__init__(task) |
|
self.sentence_avg = sentence_avg |
|
self.ignore_prefix_size = ignore_prefix_size |
|
self.transtab = str.maketrans({key: None for key in string.punctuation}) |
|
|
|
self.constraint_start = None |
|
self.constraint_end = None |
|
if constraint_range is not None: |
|
constraint_start, constraint_end = constraint_range.split(',') |
|
self.constraint_start = int(constraint_start) |
|
self.constraint_end = int(constraint_end) |
|
|
|
self.metric = metric |
|
print("metric", metric) |
|
|
|
self.acc_thresh = acc_thresh |
|
self.metric = metric |
|
self.min_area_size = min_area_size |
|
self.max_area_size = max_area_size |
|
self.logprob = logprob |
|
|
|
self.pos_reward = pos_reward |
|
self.neg_reward = neg_reward |
|
|
|
self.reinforce = reinforce |
|
self.lambda_reinforce = lambda_reinforce |
|
|
|
self.medium_area = medium_area |
|
|
|
|
|
|
|
|
|
def forward(self, model, sample, update_num=0, reduce=True): |
|
"""Compute the loss for the given sample. |
|
|
|
Returns a tuple with three elements: |
|
1) the loss |
|
2) the sample size, which is used as the denominator for the gradient |
|
3) logging outputs to display while training |
|
""" |
|
loss, score, ntokens, nsentences = self.compute_loss(model, sample, reduce=reduce) |
|
|
|
sample_size = ( |
|
nsentences if self.sentence_avg else ntokens |
|
) |
|
logging_output = { |
|
"loss": loss.data, |
|
"score": score, |
|
"ntokens": ntokens, |
|
"nsentences": nsentences, |
|
"sample_size": sample_size, |
|
} |
|
return loss, sample_size, logging_output |
|
|
|
def _calculate_eval_scores(self, gen_res, gt_idx, gt_res): |
|
''' |
|
gen_res: generated captions, list of str |
|
gt_idx: list of int, of the same length as gen_res |
|
gt_res: ground truth captions, list of list of str. |
|
gen_res[i] corresponds to gt_res[gt_idx[i]] |
|
Each image can have multiple ground truth captions |
|
''' |
|
|
|
gen_res_size = len(gen_res) |
|
|
|
res = OrderedDict() |
|
for i in range(gen_res_size): |
|
res[i] = [self._wrap_sentence(gen_res[i].strip().translate(self.transtab))] |
|
|
|
gts = OrderedDict() |
|
gt_res_ = [ |
|
[self._wrap_sentence(gt_res[i][j].strip().translate(self.transtab)) for j in range(len(gt_res[i]))] |
|
for i in range(len(gt_res)) |
|
] |
|
for i in range(gen_res_size): |
|
gts[i] = gt_res_[gt_idx[i]] |
|
|
|
res_ = [{'image_id':i, 'caption': res[i]} for i in range(len(res))] |
|
|
|
|
|
if self.metric != 'cider': |
|
predicts = [res[i][0] if isinstance(res[i], list) else res[i] for i in range(len(res))] |
|
|
|
answers = [gts[i] for i in range(gen_res_size)] |
|
|
|
results = self.evaluator.run_evaluation(predicts, answers) |
|
batch_cider_scores = results[self.metric] |
|
|
|
batch_cider_scores = torch.tensor(batch_cider_scores).repeat(gen_res_size) |
|
else: |
|
_, batch_cider_scores = self.scst_cider_scorer.compute_score(gts, res_) |
|
|
|
scores = self.CIDER_REWARD_WEIGHT * batch_cider_scores |
|
return scores |
|
|
|
@classmethod |
|
def _wrap_sentence(self, s): |
|
|
|
|
|
r = s.strip() |
|
if r.endswith('.'): |
|
r = r[:-1] |
|
r += ' <eos>' |
|
return r |
|
|
|
|
|
def get_generator_out(self, model, sample): |
|
|
|
|
|
model.eval() |
|
with torch.no_grad(): |
|
self.task.scst_generator.model.eval() |
|
gen_out = self.task.scst_generator.generate([model], sample) |
|
|
|
gen_target = [] |
|
gen_res = [] |
|
gt_res = [] |
|
for i in range(len(gen_out)): |
|
gen_res.append(gen_out[i][0]["tokens"][:-1] - len(self.task.src_dict) + self.task.cfg.num_bins) |
|
gt_res.append(sample["target"][i][:-1] - len(self.task.src_dict) + self.task.cfg.num_bins) |
|
gen_target.append(gen_out[i][0]["tokens"][:-1].int().cpu()) |
|
|
|
return gen_target, gen_res, gt_res |
|
|
|
def _calculate_ap_score(self, hyps, refs, thresh=0.5, min_area_size=None, max_area_size=None, medium_area=False): |
|
interacts = torch.cat( |
|
[torch.where(hyps[:, :2] < refs[:, :2], refs[:, :2], hyps[:, :2]), |
|
torch.where(hyps[:, 2:] < refs[:, 2:], hyps[:, 2:], refs[:, 2:])], |
|
dim=1 |
|
) |
|
area_predictions = (hyps[:, 2] - hyps[:, 0]) * (hyps[:, 3] - hyps[:, 1]) |
|
area_targets = (refs[:, 2] - refs[:, 0]) * (refs[:, 3] - refs[:, 1]) |
|
interacts_w = interacts[:, 2] - interacts[:, 0] |
|
interacts_h = interacts[:, 3] - interacts[:, 1] |
|
area_interacts = interacts_w * interacts_h |
|
ious = area_interacts / (area_predictions + area_targets - area_interacts + 1e-6) |
|
|
|
|
|
if max_area_size is not None and min_area_size is not None: |
|
if medium_area: |
|
ious = ious * (torch.logical_and(area_targets > max_area_size, area_targets < min_area_size).float()) |
|
|
|
else: |
|
ious = ious * (torch.logical_or(area_targets < max_area_size, area_targets > min_area_size).float()) |
|
|
|
elif min_area_size is not None: |
|
if medium_area: |
|
ious = ious * (area_targets < min_area_size).float() |
|
else: |
|
ious = ious * (area_targets > min_area_size).float() |
|
|
|
elif max_area_size is not None: |
|
if medium_area: |
|
ious = ious * (area_targets > max_area_size).float() |
|
else: |
|
ious = ious * (area_targets < max_area_size).float() |
|
|
|
if thresh is None: |
|
return ious |
|
else: |
|
return ((ious >= thresh) & (interacts_w > 0) & (interacts_h > 0)).float() |
|
|
|
|
|
def get_reward_and_scores(self, gen_res, gt_res, device, sample): |
|
|
|
|
|
hyps_, refs_ = torch.stack(gen_res, dim=0), torch.stack(gt_res, dim=0) |
|
|
|
hyps = hyps_ / (self.task.cfg.num_bins - 1) * self.task.cfg.max_image_size |
|
refs = refs_ / (self.task.cfg.num_bins - 1) * self.task.cfg.max_image_size |
|
|
|
hyps[:, ::2] /= sample['w_resize_ratios'].unsqueeze(1) |
|
hyps[:, 1::2] /= sample['h_resize_ratios'].unsqueeze(1) |
|
refs[:, ::2] /= sample['w_resize_ratios'].unsqueeze(1) |
|
refs[:, 1::2] /= sample['h_resize_ratios'].unsqueeze(1) |
|
|
|
if self.metric == 'acc': |
|
scores = self._calculate_ap_score(hyps, sample['region_coords'].float(), thresh=self.acc_thresh, |
|
min_area_size=self.min_area_size, max_area_size=self.max_area_size, medium_area=self.medium_area) |
|
else: |
|
raise NotImplemented |
|
|
|
|
|
if self.pos_reward: |
|
scores = torch.where(scores > 0, self.pos_reward, scores) |
|
if self.neg_reward: |
|
scores = torch.where(scores == 0, self.neg_reward, scores) |
|
|
|
return scores, scores |
|
|
|
|
|
def get_net_output(self, model, sample, gen_target): |
|
def merge(sample_list, eos=self.task.tgt_dict.eos(), move_eos_to_beginning=False): |
|
return data_utils.collate_tokens( |
|
sample_list, |
|
pad_idx=self.padding_idx, |
|
eos_idx=eos, |
|
left_pad=False, |
|
move_eos_to_beginning=move_eos_to_beginning, |
|
) |
|
|
|
batch_size = len(sample["target"]) |
|
gen_target_size = len(gen_target) |
|
seq_per_img = gen_target_size // batch_size |
|
|
|
model.train() |
|
sample_src_tokens = torch.repeat_interleave( |
|
sample['net_input']['src_tokens'], seq_per_img, dim=0 |
|
) |
|
sample_src_lengths = torch.repeat_interleave( |
|
sample['net_input']['src_lengths'], seq_per_img, dim=0 |
|
) |
|
sample_patch_images = torch.repeat_interleave( |
|
sample['net_input']['patch_images'], seq_per_img, dim=0 |
|
) |
|
sample_patch_masks = torch.repeat_interleave( |
|
sample['net_input']['patch_masks'], seq_per_img, dim=0 |
|
) |
|
gen_prev_output_tokens = torch.as_tensor( |
|
merge(gen_target, eos=self.task.tgt_dict.bos(), move_eos_to_beginning=True), |
|
device=sample["target"].device, dtype=torch.int64 |
|
) |
|
gen_target_tokens = torch.as_tensor( |
|
merge(gen_target), device=sample["target"].device, dtype=torch.int64 |
|
) |
|
|
|
net_output = model( |
|
src_tokens=sample_src_tokens, src_lengths=sample_src_lengths, |
|
patch_images=sample_patch_images, patch_masks=sample_patch_masks, |
|
prev_output_tokens=gen_prev_output_tokens |
|
) |
|
|
|
return net_output, gen_target_tokens |
|
|
|
def get_lprobs_and_target(self, model, net_output, gen_target): |
|
if self.constraint_start is not None and self.constraint_end is not None: |
|
net_output[0][:, :, 4:self.constraint_start] = -math.inf |
|
net_output[0][:, :, self.constraint_end:] = -math.inf |
|
lprobs = model.get_normalized_probs(net_output, log_probs=True) |
|
if self.ignore_prefix_size > 0: |
|
if getattr(lprobs, "batch_first", False): |
|
lprobs = lprobs[:, self.ignore_prefix_size :, :].contiguous() |
|
gen_target = gen_target[:, self.ignore_prefix_size :].contiguous() |
|
else: |
|
lprobs = lprobs[self.ignore_prefix_size :, :, :].contiguous() |
|
gen_target = gen_target[self.ignore_prefix_size :, :].contiguous() |
|
return lprobs, gen_target |
|
|
|
def compute_loss(self, model, sample, reduce=True): |
|
gen_target, gen_res, gt_res = self.get_generator_out(model, sample) |
|
reward, scores = self.get_reward_and_scores(gen_res, gt_res, device=sample["target"].device, sample=sample) |
|
|
|
net_output, gen_target_tokens = self.get_net_output(model, sample, gen_target) |
|
|
|
gen_lprobs, gen_target_tokens = self.get_lprobs_and_target(model, net_output, gen_target_tokens) |
|
loss, ntokens = scst_loss(gen_lprobs, gen_target_tokens, reward, ignore_index=self.padding_idx, reduce=reduce) |
|
nsentences = gen_target_tokens.size(0) |
|
|
|
if self.lambda_reinforce > 0: |
|
target = model.get_targets(sample, net_output)[:, :-1] |
|
if self.ignore_prefix_size > 0: |
|
target = target[:, self.ignore_prefix_size :].contiguous() |
|
|
|
loss_ce, ntokens_ = scst_loss(gen_lprobs, target, reward=1, ignore_index=self.padding_idx, reduce=reduce, ce=True) |
|
|
|
loss = loss_ce + self.lambda_reinforce*loss |
|
|
|
return loss, scores.sum(), ntokens, nsentences |
|
|
|
@classmethod |
|
def reduce_metrics(cls, logging_outputs) -> None: |
|
"""Aggregate logging outputs from data parallel training.""" |
|
loss_sum = sum(log.get("loss", 0) for log in logging_outputs) |
|
score_sum = sum(log.get("score", 0) for log in logging_outputs) |
|
ntokens = sum(log.get("ntokens", 0) for log in logging_outputs) |
|
nsentences = sum(log.get("nsentences", 0) for log in logging_outputs) |
|
sample_size = sum(log.get("sample_size", 0) for log in logging_outputs) |
|
|
|
metrics.log_scalar( |
|
"loss", loss_sum / sample_size, sample_size, round=3 |
|
) |
|
metrics.log_scalar( |
|
"score", score_sum / nsentences, nsentences, round=3 |
|
) |
|
|
|
metrics.log_scalar( |
|
"ntokens", ntokens, 1, round=3 |
|
) |
|
metrics.log_scalar( |
|
"nsentences", nsentences, 1, round=3 |
|
) |
|
metrics.log_scalar( |
|
"sample_size", sample_size, 1, round=3 |
|
) |
|
|
|
@staticmethod |
|
def logging_outputs_can_be_summed() -> bool: |
|
""" |
|
Whether the logging outputs returned by `forward` can be summed |
|
across workers prior to calling `reduce_metrics`. Setting this |
|
to True will improves distributed training speed. |
|
""" |
|
return True |
|
|