|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Evaluator to run inference and store results.""" |
|
import functools |
|
|
|
import big_vision.evaluators.common as c |
|
import big_vision.input_pipeline |
|
import big_vision.pp.builder |
|
import big_vision.pp.tokenizer |
|
import big_vision.utils as u |
|
|
|
import jax |
|
|
|
|
|
|
|
API = "jit" |
|
|
|
|
|
class Evaluator: |
|
"""Evaluator to run inference and store results.""" |
|
|
|
def __init__( |
|
self, predict_fn, tokenizer=None, |
|
preds_outfile="{workdir}/{name}_{split}_preds.json", |
|
annot_outfile="{workdir}/{name}_{split}_annotations.json", |
|
id_key="id", |
|
*, data, devices, **kw |
|
): |
|
self.id_key = id_key |
|
self.get_data_iter, self.steps = c.eval_input_pipeline( |
|
keep_on_cpu={id_key}, data=data, devices=devices, **kw) |
|
|
|
self.preds_outfile = c.resolve_outfile( |
|
preds_outfile, name=data.get("name"), split=data.get("split", "")) |
|
self.annot_outfile = c.resolve_outfile( |
|
annot_outfile, name=data.get("name"), split=data.get("split", "")) |
|
|
|
self.tok = big_vision.pp.tokenizer.get_tokenizer(tokenizer) |
|
self.decode = functools.partial( |
|
predict_fn, devices=devices, eos_token=self.tok.eos_token) |
|
|
|
def run(self, train_state): |
|
"""Run eval.""" |
|
res = [] |
|
|
|
for _, batch in zip(range(self.steps), self.get_data_iter()): |
|
|
|
tokens = self.decode(train_state, batch) |
|
|
|
|
|
tokens = u.get_local_slice_from_fsarray(tokens) |
|
ex_masks = u.get_local_slice_from_fsarray(batch["_mask"]) |
|
|
|
image_ids = batch[self.id_key][ex_masks] |
|
pred_captions = self.tok.to_str(tokens[ex_masks]) |
|
|
|
for image_id, caption in zip(image_ids, pred_captions): |
|
res.append({self.id_key: str(image_id), "caption": caption}) |
|
|
|
res = c.multiprocess_write_json(self.preds_outfile, res) |
|
|
|
if jax.process_index(): |
|
return |
|
|
|
yield "num_examples", len(res) |
|
|