|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Evaluator for caption generation metrics used for the MS COCO dataset.""" |
|
import collections |
|
import functools |
|
import os |
|
import tempfile |
|
|
|
import big_vision.evaluators.common as c |
|
import big_vision.input_pipeline |
|
import big_vision.pp.builder |
|
import big_vision.pp.tokenizer |
|
import big_vision.utils as u |
|
|
|
from pycocoevalcap.bleu import bleu |
|
from pycocoevalcap.cider import cider |
|
from pycocoevalcap.meteor import meteor |
|
from pycocoevalcap.rouge import rouge |
|
from pycocoevalcap.spice import spice |
|
from pycocoevalcap.tokenizer import ptbtokenizer |
|
|
|
import jax |
|
|
|
from tensorflow.io import gfile |
|
|
|
|
|
|
|
API = "jit" |
|
|
|
|
|
class Evaluator: |
|
"""Evaluator for caption generation metrics used for the MS COCO dataset. |
|
|
|
See https://arxiv.org/pdf/1504.00325.pdf or the repository implementing it |
|
https://github.com/tylin/coco-caption for details on the metrics. This code |
|
uses the python3 pip package from: https://github.com/salaniz/pycocoevalcap |
|
|
|
Note that both the model caption and the ground truth reference captions are |
|
further processed with the PTBTokenizer before computing scores. |
|
|
|
`predict_fn` accepts arbitrary dictionaries of parameters and data, where |
|
the data dictionary is produced by the `pp_fn` op. It is expected to output a |
|
dict containing tokenized captions. |
|
|
|
`pp_fn` must have fields: "image/id" and "captions". |
|
""" |
|
|
|
def __init__( |
|
self, predict_fn, tokenizer=None, |
|
metrics=("cider",), |
|
preds_outfile="{workdir}/{name}_{split}_preds.json", |
|
annot_outfile="{workdir}/{name}_{split}_annotations.json", |
|
*, data, devices, **kw |
|
): |
|
self.get_data_iter, self.steps = c.eval_input_pipeline( |
|
keep_on_cpu={"image/id", "captions"}, data=data, devices=devices, **kw) |
|
|
|
self.preds_outfile = c.resolve_outfile( |
|
preds_outfile, name=data.get("name"), split=data.get("split")) |
|
self.annot_outfile = c.resolve_outfile( |
|
annot_outfile, name=data.get("name"), split=data.get("split")) |
|
|
|
self.metrics = metrics |
|
self.tok = big_vision.pp.tokenizer.get_tokenizer(tokenizer) |
|
self.decode = functools.partial( |
|
predict_fn, devices=devices, eos_token=self.tok.eos_token) |
|
|
|
def run(self, train_state): |
|
"""Run eval.""" |
|
gts = [] |
|
res = [] |
|
|
|
for _, batch in zip(range(self.steps), self.get_data_iter()): |
|
|
|
tokens = self.decode(train_state, batch) |
|
|
|
|
|
tokens = u.get_local_slice_from_fsarray(tokens) |
|
ex_masks = u.get_local_slice_from_fsarray(batch["_mask"]) |
|
|
|
image_ids = batch["image/id"][ex_masks] |
|
pred_captions = self.tok.to_str(tokens[ex_masks]) |
|
|
|
for image_id, caption in zip(image_ids, pred_captions): |
|
res.append({"image_id": image_id.item(), "caption": caption}) |
|
|
|
for image_id, captions in zip(image_ids, batch["captions"]): |
|
for caption in captions: |
|
gts.append({"image_id": image_id.item(), "caption": caption.item()}) |
|
|
|
|
|
|
|
res = c.multiprocess_write_json(self.preds_outfile, res) |
|
gts = c.multiprocess_write_json(self.annot_outfile, gts) |
|
|
|
if jax.process_index(): |
|
return |
|
|
|
outs = self.evaluate(gts, res) |
|
for key, score in outs.items(): |
|
yield key, score |
|
|
|
def evaluate(self, gt_annotations, res_annotations): |
|
"""Creates scorers and run evaluation.""" |
|
scorers = { |
|
"rouge": rouge.Rouge, |
|
"cider": cider.Cider, |
|
"bleu-4": bleu.Bleu, |
|
"spice": spice.Spice, |
|
"meteor": meteor.Meteor, |
|
} |
|
|
|
|
|
|
|
|
|
|
|
iid_map = collections.defaultdict(lambda: len(iid_map)) |
|
res = {iid_map[x["image_id"]]: [x] for x in res_annotations} |
|
gts = collections.defaultdict(list) |
|
for x in gt_annotations: |
|
gts[iid_map[x["image_id"]]].append(x) |
|
assert sorted(gts.keys()) == sorted(res.keys()) |
|
|
|
|
|
coco_tokenizer = ptbtokenizer.PTBTokenizer() |
|
gts = coco_tokenizer.tokenize(gts) |
|
res = coco_tokenizer.tokenize(res) |
|
|
|
scores = {} |
|
for metric in self.metrics: |
|
scorer = scorers[metric]() |
|
scores[metric], _ = scorer.compute_score(gts, res) |
|
return scores |
|
|