pranavSIT's picture
added pali inference
74e8f2f
# Copyright 2024 Big Vision Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Evaluator for simple VQA variants with per answer-type metrics.
According to the (A-)OKVAQ papers, the eval for these datasets should follow
VQAv2. But here we don't track different answer-types, and don't do any
leave-one-out averaging, as this isn't done in the official implementation at
https://github.com/allenai/aokvqa/blob/main/evaluation/eval_predictions.py
either.
"""
import functools
import big_vision.evaluators.common as c
import big_vision.pp.tokenizer
import big_vision.utils as u
import editdistance
# Temporary global flag to facilitate backwards compatability. Will be removed
# by the end of year 2023.
API = "jit"
QUESTION_TYPES = ("comp", "count", "presence", "rural_urban", "area")
ACC_SUBSETS = (
("nonum", ("comp", "presence", "rural_urban")), # rsvqa_lr
("nonum", ("comp", "presence")), # rsvqa_hr
)
class Evaluator:
"""Evaluator for simple VQA tasks."""
def __init__(
self, predict_fn, tokenizer, to_lower=False,
outfile="{workdir}/{split}.json",
*, data, devices, **kw):
self.get_data_iter, self.steps = c.eval_input_pipeline(
keep_on_cpu={"answers", "answer", "question_id", "question_type"},
data=data, devices=devices, **kw)
self.outfile = c.resolve_outfile(outfile, split=data.get("split"))
# We'll need the tokenizer to detokenize the model outputs later.
self.tok = big_vision.pp.tokenizer.get_tokenizer(tokenizer)
self.postproc = (lambda s: s.lower()) if to_lower else lambda s: s
self.decode = functools.partial(
predict_fn, devices=devices, eos_token=self.tok.eos_token)
def run(self, train_state):
"""Does one evaluation run, yields metrics."""
accuracies = []
accuracies_any = []
counts_per_type = {t: 0 for t in QUESTION_TYPES}
accuracies_per_type = {t: [] for t in QUESTION_TYPES}
anls_values = []
json_out = []
for _, batch in zip(range(self.steps), self.get_data_iter()):
# (batch, seqlen) array of decoded generated tokens.
tokens = self.decode(train_state, batch) # (B,L,E)
# (local_batch,) that indicates padding examples (0) vs real examples (1).
tokens = u.get_local_slice_from_fsarray(tokens)
ex_masks = u.get_local_slice_from_fsarray(batch["_mask"])
# Turn predictions into texts and then scores, one by one.
for i in range(len(tokens)):
if ex_masks[i] == 0: # Skip last-batch padding examples
continue
answer = self.postproc(self.tok.to_str(tokens[i], stop_at_eos=True))
# Now we have two commonly used VQA evaluation modes:
if "answer" in batch:
# single GT (eg ocrvqa): just compare to that answer, done.
gt = self.postproc(batch["answer"][i])
gts = [gt]
accuracies.append(float(answer == gt))
accuracies_any.append(float(answer == gt))
anls_values.append(anls_metric(gt, answer))
elif "answers" in batch and (gt_answers := batch["answers"][i]).size:
# multiple GTs (eg okvqa): introduced by VQA, compare to each of them
# with a threshold, see also: https://visualqa.org/evaluation.html
gts = [self.postproc(a) for a in gt_answers]
num_match = sum([answer == gt for gt in gts])
accuracies.append(min(1.0, num_match / 3.0))
accuracies_any.append(min(1.0, float(num_match)))
anls_values.append(max(anls_metric(gt, answer) for gt in gts))
accuracies_per_type[batch["question_type"][i]].append(
accuracies_any[-1]
)
counts_per_type[batch["question_type"][i]] += 1
else:
gts = []
json_out.append({
"question_id": batch["question_id"][i].item(),
"answer": answer} | ({"gts": gts} if gts else {}))
# At this point `accuracies` is a list of per-example scores. However,
# remember that each host holds a different subset of the examples! So if
# we were to just return the mean accuracy here, we would effectively only
# have evaluated on the main host's (who writes metrics) subset!
# So now, we need to compute global means.
# There is one more caveat: `process_sum` needs the summands on each host
# to have the same size. So we either need to include dummy values for
# the padding examples (last batch, annoying), or we only sum scalars as in
# sufficient statistics, which we do here.
sum_accs, sum_accs_any, sum_anls, num_accs, num = c.process_sum(
[sum(accuracies), sum(accuracies_any), sum(anls_values),
len(accuracies), len(json_out)])
sum_accs_per_type, sum_cnts_per_type = c.process_sum(
[{k: sum(v) for k, v in accuracies_per_type.items()}, counts_per_type]
)
# Yielding metric_name, value means logging the metric.
if num_accs:
yield "acc", sum_accs / num_accs
yield "acc_any", sum_accs_any / num_accs # Overall Accuracy (OA).
yield "anls", sum_anls / num_accs
acc_types = {}
for k, v in sum_accs_per_type.items():
if sum_cnts_per_type[k]:
acc_types[k] = v / sum_cnts_per_type[k]
yield f"acc_{k}", acc_types[k]
yield "acc_avg", sum(acc_types.values()) / len(acc_types) # Avg acc (AA).
for postfix, types in ACC_SUBSETS:
if all(t in acc_types for t in types):
yield f"acc_avg_{postfix}", sum(
[v for k, v in acc_types.items() if k in types]
) / len(types) # Average accuracy per question types subset.
yield "num", num # Just for sanity checks.
c.multiprocess_write_json(self.outfile, json_out)
def anls_metric(target: str, prediction: str, theta: float = 0.5):
"""Calculates ANLS for DocVQA.
There does not seem to be an official evaluation script.
Public implementation on which this implementation is based:
https://github.com/herobd/layoutlmv2/blob/main/eval_docvqa.py#L92
Original paper (see Eq 1): https://arxiv.org/pdf/1907.00490.pdf
Args:
target: Target string.
prediction: Predicted string.
theta: Filter threshold set to 0.5 for DocVQA.
Returns:
ANLS score.
"""
if target:
edit_distance = editdistance.eval(target, prediction)
normalized_ld = edit_distance / max(len(target), len(prediction))
return 1 - normalized_ld if normalized_ld < theta else 0
else:
return float(prediction == "")