File size: 2,606 Bytes
74e8f2f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 |
# Copyright 2024 Big Vision Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Evaluator to run inference and store results."""
import functools
import big_vision.evaluators.common as c
import big_vision.input_pipeline
import big_vision.pp.builder
import big_vision.pp.tokenizer
import big_vision.utils as u
import jax
# Temporary global flag to facilitate backwards compatability. Will be removed
# by the end of year 2023.
API = "jit"
class Evaluator:
"""Evaluator to run inference and store results."""
def __init__(
self, predict_fn, tokenizer=None,
preds_outfile="{workdir}/{name}_{split}_preds.json",
annot_outfile="{workdir}/{name}_{split}_annotations.json",
id_key="id",
*, data, devices, **kw
):
self.id_key = id_key
self.get_data_iter, self.steps = c.eval_input_pipeline(
keep_on_cpu={id_key}, data=data, devices=devices, **kw)
self.preds_outfile = c.resolve_outfile(
preds_outfile, name=data.get("name"), split=data.get("split", ""))
self.annot_outfile = c.resolve_outfile(
annot_outfile, name=data.get("name"), split=data.get("split", ""))
self.tok = big_vision.pp.tokenizer.get_tokenizer(tokenizer)
self.decode = functools.partial(
predict_fn, devices=devices, eos_token=self.tok.eos_token)
def run(self, train_state):
"""Run eval."""
res = []
for _, batch in zip(range(self.steps), self.get_data_iter()):
# (batch, seqlen) array of decoded generated tokens.
tokens = self.decode(train_state, batch)
# (local_batch,)
tokens = u.get_local_slice_from_fsarray(tokens)
ex_masks = u.get_local_slice_from_fsarray(batch["_mask"])
image_ids = batch[self.id_key][ex_masks]
pred_captions = self.tok.to_str(tokens[ex_masks])
for image_id, caption in zip(image_ids, pred_captions):
res.append({self.id_key: str(image_id), "caption": caption})
res = c.multiprocess_write_json(self.preds_outfile, res)
if jax.process_index(): # Host0 gets all preds and does eval.
return
yield "num_examples", len(res)
|