|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Evaluator for simple VQA variants with per answer-type metrics. |
|
|
|
According to the (A-)OKVAQ papers, the eval for these datasets should follow |
|
VQAv2. But here we don't track different answer-types, and don't do any |
|
leave-one-out averaging, as this isn't done in the official implementation at |
|
https://github.com/allenai/aokvqa/blob/main/evaluation/eval_predictions.py |
|
either. |
|
""" |
|
|
|
import functools |
|
|
|
import big_vision.evaluators.common as c |
|
import big_vision.pp.tokenizer |
|
import big_vision.utils as u |
|
import editdistance |
|
|
|
|
|
|
|
|
|
API = "jit" |
|
|
|
QUESTION_TYPES = ("comp", "count", "presence", "rural_urban", "area") |
|
|
|
ACC_SUBSETS = ( |
|
("nonum", ("comp", "presence", "rural_urban")), |
|
("nonum", ("comp", "presence")), |
|
) |
|
|
|
|
|
class Evaluator: |
|
"""Evaluator for simple VQA tasks.""" |
|
|
|
def __init__( |
|
self, predict_fn, tokenizer, to_lower=False, |
|
outfile="{workdir}/{split}.json", |
|
*, data, devices, **kw): |
|
self.get_data_iter, self.steps = c.eval_input_pipeline( |
|
keep_on_cpu={"answers", "answer", "question_id", "question_type"}, |
|
data=data, devices=devices, **kw) |
|
|
|
self.outfile = c.resolve_outfile(outfile, split=data.get("split")) |
|
|
|
|
|
self.tok = big_vision.pp.tokenizer.get_tokenizer(tokenizer) |
|
self.postproc = (lambda s: s.lower()) if to_lower else lambda s: s |
|
self.decode = functools.partial( |
|
predict_fn, devices=devices, eos_token=self.tok.eos_token) |
|
|
|
def run(self, train_state): |
|
"""Does one evaluation run, yields metrics.""" |
|
|
|
accuracies = [] |
|
accuracies_any = [] |
|
counts_per_type = {t: 0 for t in QUESTION_TYPES} |
|
accuracies_per_type = {t: [] for t in QUESTION_TYPES} |
|
anls_values = [] |
|
json_out = [] |
|
for _, batch in zip(range(self.steps), self.get_data_iter()): |
|
|
|
tokens = self.decode(train_state, batch) |
|
|
|
|
|
tokens = u.get_local_slice_from_fsarray(tokens) |
|
ex_masks = u.get_local_slice_from_fsarray(batch["_mask"]) |
|
|
|
|
|
for i in range(len(tokens)): |
|
if ex_masks[i] == 0: |
|
continue |
|
|
|
answer = self.postproc(self.tok.to_str(tokens[i], stop_at_eos=True)) |
|
|
|
|
|
if "answer" in batch: |
|
|
|
gt = self.postproc(batch["answer"][i]) |
|
gts = [gt] |
|
accuracies.append(float(answer == gt)) |
|
accuracies_any.append(float(answer == gt)) |
|
anls_values.append(anls_metric(gt, answer)) |
|
elif "answers" in batch and (gt_answers := batch["answers"][i]).size: |
|
|
|
|
|
gts = [self.postproc(a) for a in gt_answers] |
|
num_match = sum([answer == gt for gt in gts]) |
|
accuracies.append(min(1.0, num_match / 3.0)) |
|
accuracies_any.append(min(1.0, float(num_match))) |
|
anls_values.append(max(anls_metric(gt, answer) for gt in gts)) |
|
accuracies_per_type[batch["question_type"][i]].append( |
|
accuracies_any[-1] |
|
) |
|
counts_per_type[batch["question_type"][i]] += 1 |
|
else: |
|
gts = [] |
|
|
|
json_out.append({ |
|
"question_id": batch["question_id"][i].item(), |
|
"answer": answer} | ({"gts": gts} if gts else {})) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sum_accs, sum_accs_any, sum_anls, num_accs, num = c.process_sum( |
|
[sum(accuracies), sum(accuracies_any), sum(anls_values), |
|
len(accuracies), len(json_out)]) |
|
|
|
sum_accs_per_type, sum_cnts_per_type = c.process_sum( |
|
[{k: sum(v) for k, v in accuracies_per_type.items()}, counts_per_type] |
|
) |
|
|
|
|
|
if num_accs: |
|
yield "acc", sum_accs / num_accs |
|
yield "acc_any", sum_accs_any / num_accs |
|
yield "anls", sum_anls / num_accs |
|
acc_types = {} |
|
for k, v in sum_accs_per_type.items(): |
|
if sum_cnts_per_type[k]: |
|
acc_types[k] = v / sum_cnts_per_type[k] |
|
yield f"acc_{k}", acc_types[k] |
|
yield "acc_avg", sum(acc_types.values()) / len(acc_types) |
|
for postfix, types in ACC_SUBSETS: |
|
if all(t in acc_types for t in types): |
|
yield f"acc_avg_{postfix}", sum( |
|
[v for k, v in acc_types.items() if k in types] |
|
) / len(types) |
|
yield "num", num |
|
c.multiprocess_write_json(self.outfile, json_out) |
|
|
|
|
|
def anls_metric(target: str, prediction: str, theta: float = 0.5): |
|
"""Calculates ANLS for DocVQA. |
|
|
|
There does not seem to be an official evaluation script. |
|
Public implementation on which this implementation is based: |
|
https://github.com/herobd/layoutlmv2/blob/main/eval_docvqa.py#L92 |
|
|
|
Original paper (see Eq 1): https://arxiv.org/pdf/1907.00490.pdf |
|
|
|
Args: |
|
target: Target string. |
|
prediction: Predicted string. |
|
theta: Filter threshold set to 0.5 for DocVQA. |
|
|
|
Returns: |
|
ANLS score. |
|
""" |
|
if target: |
|
edit_distance = editdistance.eval(target, prediction) |
|
normalized_ld = edit_distance / max(len(target), len(prediction)) |
|
return 1 - normalized_ld if normalized_ld < theta else 0 |
|
else: |
|
return float(prediction == "") |
|
|