|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Evaluator for TallyQA dataset.""" |
|
|
|
import functools |
|
|
|
import big_vision.evaluators.common as c |
|
import big_vision.pp.tokenizer |
|
import big_vision.utils as u |
|
|
|
|
|
|
|
|
|
API = "jit" |
|
|
|
|
|
|
|
_LARGEST_COUNT = 15 |
|
|
|
|
|
class Evaluator: |
|
"""TallyQA evaluator.""" |
|
|
|
def __init__(self, predict_fn, tokenizer, *, devices, **kw): |
|
self.get_data_iter, self.steps = c.eval_input_pipeline( |
|
keep_on_cpu={"answer", "issimple"}, devices=devices, **kw) |
|
|
|
|
|
self.tok = big_vision.pp.tokenizer.get_tokenizer(tokenizer) |
|
self.decode = functools.partial( |
|
predict_fn, devices=devices, eos_token=self.tok.eos_token |
|
) |
|
|
|
def run(self, train_state): |
|
"""Does one evaluation run, yields metrics.""" |
|
|
|
accuracies_by_type = {"all": [], "simple": [], "complex": []} |
|
|
|
|
|
accuracies_by_type.update( |
|
{f"count_{i}": [] for i in range(_LARGEST_COUNT + 1)} |
|
) |
|
|
|
for _, batch in zip(range(self.steps), self.get_data_iter()): |
|
|
|
tokens = self.decode(train_state, batch) |
|
|
|
|
|
tokens = u.get_local_slice_from_fsarray(tokens) |
|
ex_masks = u.get_local_slice_from_fsarray(batch["_mask"]) |
|
|
|
|
|
|
|
|
|
for i in range(len(tokens)): |
|
if ex_masks[i] == 0: |
|
continue |
|
|
|
|
|
answer = self.tok.to_str(tokens[i], stop_at_eos=True) |
|
|
|
|
|
answer = _number_word_to_numeral(answer) |
|
|
|
|
|
gt = _number_word_to_numeral(batch["answer"][i]) |
|
accuracies_by_type["all"].append(float(answer == gt)) |
|
|
|
if "issimple" in batch: |
|
|
|
if batch["issimple"][i] == 1: |
|
accuracies_by_type["simple"].append(float(answer == gt)) |
|
elif batch["issimple"][i] == 0: |
|
accuracies_by_type["complex"].append(float(answer == gt)) |
|
else: |
|
|
|
|
|
pass |
|
|
|
|
|
accuracies_by_type[f"count_{gt}"].append(float(answer == gt)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sum_accs = c.process_sum({k: sum(v) for k, v in accuracies_by_type.items()}) |
|
num_accs = c.process_sum({k: len(v) for k, v in accuracies_by_type.items()}) |
|
|
|
if n := num_accs["all"]: |
|
yield "acc", sum_accs["all"] / n |
|
yield "num", n |
|
for key in sum_accs.keys(): |
|
if (key != "all") and (num_accs[key]): |
|
yield f"acc/{key}", sum_accs[key] / num_accs[key] |
|
yield f"num/{key}", num_accs[key] |
|
|
|
|
|
def _number_word_to_numeral(s: str) -> str: |
|
"""Returns numeral for a given number word, e.g., "one" -> "1" (up to 20).""" |
|
return REPLACEMENTS.get(s.lower(), s) |
|
|
|
|
|
REPLACEMENTS = { |
|
"none": "0", |
|
"zero": "0", |
|
"one": "1", |
|
"two": "2", |
|
"three": "3", |
|
"four": "4", |
|
"five": "5", |
|
"six": "6", |
|
"seven": "7", |
|
"eight": "8", |
|
"nine": "9", |
|
"ten": "10", |
|
"eleven": "11", |
|
"twelve": "12", |
|
"thirteen": "13", |
|
"fourteen": "14", |
|
"fifteen": "15", |
|
"sixteen": "16", |
|
"seventeen": "17", |
|
"eighteen": "18", |
|
"nineteen": "19", |
|
"twenty": "20", |
|
} |
|
|