|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Evaluator for ChartQA variants.""" |
|
|
|
import functools |
|
|
|
import big_vision.evaluators.common as c |
|
import big_vision.pp.tokenizer |
|
import big_vision.utils as u |
|
|
|
|
|
|
|
|
|
API = "jit" |
|
|
|
|
|
class Evaluator: |
|
"""Evaluator for simple VQA tasks.""" |
|
|
|
def __init__( |
|
self, predict_fn, tokenizer, to_lower=False, |
|
outfile="{workdir}/{split}.json", |
|
out_question_key="question_id", out_answer_key="answer", |
|
*, data, devices, **kw): |
|
self.get_data_iter, self.steps = c.eval_input_pipeline( |
|
keep_on_cpu={"answer", "question_id"}, data=data, devices=devices, **kw) |
|
|
|
self.outfile = c.resolve_outfile(outfile, split=data.get("split")) |
|
self.out_question_key = out_question_key |
|
self.out_answer_key = out_answer_key |
|
|
|
|
|
self.tok = big_vision.pp.tokenizer.get_tokenizer(tokenizer) |
|
self.postproc = (lambda s: s.lower()) if to_lower else lambda s: s |
|
self.decode = functools.partial( |
|
predict_fn, devices=devices, eos_token=self.tok.eos_token) |
|
|
|
def run(self, train_state): |
|
"""Does one evaluation run, yields metrics.""" |
|
|
|
accuracies = [] |
|
relaxed_accuracies = [] |
|
json_out = [] |
|
for _, batch in zip(range(self.steps), self.get_data_iter()): |
|
|
|
tokens = self.decode(train_state, batch) |
|
|
|
|
|
tokens = u.get_local_slice_from_fsarray(tokens) |
|
ex_masks = u.get_local_slice_from_fsarray(batch["_mask"]) |
|
|
|
|
|
for i in range(len(tokens)): |
|
if ex_masks[i] == 0: |
|
continue |
|
|
|
answer = self.postproc(self.tok.to_str(tokens[i], stop_at_eos=True)) |
|
|
|
gt = self.postproc(batch["answer"][i]) |
|
accuracies.append(float(answer == gt)) |
|
relaxed_accuracies.append(_relaxed_match(gt, answer)) |
|
json_out.append({ |
|
self.out_question_key: batch["question_id"][i].item(), |
|
self.out_answer_key: answer, |
|
"gt": gt, |
|
"relaxed_match": relaxed_accuracies[-1], |
|
}) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sum_accs, sum_relaxed_accs, num = c.process_sum( |
|
[sum(accuracies), sum(relaxed_accuracies), len(accuracies)]) |
|
|
|
|
|
yield "acc", sum_accs / num |
|
yield "relaxed_acc", sum_relaxed_accs / num |
|
yield "num", num |
|
c.multiprocess_write_json(self.outfile, json_out) |
|
|
|
|
|
def _to_float(text: str) -> float | None: |
|
try: |
|
if text.endswith("%"): |
|
|
|
return float(text.rstrip("%")) / 100.0 |
|
else: |
|
return float(text) |
|
except ValueError: |
|
return None |
|
|
|
|
|
def _relaxed_match(target: str, |
|
prediction: str, |
|
max_relative_error: float = 0.05) -> bool: |
|
"""Calculates relaxed correctness. |
|
|
|
The correctness tolerates certain error ratio defined by max_relative_error. |
|
See https://arxiv.org/pdf/2203.10244.pdf, end of section 5.1: |
|
“Following Methani et al. (2020), we use a relaxed accuracy measure for the |
|
numeric answers to allow a minor inaccuracy that may result from the automatic |
|
data extraction process. We consider an answer to be correct if it is within |
|
5% of the gold answer. For non-numeric answers, we still need an exact match |
|
to consider an answer to be correct.” |
|
|
|
Args: |
|
target: Target string. |
|
prediction: Predicted string. |
|
max_relative_error: Maximum relative error. |
|
|
|
Returns: |
|
Whether the prediction was correct given the specified tolerance. |
|
""" |
|
prediction_float = _to_float(prediction) |
|
target_float = _to_float(target) |
|
|
|
if prediction_float is not None and target_float: |
|
relative_error = abs(prediction_float - target_float) / abs(target_float) |
|
return relative_error <= max_relative_error |
|
else: |
|
return prediction == target |
|
|