herrius's picture
Upload 259 files
32b542e
raw
history blame
3.3 kB
# Copyright 2021 JD.com, Inc., JD AI
"""
@author: Yehao Li
@contact: [email protected]
"""
import os
import sys
import pickle
import json
from json import encoder
from .build import EVALUATION_REGISTRY
@EVALUATION_REGISTRY.register()
class VQAEvaler(object):
def __init__(self, cfg, annfile, output_dir):
super(VQAEvaler, self).__init__()
label2ans_path = os.path.join(cfg.DATALOADER.ANNO_FOLDER, "trainval_label2ans.pkl")
ori_annotation = json.load(open(os.path.join(cfg.DATALOADER.ANNO_FOLDER, "v2_mscoco_val2014_annotations.json")))
self.id2type = {t['question_id']: t['question_type'] for t in ori_annotation['annotations']}
self.id2type_answer = {t['question_id']: t['answer_type'] for t in ori_annotation['annotations']}
self.label2ans = pickle.load(open(label2ans_path, "rb"))
self.id2label = {}
if len(annfile) > 0:
answers_val = pickle.load(open(annfile, "rb"))
for datum in answers_val:
quesid = datum['question_id']
self.id2label[quesid] = {}
for i, label in enumerate(datum['labels']):
label_str = self.label2ans[label]
self.id2label[quesid][label_str] = datum['scores'][i]
if output_dir is not None:
self.output_dir = os.path.join(output_dir, 'results')
if not os.path.exists(self.output_dir):
os.mkdir(self.output_dir)
else:
self.output_dir = None
def eval(self, results, epoch):
for res in results:
res['answer'] = self.label2ans[res['answer']]
if self.output_dir is not None:
json.dump(results, open(os.path.join(self.output_dir, str(epoch) + '.json'), "w", encoding="utf-8"))
accuracy = 0.
acc_by_type = dict()
acc_by_type_answer = dict()
for result in results:
quesid = result['question_id']
ans = result['answer']
if quesid not in self.id2type:
print("Test Stage has no target")
return { "accuracy": 0.0 }
type = self.id2type[quesid]
ans_type = self.id2type_answer[quesid]
if type not in acc_by_type:
acc_by_type[type] = [0,0]
if ans_type not in acc_by_type_answer:
acc_by_type_answer[ans_type] = [0,0]
if quesid not in self.id2label:
return { "accuracy": 0.0 }
datum = self.id2label[quesid]
acc_by_type[type][1] += 1
acc_by_type_answer[ans_type][1] += 1
if ans in datum:
accuracy += datum[ans]
acc_by_type[type][0] += datum[ans]
acc_by_type_answer[ans_type][0] += datum[ans]
accuracy = accuracy / len(results)
print('vqa acc: {}'.format(accuracy*100))
for k in acc_by_type.keys():
acc_by_type[k] = acc_by_type[k][0] / acc_by_type[k][1] * 100
for k in acc_by_type_answer.keys():
acc_by_type_answer[k] = acc_by_type_answer[k][0] / acc_by_type_answer[k][1] * 100
print('vqa acc by question type: ', acc_by_type)
print("vqa acc by answer type: ", acc_by_type_answer)
return { "accuracy": accuracy }