Spaces:
Sleeping
Sleeping
# coding=utf-8 | |
# Copyleft 2019 project LXRT. | |
import json | |
import numpy as np | |
import torch | |
from torch.utils.data import Dataset | |
from param import args | |
from utils import load_obj_tsv | |
# Load part of the dataset for fast checking. | |
# Notice that here is the number of images instead of the number of data, | |
# which means all related data to the images would be used. | |
TINY_IMG_NUM = 512 | |
FAST_IMG_NUM = 5000 | |
class GQADataset: | |
""" | |
A GQA data example in json file: | |
{ | |
"img_id": "2375429", | |
"label": { | |
"pipe": 1.0 | |
}, | |
"question_id": "07333408", | |
"sent": "What is on the white wall?" | |
} | |
""" | |
def __init__(self, splits: str): | |
self.name = splits | |
self.splits = splits.split(',') | |
# Loading datasets to data | |
self.data = [] | |
for split in self.splits: | |
self.data.extend(json.load(open("data/gqa/%s.json" % split))) | |
print("Load %d data from split(s) %s." % (len(self.data), self.name)) | |
# List to dict (for evaluation and others) | |
self.id2datum = { | |
datum['question_id']: datum | |
for datum in self.data | |
} | |
# Answers | |
self.ans2label = json.load(open("data/gqa/trainval_ans2label.json")) | |
self.label2ans = json.load(open("data/gqa/trainval_label2ans.json")) | |
assert len(self.ans2label) == len(self.label2ans) | |
for ans, label in self.ans2label.items(): | |
assert self.label2ans[label] == ans | |
def num_answers(self): | |
return len(self.ans2label) | |
def __len__(self): | |
return len(self.data) | |
class GQABufferLoader(): | |
def __init__(self): | |
self.key2data = {} | |
def load_data(self, name, number): | |
if name == 'testdev': | |
path = "data/vg_gqa_imgfeat/gqa_testdev_obj36.tsv" | |
else: | |
path = "data/vg_gqa_imgfeat/vg_gqa_obj36.tsv" | |
key = "%s_%d" % (path, number) | |
if key not in self.key2data: | |
self.key2data[key] = load_obj_tsv( | |
path, | |
topk=number | |
) | |
return self.key2data[key] | |
gqa_buffer_loader = GQABufferLoader() | |
""" | |
Example in obj tsv: | |
FIELDNAMES = ["img_id", "img_h", "img_w", "objects_id", "objects_conf", | |
"attrs_id", "attrs_conf", "num_boxes", "boxes", "features"] | |
""" | |
class GQATorchDataset(Dataset): | |
def __init__(self, dataset: GQADataset): | |
super().__init__() | |
self.raw_dataset = dataset | |
if args.tiny: | |
topk = TINY_IMG_NUM | |
elif args.fast: | |
topk = FAST_IMG_NUM | |
else: | |
topk = -1 | |
# Loading detection features to img_data | |
# Since images in train and valid both come from Visual Genome, | |
# buffer the image loading to save memory. | |
img_data = [] | |
if 'testdev' in dataset.splits or 'testdev_all' in dataset.splits: # Always loading all the data in testdev | |
img_data.extend(gqa_buffer_loader.load_data('testdev', -1)) | |
else: | |
img_data.extend(gqa_buffer_loader.load_data('train', topk)) | |
self.imgid2img = {} | |
for img_datum in img_data: | |
self.imgid2img[img_datum['img_id']] = img_datum | |
# Only kept the data with loaded image features | |
self.data = [] | |
for datum in self.raw_dataset.data: | |
if datum['img_id'] in self.imgid2img: | |
self.data.append(datum) | |
print("Use %d data in torch dataset" % (len(self.data))) | |
print() | |
def __len__(self): | |
return len(self.data) | |
def __getitem__(self, item: int): | |
datum = self.data[item] | |
img_id = datum['img_id'] | |
ques_id = datum['question_id'] | |
ques = datum['sent'] | |
# Get image info | |
img_info = self.imgid2img[img_id] | |
obj_num = img_info['num_boxes'] | |
boxes = img_info['boxes'].copy() | |
feats = img_info['features'].copy() | |
assert len(boxes) == len(feats) == obj_num | |
# Normalize the boxes (to 0 ~ 1) | |
img_h, img_w = img_info['img_h'], img_info['img_w'] | |
boxes = boxes.copy() | |
boxes[:, (0, 2)] /= img_w | |
boxes[:, (1, 3)] /= img_h | |
np.testing.assert_array_less(boxes, 1+1e-5) | |
np.testing.assert_array_less(-boxes, 0+1e-5) | |
# Create target | |
if 'label' in datum: | |
label = datum['label'] | |
target = torch.zeros(self.raw_dataset.num_answers) | |
for ans, score in label.items(): | |
if ans in self.raw_dataset.ans2label: | |
target[self.raw_dataset.ans2label[ans]] = score | |
return ques_id, feats, boxes, ques, target | |
else: | |
return ques_id, feats, boxes, ques | |
class GQAEvaluator: | |
def __init__(self, dataset: GQADataset): | |
self.dataset = dataset | |
def evaluate(self, quesid2ans: dict): | |
score = 0. | |
for quesid, ans in quesid2ans.items(): | |
datum = self.dataset.id2datum[quesid] | |
label = datum['label'] | |
if ans in label: | |
score += label[ans] | |
return score / len(quesid2ans) | |
def dump_result(self, quesid2ans: dict, path): | |
""" | |
Dump the result to a GQA-challenge submittable json file. | |
GQA json file submission requirement: | |
results = [result] | |
result = { | |
"questionId": str, # Note: it's a actually an int number but the server requires an str. | |
"prediction": str | |
} | |
:param quesid2ans: A dict mapping question id to its predicted answer. | |
:param path: The file path to save the json file. | |
:return: | |
""" | |
with open(path, 'w') as f: | |
result = [] | |
for ques_id, ans in quesid2ans.items(): | |
result.append({ | |
'questionId': ques_id, | |
'prediction': ans | |
}) | |
json.dump(result, f, indent=4, sort_keys=True) | |