|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Evaluator for VQAV2 dataset. |
|
""" |
|
import functools |
|
import re |
|
|
|
import big_vision.evaluators.common as c |
|
import big_vision.pp.tokenizer |
|
import big_vision.utils as u |
|
import numpy as np |
|
|
|
|
|
|
|
|
|
API = "jit" |
|
|
|
|
|
class Evaluator: |
|
"""VQAv2 evaluator.""" |
|
|
|
def __init__( |
|
self, predict_fn, tokenizer, outfile="{workdir}/{split}.json", |
|
*, data, devices, **kw): |
|
self.get_data_iter, self.steps = c.eval_input_pipeline( |
|
keep_on_cpu={"answers", "answer_type", "question_type", "question_id"}, |
|
data=data, devices=devices, **kw) |
|
|
|
self.outfile = c.resolve_outfile(outfile, split=data.get("split")) |
|
|
|
|
|
self.tok = big_vision.pp.tokenizer.get_tokenizer(tokenizer) |
|
self.decode = functools.partial( |
|
predict_fn, devices=devices, eos_token=self.tok.eos_token) |
|
|
|
def run(self, train_state): |
|
"""Does one evaluation run, yields metrics.""" |
|
accuracies_by_type = {"yes/no": [], "number": [], "other": []} |
|
json_out = [] |
|
|
|
for _, batch in zip(range(self.steps), self.get_data_iter()): |
|
|
|
tokens = self.decode(train_state, batch) |
|
|
|
|
|
tokens = u.get_local_slice_from_fsarray(tokens) |
|
ex_masks = u.get_local_slice_from_fsarray(batch["_mask"]) |
|
|
|
|
|
for i in range(len(tokens)): |
|
if ex_masks[i] == 0: |
|
continue |
|
|
|
|
|
answer = self.tok.to_str(tokens[i], stop_at_eos=True) |
|
json = {"question_id": batch["question_id"][i].item(), "answer": answer} |
|
|
|
|
|
|
|
if (gt_answers := batch["answers"][i]).size: |
|
|
|
gt_answers = [stripspace_vqav2(a) for a in gt_answers] |
|
answer = stripspace_vqav2(answer) |
|
|
|
|
|
|
|
if len(set(gt_answers)) > 1: |
|
answer = postprocess_vqav2_text(answer) |
|
gt_answers = [postprocess_vqav2_text(a) for a in gt_answers] |
|
|
|
|
|
|
|
|
|
matches = answer == np.array(gt_answers) |
|
acc = np.mean([ |
|
np.clip(np.sum(np.delete(matches, i_leave_out)) / 3, 0, 1) |
|
for i_leave_out in range(10) |
|
]) |
|
|
|
accuracies_by_type[batch["answer_type"][i]].append(acc) |
|
|
|
|
|
json["answer_raw"] = json["answer"] |
|
json["answer"] = answer |
|
json["gts"] = gt_answers |
|
|
|
json_out.append(json) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sum_accs = c.process_sum({k: sum(v) for k, v in accuracies_by_type.items()}) |
|
num_accs = c.process_sum({k: len(v) for k, v in accuracies_by_type.items()}) |
|
num = c.process_sum(len(json_out)) |
|
|
|
|
|
if n := sum(num_accs.values()): |
|
yield "acc", sum(sum_accs.values()) / n |
|
if n := num_accs["yes/no"]: |
|
yield "acc/yesno", sum_accs["yes/no"] / n |
|
yield "num/yesno", n |
|
if n := num_accs["number"]: |
|
yield "acc/number", sum_accs["number"] / n |
|
yield "num/number", n |
|
if n := num_accs["other"]: |
|
yield "acc/other", sum_accs["other"] / n |
|
yield "num/other", n |
|
|
|
yield "num", num |
|
c.multiprocess_write_json(self.outfile, json_out) |
|
|
|
|
|
|
|
|
|
|
|
def stripspace_vqav2(txt): |
|
return txt.replace("\n", " ").replace("\t", " ").strip() |
|
|
|
|
|
def postprocess_vqav2_text(txt): |
|
"""Cleanup string according to VQA.""" |
|
has_digit_comma = re.search(r"(\d)(\,)(\d)", txt) is not None |
|
|
|
out = txt |
|
for p in PUNCT: |
|
|
|
if has_digit_comma or f"{p} " in txt or f" {p}" in txt: |
|
out = out.replace(p, "") |
|
else: |
|
out = out.replace(p, " ") |
|
|
|
|
|
out = re.sub(r"(?!<=\d)(\.)(?!\d)", "", out, flags=re.UNICODE) |
|
|
|
words = [] |
|
for word in out.lower().split(): |
|
if word not in ARTICLES: |
|
words.append(REPLACEMENTS.get(word, word)) |
|
return " ".join(words) |
|
|
|
|
|
|
|
REPLACEMENTS = { |
|
|
|
"aint": "ain't", "arent": "aren't", "cant": "can't", "couldve": "could've", "couldnt": "couldn't", |
|
"couldn'tve": "couldn't've", "couldnt've": "couldn't've", "didnt": "didn't", "doesnt": "doesn't", "dont": "don't", "hadnt": "hadn't", |
|
"hadnt've": "hadn't've", "hadn'tve": "hadn't've", "hasnt": "hasn't", "havent": "haven't", "hed": "he'd", "hed've": "he'd've", |
|
"he'dve": "he'd've", "hes": "he's", "howd": "how'd", "howll": "how'll", "hows": "how's", "Id've": "I'd've", "I'dve": "I'd've", |
|
"Im": "I'm", "Ive": "I've", "isnt": "isn't", "itd": "it'd", "itd've": "it'd've", "it'dve": "it'd've", "itll": "it'll", "let's": "let's", |
|
"maam": "ma'am", "mightnt": "mightn't", "mightnt've": "mightn't've", "mightn'tve": "mightn't've", "mightve": "might've", |
|
"mustnt": "mustn't", "mustve": "must've", "neednt": "needn't", "notve": "not've", "oclock": "o'clock", "oughtnt": "oughtn't", |
|
"ow's'at": "'ow's'at", "'ows'at": "'ow's'at", "'ow'sat": "'ow's'at", "shant": "shan't", "shed've": "she'd've", "she'dve": "she'd've", |
|
"she's": "she's", "shouldve": "should've", "shouldnt": "shouldn't", "shouldnt've": "shouldn't've", "shouldn'tve": "shouldn't've", |
|
"somebody'd": "somebodyd", "somebodyd've": "somebody'd've", "somebody'dve": "somebody'd've", "somebodyll": "somebody'll", |
|
"somebodys": "somebody's", "someoned": "someone'd", "someoned've": "someone'd've", "someone'dve": "someone'd've", |
|
"someonell": "someone'll", "someones": "someone's", "somethingd": "something'd", "somethingd've": "something'd've", |
|
"something'dve": "something'd've", "somethingll": "something'll", "thats": "that's", "thered": "there'd", "thered've": "there'd've", |
|
"there'dve": "there'd've", "therere": "there're", "theres": "there's", "theyd": "they'd", "theyd've": "they'd've", |
|
"they'dve": "they'd've", "theyll": "they'll", "theyre": "they're", "theyve": "they've", "twas": "'twas", "wasnt": "wasn't", |
|
"wed've": "we'd've", "we'dve": "we'd've", "weve": "we've", "werent": "weren't", "whatll": "what'll", "whatre": "what're", |
|
"whats": "what's", "whatve": "what've", "whens": "when's", "whered": "where'd", "wheres": "where's", "whereve": "where've", |
|
"whod": "who'd", "whod've": "who'd've", "who'dve": "who'd've", "wholl": "who'll", "whos": "who's", "whove": "who've", "whyll": "why'll", |
|
"whyre": "why're", "whys": "why's", "wont": "won't", "wouldve": "would've", "wouldnt": "wouldn't", "wouldnt've": "wouldn't've", |
|
"wouldn'tve": "wouldn't've", "yall": "y'all", "yall'll": "y'all'll", "y'allll": "y'all'll", "yall'd've": "y'all'd've", |
|
"y'alld've": "y'all'd've", "y'all'dve": "y'all'd've", "youd": "you'd", "youd've": "you'd've", "you'dve": "you'd've", |
|
"youll": "you'll", "youre": "you're", "youve": "you've", |
|
|
|
"none": "0", "zero": "0", "one": "1", "two": "2", |
|
"three": "3", "four": "4", "five": "5", "six": "6", |
|
"seven": "7", "eight": "8", "nine": "9", "ten": "10", |
|
} |
|
|
|
|
|
PUNCT = [ |
|
";", "/", "[", "]", "\"", "{", "}", |
|
"(", ")", "=", "+", "\\", "_", "-", |
|
">", "<", "@", "`", ",", "?", "!" |
|
] |
|
ARTICLES = {"a", "an", "the"} |
|
|