WwYc's picture
Upload 61 files
08d7644 verified
raw
history blame
4.92 kB
# coding=utf-8
# Copyleft 2019 project LXRT.
import json
import numpy as np
from torch.utils.data import Dataset
from param import args
from utils import load_obj_tsv
# Load part of the dataset for fast checking.
# Notice that here is the number of images instead of the number of data,
# which means all related data to the images would be used.
TINY_IMG_NUM = 512
FAST_IMG_NUM = 5000
class NLVR2Dataset:
"""
An NLVR2 data example in json file:
{
"identifier": "train-10171-0-0",
"img0": "train-10171-0-img0",
"img1": "train-10171-0-img1",
"label": 0,
"sent": "An image shows one leather pencil case, displayed open with writing implements tucked inside.
",
"uid": "nlvr2_train_0"
}
"""
def __init__(self, splits: str):
self.name = splits
self.splits = splits.split(',')
# Loading datasets to data
self.data = []
for split in self.splits:
self.data.extend(json.load(open("data/nlvr2/%s.json" % split)))
print("Load %d data from split(s) %s." % (len(self.data), self.name))
# List to dict (for evaluation and others)
self.id2datum = {
datum['uid']: datum
for datum in self.data
}
def __len__(self):
return len(self.data)
"""
An example in obj36 tsv:
FIELDNAMES = ["img_id", "img_h", "img_w", "objects_id", "objects_conf",
"attrs_id", "attrs_conf", "num_boxes", "boxes", "features"]
FIELDNAMES would be keys in the dict returned by load_obj_tsv.
"""
class NLVR2TorchDataset(Dataset):
def __init__(self, dataset: NLVR2Dataset):
super().__init__()
self.raw_dataset = dataset
if args.tiny:
topk = TINY_IMG_NUM
elif args.fast:
topk = FAST_IMG_NUM
else:
topk = -1
# Loading detection features to img_data
img_data = []
if 'train' in dataset.splits:
img_data.extend(load_obj_tsv('data/nlvr2_imgfeat/train_obj36.tsv', topk=topk))
if 'valid' in dataset.splits:
img_data.extend(load_obj_tsv('data/nlvr2_imgfeat/valid_obj36.tsv', topk=topk))
if 'test' in dataset.name:
img_data.extend(load_obj_tsv('data/nlvr2_imgfeat/test_obj36.tsv', topk=topk))
self.imgid2img = {}
for img_datum in img_data:
self.imgid2img[img_datum['img_id']] = img_datum
# Filter out the dataset
self.data = []
for datum in self.raw_dataset.data:
if datum['img0'] in self.imgid2img and datum['img1'] in self.imgid2img:
self.data.append(datum)
print("Use %d data in torch dataset" % (len(self.data)))
print()
def __len__(self):
return len(self.data)
def __getitem__(self, item: int):
datum = self.data[item]
ques_id = datum['uid']
ques = datum['sent']
# Get image info
boxes2 = []
feats2 = []
for key in ['img0', 'img1']:
img_id = datum[key]
img_info = self.imgid2img[img_id]
boxes = img_info['boxes'].copy()
feats = img_info['features'].copy()
assert len(boxes) == len(feats)
# Normalize the boxes (to 0 ~ 1)
img_h, img_w = img_info['img_h'], img_info['img_w']
boxes[..., (0, 2)] /= img_w
boxes[..., (1, 3)] /= img_h
np.testing.assert_array_less(boxes, 1+1e-5)
np.testing.assert_array_less(-boxes, 0+1e-5)
boxes2.append(boxes)
feats2.append(feats)
feats = np.stack(feats2)
boxes = np.stack(boxes2)
# Create target
if 'label' in datum:
label = datum['label']
return ques_id, feats, boxes, ques, label
else:
return ques_id, feats, boxes, ques
class NLVR2Evaluator:
def __init__(self, dataset: NLVR2Dataset):
self.dataset = dataset
def evaluate(self, quesid2ans: dict):
score = 0.
for quesid, ans in quesid2ans.items():
datum = self.dataset.id2datum[quesid]
label = datum['label']
if ans == label:
score += 1
return score / len(quesid2ans)
def dump_result(self, quesid2ans: dict, path):
"""
Dump result to a CSV file, which is compatible with NLVR2 evaluation system.
NLVR2 CSV file requirement:
Each line contains: identifier, answer
:param quesid2ans: nlvr2 uid to ans (either "True" or "False")
:param path: The desired path of saved file.
:return:
"""
with open(path, 'w') as f:
for uid, ans in quesid2ans.items():
idt = self.dataset.id2datum[uid]["identifier"]
ans = 'True' if ans == 1 else 'False'
f.write("%s,%s\n" % (idt, ans))