File size: 9,331 Bytes
74e8f2f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 |
# Copyright 2024 Big Vision Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Evaluator for VQAV2 dataset.
"""
import functools
import re
import big_vision.evaluators.common as c
import big_vision.pp.tokenizer
import big_vision.utils as u
import numpy as np
# Temporary global flag to facilitate backwards compatability. Will be removed
# by the end of year 2023.
API = "jit"
class Evaluator:
"""VQAv2 evaluator."""
def __init__(
self, predict_fn, tokenizer, outfile="{workdir}/{split}.json",
*, data, devices, **kw):
self.get_data_iter, self.steps = c.eval_input_pipeline(
keep_on_cpu={"answers", "answer_type", "question_type", "question_id"},
data=data, devices=devices, **kw)
self.outfile = c.resolve_outfile(outfile, split=data.get("split"))
# We'll need the tokenizer to detokenize the model outputs later.
self.tok = big_vision.pp.tokenizer.get_tokenizer(tokenizer)
self.decode = functools.partial(
predict_fn, devices=devices, eos_token=self.tok.eos_token)
def run(self, train_state):
"""Does one evaluation run, yields metrics."""
accuracies_by_type = {"yes/no": [], "number": [], "other": []}
json_out = []
for _, batch in zip(range(self.steps), self.get_data_iter()):
# (batch, seqlen) array of decoded (generated) token sequences suffixes.
tokens = self.decode(train_state, batch)
# (local_batch,) that indicates padding examples (0) vs real examples (1).
tokens = u.get_local_slice_from_fsarray(tokens)
ex_masks = u.get_local_slice_from_fsarray(batch["_mask"])
# Turn predictions into texts and then scores, one by one.
for i in range(len(tokens)):
if ex_masks[i] == 0: # Skip last-batch padding examples
continue
# Extract the suffix/answer from the generated string, skip bos.
answer = self.tok.to_str(tokens[i], stop_at_eos=True)
json = {"question_id": batch["question_id"][i].item(), "answer": answer}
# The rest is computation of VQA-score which compares to multiple GTs.
# This is described better here: https://visualqa.org/evaluation.html
if (gt_answers := batch["answers"][i]).size:
# Always need to do light space-processing:
gt_answers = [stripspace_vqav2(a) for a in gt_answers]
answer = stripspace_vqav2(answer)
# Only post-process if not all agree. Supposedly avoids postproc OCR:
# https://github.com/GT-Vision-Lab/VQA/issues/14#issuecomment-1334695361
if len(set(gt_answers)) > 1:
answer = postprocess_vqav2_text(answer)
gt_answers = [postprocess_vqav2_text(a) for a in gt_answers]
# Accuracy is avg over all ten leave-one-out GT's.
# https://github.com/GT-Vision-Lab/VQA/issues/1#issuecomment-199921352
# An answer is counted 100% correct as soon as 3 GT's agree with it.
matches = answer == np.array(gt_answers)
acc = np.mean([
np.clip(np.sum(np.delete(matches, i_leave_out)) / 3, 0, 1)
for i_leave_out in range(10)
])
accuracies_by_type[batch["answer_type"][i]].append(acc)
# Update json with fully post-processed answer and gt:
json["answer_raw"] = json["answer"]
json["answer"] = answer
json["gts"] = gt_answers
json_out.append(json)
# At this point `accuracies` is a list of per-example scores. However,
# remember that each host holds a different subset of the examples! So if
# we were to just return the mean accuracy here, we would effectively only
# have evaluated on the main host's (who writes metrics) subset!
# So now, we need to compute global means.
# There is one more caveat: `process_sum` needs the summands on each host
# to have the same size. So we either need to include dummy values for
# the padding examples (last batch, annoying), or we only sum scalars as in
# sufficient statistics, which we do here.
sum_accs = c.process_sum({k: sum(v) for k, v in accuracies_by_type.items()})
num_accs = c.process_sum({k: len(v) for k, v in accuracies_by_type.items()})
num = c.process_sum(len(json_out))
# Yielding metric_name, value means logging the metric.
if n := sum(num_accs.values()):
yield "acc", sum(sum_accs.values()) / n
if n := num_accs["yes/no"]:
yield "acc/yesno", sum_accs["yes/no"] / n
yield "num/yesno", n
if n := num_accs["number"]:
yield "acc/number", sum_accs["number"] / n
yield "num/number", n
if n := num_accs["other"]:
yield "acc/other", sum_accs["other"] / n
yield "num/other", n
yield "num", num # Just for sanity checks.
c.multiprocess_write_json(self.outfile, json_out)
# Post-processing required is described at https://visualqa.org/evaluation.html
def stripspace_vqav2(txt):
return txt.replace("\n", " ").replace("\t", " ").strip()
def postprocess_vqav2_text(txt):
"""Cleanup string according to VQA."""
has_digit_comma = re.search(r"(\d)(\,)(\d)", txt) is not None
out = txt
for p in PUNCT:
# NOTE: digit_comma here looks like a bug in official code, so we follow it.
if has_digit_comma or f"{p} " in txt or f" {p}" in txt:
out = out.replace(p, "")
else:
out = out.replace(p, " ")
# Remove full-stops that aren't part of a number.
out = re.sub(r"(?!<=\d)(\.)(?!\d)", "", out, flags=re.UNICODE)
words = []
for word in out.lower().split():
if word not in ARTICLES:
words.append(REPLACEMENTS.get(word, word))
return " ".join(words)
# pylint: disable=line-too-long
REPLACEMENTS = {
# CONTRACTIONS
"aint": "ain't", "arent": "aren't", "cant": "can't", "couldve": "could've", "couldnt": "couldn't",
"couldn'tve": "couldn't've", "couldnt've": "couldn't've", "didnt": "didn't", "doesnt": "doesn't", "dont": "don't", "hadnt": "hadn't",
"hadnt've": "hadn't've", "hadn'tve": "hadn't've", "hasnt": "hasn't", "havent": "haven't", "hed": "he'd", "hed've": "he'd've",
"he'dve": "he'd've", "hes": "he's", "howd": "how'd", "howll": "how'll", "hows": "how's", "Id've": "I'd've", "I'dve": "I'd've",
"Im": "I'm", "Ive": "I've", "isnt": "isn't", "itd": "it'd", "itd've": "it'd've", "it'dve": "it'd've", "itll": "it'll", "let's": "let's",
"maam": "ma'am", "mightnt": "mightn't", "mightnt've": "mightn't've", "mightn'tve": "mightn't've", "mightve": "might've",
"mustnt": "mustn't", "mustve": "must've", "neednt": "needn't", "notve": "not've", "oclock": "o'clock", "oughtnt": "oughtn't",
"ow's'at": "'ow's'at", "'ows'at": "'ow's'at", "'ow'sat": "'ow's'at", "shant": "shan't", "shed've": "she'd've", "she'dve": "she'd've",
"she's": "she's", "shouldve": "should've", "shouldnt": "shouldn't", "shouldnt've": "shouldn't've", "shouldn'tve": "shouldn't've",
"somebody'd": "somebodyd", "somebodyd've": "somebody'd've", "somebody'dve": "somebody'd've", "somebodyll": "somebody'll",
"somebodys": "somebody's", "someoned": "someone'd", "someoned've": "someone'd've", "someone'dve": "someone'd've",
"someonell": "someone'll", "someones": "someone's", "somethingd": "something'd", "somethingd've": "something'd've",
"something'dve": "something'd've", "somethingll": "something'll", "thats": "that's", "thered": "there'd", "thered've": "there'd've",
"there'dve": "there'd've", "therere": "there're", "theres": "there's", "theyd": "they'd", "theyd've": "they'd've",
"they'dve": "they'd've", "theyll": "they'll", "theyre": "they're", "theyve": "they've", "twas": "'twas", "wasnt": "wasn't",
"wed've": "we'd've", "we'dve": "we'd've", "weve": "we've", "werent": "weren't", "whatll": "what'll", "whatre": "what're",
"whats": "what's", "whatve": "what've", "whens": "when's", "whered": "where'd", "wheres": "where's", "whereve": "where've",
"whod": "who'd", "whod've": "who'd've", "who'dve": "who'd've", "wholl": "who'll", "whos": "who's", "whove": "who've", "whyll": "why'll",
"whyre": "why're", "whys": "why's", "wont": "won't", "wouldve": "would've", "wouldnt": "wouldn't", "wouldnt've": "wouldn't've",
"wouldn'tve": "wouldn't've", "yall": "y'all", "yall'll": "y'all'll", "y'allll": "y'all'll", "yall'd've": "y'all'd've",
"y'alld've": "y'all'd've", "y'all'dve": "y'all'd've", "youd": "you'd", "youd've": "you'd've", "you'dve": "you'd've",
"youll": "you'll", "youre": "you're", "youve": "you've",
# NUMBERS
"none": "0", "zero": "0", "one": "1", "two": "2",
"three": "3", "four": "4", "five": "5", "six": "6",
"seven": "7", "eight": "8", "nine": "9", "ten": "10",
}
# pylint: enable=line-too-long
PUNCT = [
";", "/", "[", "]", "\"", "{", "}",
"(", ")", "=", "+", "\\", "_", "-",
">", "<", "@", "`", ",", "?", "!"
]
ARTICLES = {"a", "an", "the"}
|