UnIVAL / utils /eval_utils.py
mshukor
init
26fd00c
raw
history blame
17.9 kB
# Modified from OFA code.
# Copyright 2022 The OFA-Sys Team.
# All rights reserved.
# This source code is licensed under the Apache 2.0 license
# found in the LICENSE file in the root directory.
import string
import math
import json
from itertools import chain
import os
import torch
import torch.distributed as dist
from data import data_utils
from functools import partial
def get_symbols_to_strip_from_output(generator):
if hasattr(generator, "symbols_to_strip_from_output"):
return generator.symbols_to_strip_from_output
else:
return {generator.bos, generator.eos}
def decode_fn(x, tgt_dict, bpe, generator, tokenizer=None):
x = tgt_dict.string(x.int().cpu(), extra_symbols_to_ignore=get_symbols_to_strip_from_output(generator))
if bpe is not None:
x = bpe.decode(x)
if tokenizer is not None:
x = tokenizer.decode(x)
return x
def eval_caption(task, generator, models, sample, **kwargs):
transtab = str.maketrans({key: None for key in string.punctuation})
hypos = task.inference_step(generator, models, sample)
results = []
for i, sample_id in enumerate(sample["id"].tolist()):
detok_hypo_str = decode_fn(hypos[i][0]["tokens"], task.tgt_dict, task.bpe, generator)
results.append({"image_id": str(sample_id), "caption": detok_hypo_str.translate(transtab).strip()})
return results, None
def eval_vqa_gen(task, generator, models, sample, **kwargs):
if kwargs['beam_search_vqa_eval']:
hypos = task.inference_step(generator, models, sample, prefix_tokens=sample['prefix_tokens'])
results = []
for i, sample_id in enumerate(sample["id"].tolist()):
prefix_len = sample['prefix_tokens'][i].ne(1).sum().item()
detok_hypo_str = decode_fn(hypos[i][0]["tokens"][prefix_len:], task.tgt_dict, task.bpe, generator)
results.append({"question_id": int(sample_id), "answer": detok_hypo_str.strip()})
scores = [ref_dict.get(result['answer'], 0) for ref_dict, result in zip(sample['ref_dict'], results)]
return results, scores
encoder_out = models[0].encoder(
sample["net_input"]["src_tokens"],
src_lengths=sample["net_input"]["src_lengths"],
patch_images=sample["net_input"]["patch_images"],
patch_masks=sample["net_input"]["patch_masks"]
)
device = sample["net_input"]["src_tokens"].device
eos_item = torch.tensor([task.src_dict.eos()])
pad = task.src_dict.pad()
valid_result = []
for valid_answers, valid_constraint_masks in zip(task.valid_answers_list, task.valid_constraint_masks_list):
valid_size = len(valid_answers)
valid_tgt_items = [
torch.cat([torch.tensor(decoder_prompt[1:]), valid_answer, eos_item])
for decoder_prompt in sample["decoder_prompts"] for valid_answer in valid_answers
]
valid_prev_items = [
torch.cat([torch.tensor(decoder_prompt), valid_answer])
for decoder_prompt in sample["decoder_prompts"] for valid_answer in valid_answers
]
valid_constraint_mask_items = [
torch.cat(
[torch.zeros(len(decoder_prompt) - 1, valid_constraint_mask.size(1)).bool(), valid_constraint_mask],
dim=0
)
for decoder_prompt in sample["decoder_prompts"] for valid_constraint_mask in valid_constraint_masks
]
valid_tgt = data_utils.collate_tokens(valid_tgt_items, pad_idx=pad).to(device)
valid_prev_output = data_utils.collate_tokens(valid_prev_items, pad_idx=pad).to(device)
valid_constraint_masks = data_utils.collate_tokens(valid_constraint_mask_items, pad_idx=pad).to(device)
new_encoder_out = {}
new_encoder_out["encoder_out"] = [
encoder_out["encoder_out"][0].repeat_interleave(valid_size, dim=1)
]
new_encoder_out["encoder_padding_mask"] = [
encoder_out["encoder_padding_mask"][0].repeat_interleave(valid_size, dim=0)
]
new_encoder_out["position_embeddings"] = [
encoder_out["position_embeddings"][0].repeat_interleave(valid_size, dim=0)
]
decoder_out = models[0].decoder(valid_prev_output, encoder_out=new_encoder_out)
decoder_out[0].masked_fill_(~valid_constraint_masks, -math.inf)
lprobs = models[0].get_normalized_probs(decoder_out, log_probs=True)
scores = lprobs.gather(dim=-1, index=valid_tgt.unsqueeze(-1)).squeeze(-1)
scores = scores.masked_fill(valid_tgt.eq(task.tgt_dict.pad()), 0)
scores = scores.masked_fill((~valid_constraint_masks).all(2), 0)
scores = scores.sum(1)
scores = scores.view(-1, valid_size)
valid_result.append(scores)
valid_result = torch.cat(valid_result, dim=-1)
predicts = valid_result.argmax(1).tolist()
hyps = [task.index2ans[predict_index] for predict_index in predicts]
results = [{"question_id": int(id), "answer": hyp} for id, hyp in zip(sample["id"].tolist(), hyps)]
scores = [ref_dict.get(hyp, 0) for ref_dict, hyp in zip(sample['ref_dict'], hyps)]
return results, scores
def eval_refcoco(task, generator, models, sample, **kwargs):
def _calculate_ap_score(hyps, refs, thresh=0.5, min_area_size=None, max_area_size=None):
interacts = torch.cat(
[torch.where(hyps[:, :2] < refs[:, :2], refs[:, :2], hyps[:, :2]),
torch.where(hyps[:, 2:] < refs[:, 2:], hyps[:, 2:], refs[:, 2:])],
dim=1
)
area_predictions = (hyps[:, 2] - hyps[:, 0]) * (hyps[:, 3] - hyps[:, 1])
area_targets = (refs[:, 2] - refs[:, 0]) * (refs[:, 3] - refs[:, 1])
interacts_w = interacts[:, 2] - interacts[:, 0]
interacts_h = interacts[:, 3] - interacts[:, 1]
area_interacts = interacts_w * interacts_h
ious = area_interacts / (area_predictions + area_targets - area_interacts + 1e-6)
if max_area_size is not None and min_area_size is not None:
ious = ious * (area_targets > max_area_size).float() * (area_targets < min_area_size).float()
elif min_area_size is not None:
ious = ious * (area_targets > min_area_size).float()
elif max_area_size is not None:
ious = ious * (area_targets < max_area_size).float()
if thresh is None:
return ious
else:
return ((ious >= thresh) & (interacts_w > 0) & (interacts_h > 0)).float()
gen_out = task.inference_step(generator, models, sample)
hyps_ = []
refs_ = []
for i in range(len(gen_out)):
hyps_.append(gen_out[i][0]["tokens"][:-1] - len(task.src_dict) + task.cfg.num_bins)
refs_.append(sample["target"][i][:-1] - len(task.src_dict) + task.cfg.num_bins)
refs_ = torch.stack(refs_, dim=0)
hyps_ = torch.stack(hyps_, dim=0)
hyps = hyps_ / (task.cfg.num_bins - 1) * task.cfg.max_image_size
hyps[:, ::2] /= sample['w_resize_ratios'].unsqueeze(1)
hyps[:, 1::2] /= sample['h_resize_ratios'].unsqueeze(1)
results = [
{"uniq_id": sample_id,
"box": [hyps[i][0].item(), hyps[i][1].item(), hyps[i][2].item(), hyps[i][3].item()]}
for i, sample_id in enumerate(sample["id"].tolist())
]
scores_list = []
names = []
evaluate_cfg = kwargs['evaluate_cfg'] # task.cfg
threshs = evaluate_cfg.acc_thresh
if threshs is not None:
if ',' in threshs:
threshs = threshs.split(',')
if not isinstance(threshs, list):
threshs = [threshs]
threshs = [float(t) for t in threshs]
for thresh in threshs:
scores = _calculate_ap_score(hyps, sample['region_coords'].float(), thresh=thresh)
names.append(str(thresh)+'acc')
scores_list.append(scores)
if evaluate_cfg.min_area_size is not None:
large_scores = _calculate_ap_score(hyps, sample['region_coords'].float(), thresh=thresh,
min_area_size=evaluate_cfg.min_area_size)
scores_list.append(large_scores)
names.append(str(thresh)+'large_acc')
if evaluate_cfg.max_area_size is not None:
small_scores = _calculate_ap_score(hyps, sample['region_coords'].float(), thresh=thresh,
max_area_size=evaluate_cfg.max_area_size)
scores_list.append(small_scores)
names.append(str(thresh)+'small_acc')
if evaluate_cfg.max_area_size is not None and evaluate_cfg.min_area_size is not None:
medium_scores = _calculate_ap_score(hyps, sample['region_coords'].float(), thresh=thresh,
max_area_size=evaluate_cfg.max_area_size, min_area_size=evaluate_cfg.min_area_size)
scores_list.append(medium_scores)
names.append(str(thresh)+'medium_acc')
if len(scores_list) > 0:
scores = scores_list #[scores] + scores_list
results = [names, results]
return results, scores
def eval_snli_ve(task, generator, models, sample, **kwargs):
encoder_out = models[0].encoder(
sample["net_input"]["src_tokens"],
src_lengths=sample["net_input"]["src_lengths"],
patch_images=sample["net_input"]["patch_images"],
patch_masks=sample["net_input"]["patch_masks"]
)
device = sample["net_input"]["src_tokens"].device
eos_item = torch.tensor([task.src_dict.eos()])
pad = task.src_dict.pad()
valid_result = []
for valid_answers, valid_constraint_masks in zip(task.valid_answers_list, task.valid_constraint_masks_list):
valid_size = len(valid_answers)
valid_tgt_items = [
torch.cat([torch.tensor(decoder_prompt[1:]), valid_answer, eos_item])
for decoder_prompt in sample["decoder_prompts"] for valid_answer in valid_answers
]
valid_prev_items = [
torch.cat([torch.tensor(decoder_prompt), valid_answer])
for decoder_prompt in sample["decoder_prompts"] for valid_answer in valid_answers
]
valid_constraint_mask_items = [
torch.cat(
[torch.zeros(len(decoder_prompt) - 1, valid_constraint_mask.size(1)).bool(), valid_constraint_mask],
dim=0
)
for decoder_prompt in sample["decoder_prompts"] for valid_constraint_mask in valid_constraint_masks
]
valid_tgt = data_utils.collate_tokens(valid_tgt_items, pad_idx=pad).to(device)
valid_prev_output = data_utils.collate_tokens(valid_prev_items, pad_idx=pad).to(device)
valid_constraint_masks = data_utils.collate_tokens(valid_constraint_mask_items, pad_idx=pad).to(device)
new_encoder_out = {}
new_encoder_out["encoder_out"] = [
encoder_out["encoder_out"][0].repeat_interleave(valid_size, dim=1)
]
new_encoder_out["encoder_padding_mask"] = [
encoder_out["encoder_padding_mask"][0].repeat_interleave(valid_size, dim=0)
]
new_encoder_out["position_embeddings"] = [
encoder_out["position_embeddings"][0].repeat_interleave(valid_size, dim=0)
]
decoder_out = models[0].decoder(valid_prev_output, encoder_out=new_encoder_out)
decoder_out[0].masked_fill_(~valid_constraint_masks, -math.inf)
lprobs = models[0].get_normalized_probs(decoder_out, log_probs=True)
scores = lprobs.gather(dim=-1, index=valid_tgt.unsqueeze(-1)).squeeze(-1)
scores = scores.masked_fill(valid_tgt.eq(task.tgt_dict.pad()), 0)
scores = scores.masked_fill((~valid_constraint_masks).all(2), 0)
scores = scores.sum(1)
scores = scores.view(-1, valid_size)
valid_result.append(scores)
valid_result = torch.cat(valid_result, dim=-1)
predicts = valid_result.argmax(1).tolist()
hyps = [task.index2ans[predict_index] for predict_index in predicts]
results = [{"uniq_id": id, "answer": hyp} for id, hyp in zip(sample["id"].tolist(), hyps)]
scores = [ref_dict.get(hyp, 0) for ref_dict, hyp in zip(sample['ref_dict'], hyps)]
return results, scores
def eval_image_gen(task, generator, models, sample, **kwargs):
hypos, _ = task.inference_image(generator, sample, models)
tokens = sample['net_input']['src_tokens'][0].view(-1).tolist()
caption = task.bpe.decode(task.tgt_dict.string([token for token in tokens if token >= 4]))[
38:].replace('/', '')
text_similarity_score, indices = task.compute_text_similarity(hypos, caption,
sample['net_input']['src_tokens'].device)
results = []
for i, indice in enumerate(indices):
results.append({"sample_id": str(sample["id"][0]), "score": text_similarity_score[i], "image": hypos[indice]})
scores = [max(text_similarity_score).item()]
sorted_hyps = [hypos[indice] for indice in indices]
# dump results
if task.cfg.gen_images_path:
caption_tokens = sample['net_input']['src_tokens'][0].view(-1).tolist()
caption = task.bpe.decode(task.tgt_dict.string([token for token in caption_tokens if token >= 4]))[
38:].replace('/', '')
task.dump_images(sorted_hyps, text=caption, path=os.path.join(task.cfg.gen_images_path, 'all_results'))
task.dump_images(sorted_hyps, text=caption, path=os.path.join(task.cfg.gen_images_path, 'top1'), topk=1)
return results, scores
def eval_image_classify(task, generator, models, sample, **kwargs):
batch_size = sample["net_input"]["src_tokens"].size(0)
encoder_out = models[0].encoder(
sample["net_input"]["src_tokens"],
src_lengths=sample["net_input"]["src_lengths"],
patch_images=sample["net_input"]["patch_images"],
patch_masks=sample["net_input"]["patch_masks"]
)
device = sample["net_input"]["src_tokens"].device
valid_result = []
for valid_tgt, valid_prev_output, valid_constraint_masks in zip(task.valid_tgt_list,
task.valid_prev_output_list,
task.valid_constraint_masks_list):
valid_tgt_size = valid_tgt.size(0)
valid_tgt = valid_tgt.repeat(batch_size, 1).to(device)
valid_prev_output = valid_prev_output.repeat(batch_size, 1).to(device)
valid_constraint_masks = valid_constraint_masks.repeat(batch_size, 1, 1).to(device)
new_encoder_out = {}
new_encoder_out["encoder_out"] = [
encoder_out["encoder_out"][0].repeat_interleave(valid_tgt_size, dim=1)
]
new_encoder_out["encoder_padding_mask"] = [
encoder_out["encoder_padding_mask"][0].repeat_interleave(valid_tgt_size, dim=0)
]
new_encoder_out["position_embeddings"] = [
encoder_out["position_embeddings"][0].repeat_interleave(valid_tgt_size, dim=0)
]
decoder_out = models[0].decoder(valid_prev_output, encoder_out=new_encoder_out)
decoder_out[0].masked_fill_(~valid_constraint_masks, -math.inf)
lprobs = models[0].get_normalized_probs(decoder_out, log_probs=True)
scores = lprobs.gather(dim=-1, index=valid_tgt.unsqueeze(-1)).squeeze(-1)
scores = scores.masked_fill(valid_tgt.eq(task.tgt_dict.pad()), 0)
scores = scores.sum(1)
scores = scores.view(-1, valid_tgt_size)
valid_result.append(scores)
valid_result = torch.cat(valid_result, dim=-1)
predicts = valid_result.argmax(1).tolist()
hyps = [task.index2ans[predict_index] for predict_index in predicts]
scores = [ref_dict.get(hyp, 0) for ref_dict, hyp in zip(sample['ref_dict'], hyps)]
results = [{"uniq_id": id, "answer": hyp} for id, hyp in zip(sample["id"].tolist(), hyps)]
return results, scores
def eval_step(task, generator, models, sample, **kwargs):
if 'caption' in task.cfg._name:
return eval_caption(task, generator, models, sample, **kwargs)
elif 'vqa_gen' in task.cfg._name:
return eval_vqa_gen(task, generator, models, sample, **kwargs)
elif task.cfg._name == 'refcoco':
return eval_refcoco(task, generator, models, sample, **kwargs)
elif task.cfg._name == 'snli_ve':
return eval_snli_ve(task, generator, models, sample, **kwargs)
elif task.cfg._name == 'image_gen':
return eval_image_gen(task, generator, models, sample, **kwargs)
else:
raise NotImplementedError
def merge_results(task, cfg, logger, score_cnt, score_sum, results):
if task.cfg._name == 'image_gen':
if cfg.distributed_training.distributed_world_size > 1:
dist.all_reduce(score_sum.data)
dist.all_reduce(score_cnt.data)
if score_cnt.item() > 0:
logger.info("score_sum: {}, score_cnt: {}, score: {}".format(
score_sum, score_cnt, round(score_sum.item() / score_cnt.item(), 4)
))
else:
gather_results = None
if cfg.distributed_training.distributed_world_size > 1:
gather_results = [None for _ in range(dist.get_world_size())]
dist.all_gather_object(gather_results, results)
dist.all_reduce(score_sum.data)
dist.all_reduce(score_cnt.data)
if score_cnt.item() > 0:
logger.info("score_sum: {}, score_cnt: {}, score: {}".format(
score_sum, score_cnt, round(score_sum.item() / score_cnt.item(), 4)
))
if cfg.distributed_training.distributed_world_size == 1 or dist.get_rank() == 0:
os.makedirs(cfg.common_eval.results_path, exist_ok=True)
output_path = os.path.join(cfg.common_eval.results_path, "{}_predict.json".format(cfg.dataset.gen_subset))
gather_results = list(chain(*gather_results)) if gather_results is not None else results
with open(output_path, 'w') as fw:
json.dump(gather_results, fw)