|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Evaluator for ScienceQA. |
|
|
|
based on the official implementation at |
|
https://github.com/lupantech/ScienceQA/blob/main/models/run_gpt3.py |
|
""" |
|
|
|
import functools |
|
import re |
|
|
|
import big_vision.evaluators.common as c |
|
import big_vision.pp.tokenizer |
|
import big_vision.utils as u |
|
|
|
|
|
|
|
|
|
API = "jit" |
|
FAILURE = "failed" |
|
|
|
|
|
class Evaluator: |
|
"""Evaluator for simple VQA tasks.""" |
|
|
|
def __init__( |
|
self, predict_fn, tokenizer, |
|
outfile="{workdir}/{split}.json", |
|
out_question_key="question_id", |
|
*, data, devices, **kw): |
|
self.get_data_iter, self.steps = c.eval_input_pipeline( |
|
keep_on_cpu={"answer", "question_id"}, data=data, devices=devices, **kw) |
|
|
|
self.outfile = c.resolve_outfile(outfile, split=data.get("split")) |
|
self.out_question_key = out_question_key |
|
|
|
|
|
self.tok = big_vision.pp.tokenizer.get_tokenizer(tokenizer) |
|
self.decode = functools.partial( |
|
predict_fn, devices=devices, eos_token=self.tok.eos_token |
|
) |
|
|
|
def postproc(self, raw_answer): |
|
"""Post-processes the raw answer. extract a, b, c from the string.""" |
|
match = re.match( |
|
pattern=r"the answer is ([a-z])\.", string=raw_answer.lower() |
|
) |
|
if match: |
|
return match.groups()[0] |
|
else: |
|
return FAILURE |
|
|
|
def run(self, train_state): |
|
"""Does one evaluation run, yields metrics.""" |
|
|
|
accuracies = [] |
|
fail_parse = [] |
|
json_out = [] |
|
for _, batch in zip(range(self.steps), self.get_data_iter()): |
|
|
|
tokens = self.decode(train_state, batch) |
|
|
|
|
|
tokens = u.get_local_slice_from_fsarray(tokens) |
|
ex_masks = u.get_local_slice_from_fsarray(batch["_mask"]) |
|
|
|
|
|
for i in range(len(tokens)): |
|
if ex_masks[i] == 0: |
|
continue |
|
|
|
raw_answer = self.tok.to_str(tokens[i], stop_at_eos=True) |
|
answer = self.postproc(raw_answer) |
|
if "answer" in batch: |
|
gt = self.postproc(batch["answer"][i]) |
|
gts = [gt] |
|
accuracies.append(float(answer == gt)) |
|
fail_parse.append(float(answer == FAILURE)) |
|
else: |
|
gts = [] |
|
|
|
json_out.append( |
|
{ |
|
self.out_question_key: batch["question_id"][i].item(), |
|
"raw_answer": raw_answer, |
|
"answer": answer, |
|
} |
|
| ({"gts": gts} if gts else {}) |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sum_accs, num_parsefail, num_accs, num = c.process_sum( |
|
[sum(accuracies), sum(fail_parse), len(accuracies), len(json_out)] |
|
) |
|
|
|
|
|
if num_accs > 0: |
|
yield "acc", sum_accs / num_accs |
|
yield "parsefail", num_parsefail / num_accs |
|
|
|
yield "num", num |
|
c.multiprocess_write_json(self.outfile, json_out) |
|
|