Spaces:
Sleeping
Sleeping
File size: 5,761 Bytes
08d7644 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 |
# coding=utf-8
# Copyleft 2019 project LXRT.
import json
import os
import pickle
import numpy as np
import torch
from torch.utils.data import Dataset
from ..param import args
from ..utils import load_obj_tsv
# Load part of the dataset for fast checking.
# Notice that here is the number of images instead of the number of data,
# which means all related data to the images would be used.
TINY_IMG_NUM = 512
FAST_IMG_NUM = 5000
# The path to data and image features.
VQA_DATA_ROOT = 'data/vqa/'
MSCOCO_IMGFEAT_ROOT = 'data/mscoco_imgfeat/'
SPLIT2NAME = {
'train': 'train2014',
'valid': 'val2014',
'minival': 'val2014',
'nominival': 'val2014',
'test': 'test2015',
}
class VQADataset:
"""
A VQA data example in json file:
{
"answer_type": "other",
"img_id": "COCO_train2014_000000458752",
"label": {
"net": 1
},
"question_id": 458752000,
"question_type": "what is this",
"sent": "What is this photo taken looking through?"
}
"""
def __init__(self, splits: str):
self.name = splits
self.splits = splits.split(',')
# Loading datasets
self.data = []
for split in self.splits:
self.data.extend(json.load(open("data/vqa/%s.json" % split)))
print("Load %d data from split(s) %s." % (len(self.data), self.name))
# Convert list to dict (for evaluation)
self.id2datum = {
datum['question_id']: datum
for datum in self.data
}
# Answers
self.ans2label = json.load(open("data/vqa/trainval_ans2label.json"))
self.label2ans = json.load(open("data/vqa/trainval_label2ans.json"))
assert len(self.ans2label) == len(self.label2ans)
@property
def num_answers(self):
return len(self.ans2label)
def __len__(self):
return len(self.data)
"""
An example in obj36 tsv:
FIELDNAMES = ["img_id", "img_h", "img_w", "objects_id", "objects_conf",
"attrs_id", "attrs_conf", "num_boxes", "boxes", "features"]
FIELDNAMES would be keys in the dict returned by load_obj_tsv.
"""
class VQATorchDataset(Dataset):
def __init__(self, dataset: VQADataset):
super().__init__()
self.raw_dataset = dataset
if args.tiny:
topk = TINY_IMG_NUM
elif args.fast:
topk = FAST_IMG_NUM
else:
topk = None
# Loading detection features to img_data
img_data = []
for split in dataset.splits:
# Minival is 5K images in MS COCO, which is used in evaluating VQA/lxmert-pre-training.
# It is saved as the top 5K features in val2014_***.tsv
load_topk = 5000 if (split == 'minival' and topk is None) else topk
img_data.extend(load_obj_tsv(
os.path.join(MSCOCO_IMGFEAT_ROOT, '%s_obj36.tsv' % (SPLIT2NAME[split])),
topk=load_topk))
# Convert img list to dict
self.imgid2img = {}
for img_datum in img_data:
self.imgid2img[img_datum['img_id']] = img_datum
# Only kept the data with loaded image features
self.data = []
for datum in self.raw_dataset.data:
if datum['img_id'] in self.imgid2img:
self.data.append(datum)
print("Use %d data in torch dataset" % (len(self.data)))
print()
def __len__(self):
return len(self.data)
def __getitem__(self, item: int):
datum = self.data[item]
img_id = datum['img_id']
ques_id = datum['question_id']
ques = datum['sent']
# Get image info
img_info = self.imgid2img[img_id]
obj_num = img_info['num_boxes']
feats = img_info['features'].copy()
boxes = img_info['boxes'].copy()
assert obj_num == len(boxes) == len(feats)
# Normalize the boxes (to 0 ~ 1)
img_h, img_w = img_info['img_h'], img_info['img_w']
boxes = boxes.copy()
boxes[:, (0, 2)] /= img_w
boxes[:, (1, 3)] /= img_h
np.testing.assert_array_less(boxes, 1+1e-5)
np.testing.assert_array_less(-boxes, 0+1e-5)
# Provide label (target)
if 'label' in datum:
label = datum['label']
target = torch.zeros(self.raw_dataset.num_answers)
for ans, score in label.items():
target[self.raw_dataset.ans2label[ans]] = score
return ques_id, feats, boxes, ques, target
else:
return ques_id, feats, boxes, ques
class VQAEvaluator:
def __init__(self, dataset: VQADataset):
self.dataset = dataset
def evaluate(self, quesid2ans: dict):
score = 0.
for quesid, ans in quesid2ans.items():
datum = self.dataset.id2datum[quesid]
label = datum['label']
if ans in label:
score += label[ans]
return score / len(quesid2ans)
def dump_result(self, quesid2ans: dict, path):
"""
Dump results to a json file, which could be submitted to the VQA online evaluation.
VQA json file submission requirement:
results = [result]
result = {
"question_id": int,
"answer": str
}
:param quesid2ans: dict of quesid --> ans
:param path: The desired path of saved file.
"""
with open(path, 'w') as f:
result = []
for ques_id, ans in quesid2ans.items():
result.append({
'question_id': ques_id,
'answer': ans
})
json.dump(result, f, indent=4, sort_keys=True)
|