Spaces:
Sleeping
Sleeping
File size: 5,985 Bytes
08d7644 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 |
# coding=utf-8
# Copyleft 2019 project LXRT.
import json
import numpy as np
import torch
from torch.utils.data import Dataset
from param import args
from utils import load_obj_tsv
# Load part of the dataset for fast checking.
# Notice that here is the number of images instead of the number of data,
# which means all related data to the images would be used.
TINY_IMG_NUM = 512
FAST_IMG_NUM = 5000
class GQADataset:
"""
A GQA data example in json file:
{
"img_id": "2375429",
"label": {
"pipe": 1.0
},
"question_id": "07333408",
"sent": "What is on the white wall?"
}
"""
def __init__(self, splits: str):
self.name = splits
self.splits = splits.split(',')
# Loading datasets to data
self.data = []
for split in self.splits:
self.data.extend(json.load(open("data/gqa/%s.json" % split)))
print("Load %d data from split(s) %s." % (len(self.data), self.name))
# List to dict (for evaluation and others)
self.id2datum = {
datum['question_id']: datum
for datum in self.data
}
# Answers
self.ans2label = json.load(open("data/gqa/trainval_ans2label.json"))
self.label2ans = json.load(open("data/gqa/trainval_label2ans.json"))
assert len(self.ans2label) == len(self.label2ans)
for ans, label in self.ans2label.items():
assert self.label2ans[label] == ans
@property
def num_answers(self):
return len(self.ans2label)
def __len__(self):
return len(self.data)
class GQABufferLoader():
def __init__(self):
self.key2data = {}
def load_data(self, name, number):
if name == 'testdev':
path = "data/vg_gqa_imgfeat/gqa_testdev_obj36.tsv"
else:
path = "data/vg_gqa_imgfeat/vg_gqa_obj36.tsv"
key = "%s_%d" % (path, number)
if key not in self.key2data:
self.key2data[key] = load_obj_tsv(
path,
topk=number
)
return self.key2data[key]
gqa_buffer_loader = GQABufferLoader()
"""
Example in obj tsv:
FIELDNAMES = ["img_id", "img_h", "img_w", "objects_id", "objects_conf",
"attrs_id", "attrs_conf", "num_boxes", "boxes", "features"]
"""
class GQATorchDataset(Dataset):
def __init__(self, dataset: GQADataset):
super().__init__()
self.raw_dataset = dataset
if args.tiny:
topk = TINY_IMG_NUM
elif args.fast:
topk = FAST_IMG_NUM
else:
topk = -1
# Loading detection features to img_data
# Since images in train and valid both come from Visual Genome,
# buffer the image loading to save memory.
img_data = []
if 'testdev' in dataset.splits or 'testdev_all' in dataset.splits: # Always loading all the data in testdev
img_data.extend(gqa_buffer_loader.load_data('testdev', -1))
else:
img_data.extend(gqa_buffer_loader.load_data('train', topk))
self.imgid2img = {}
for img_datum in img_data:
self.imgid2img[img_datum['img_id']] = img_datum
# Only kept the data with loaded image features
self.data = []
for datum in self.raw_dataset.data:
if datum['img_id'] in self.imgid2img:
self.data.append(datum)
print("Use %d data in torch dataset" % (len(self.data)))
print()
def __len__(self):
return len(self.data)
def __getitem__(self, item: int):
datum = self.data[item]
img_id = datum['img_id']
ques_id = datum['question_id']
ques = datum['sent']
# Get image info
img_info = self.imgid2img[img_id]
obj_num = img_info['num_boxes']
boxes = img_info['boxes'].copy()
feats = img_info['features'].copy()
assert len(boxes) == len(feats) == obj_num
# Normalize the boxes (to 0 ~ 1)
img_h, img_w = img_info['img_h'], img_info['img_w']
boxes = boxes.copy()
boxes[:, (0, 2)] /= img_w
boxes[:, (1, 3)] /= img_h
np.testing.assert_array_less(boxes, 1+1e-5)
np.testing.assert_array_less(-boxes, 0+1e-5)
# Create target
if 'label' in datum:
label = datum['label']
target = torch.zeros(self.raw_dataset.num_answers)
for ans, score in label.items():
if ans in self.raw_dataset.ans2label:
target[self.raw_dataset.ans2label[ans]] = score
return ques_id, feats, boxes, ques, target
else:
return ques_id, feats, boxes, ques
class GQAEvaluator:
def __init__(self, dataset: GQADataset):
self.dataset = dataset
def evaluate(self, quesid2ans: dict):
score = 0.
for quesid, ans in quesid2ans.items():
datum = self.dataset.id2datum[quesid]
label = datum['label']
if ans in label:
score += label[ans]
return score / len(quesid2ans)
def dump_result(self, quesid2ans: dict, path):
"""
Dump the result to a GQA-challenge submittable json file.
GQA json file submission requirement:
results = [result]
result = {
"questionId": str, # Note: it's a actually an int number but the server requires an str.
"prediction": str
}
:param quesid2ans: A dict mapping question id to its predicted answer.
:param path: The file path to save the json file.
:return:
"""
with open(path, 'w') as f:
result = []
for ques_id, ans in quesid2ans.items():
result.append({
'questionId': ques_id,
'prediction': ans
})
json.dump(result, f, indent=4, sort_keys=True)
|