Spaces:
Sleeping
Sleeping
File size: 4,922 Bytes
08d7644 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 |
# coding=utf-8
# Copyleft 2019 project LXRT.
import json
import numpy as np
from torch.utils.data import Dataset
from param import args
from utils import load_obj_tsv
# Load part of the dataset for fast checking.
# Notice that here is the number of images instead of the number of data,
# which means all related data to the images would be used.
TINY_IMG_NUM = 512
FAST_IMG_NUM = 5000
class NLVR2Dataset:
"""
An NLVR2 data example in json file:
{
"identifier": "train-10171-0-0",
"img0": "train-10171-0-img0",
"img1": "train-10171-0-img1",
"label": 0,
"sent": "An image shows one leather pencil case, displayed open with writing implements tucked inside.
",
"uid": "nlvr2_train_0"
}
"""
def __init__(self, splits: str):
self.name = splits
self.splits = splits.split(',')
# Loading datasets to data
self.data = []
for split in self.splits:
self.data.extend(json.load(open("data/nlvr2/%s.json" % split)))
print("Load %d data from split(s) %s." % (len(self.data), self.name))
# List to dict (for evaluation and others)
self.id2datum = {
datum['uid']: datum
for datum in self.data
}
def __len__(self):
return len(self.data)
"""
An example in obj36 tsv:
FIELDNAMES = ["img_id", "img_h", "img_w", "objects_id", "objects_conf",
"attrs_id", "attrs_conf", "num_boxes", "boxes", "features"]
FIELDNAMES would be keys in the dict returned by load_obj_tsv.
"""
class NLVR2TorchDataset(Dataset):
def __init__(self, dataset: NLVR2Dataset):
super().__init__()
self.raw_dataset = dataset
if args.tiny:
topk = TINY_IMG_NUM
elif args.fast:
topk = FAST_IMG_NUM
else:
topk = -1
# Loading detection features to img_data
img_data = []
if 'train' in dataset.splits:
img_data.extend(load_obj_tsv('data/nlvr2_imgfeat/train_obj36.tsv', topk=topk))
if 'valid' in dataset.splits:
img_data.extend(load_obj_tsv('data/nlvr2_imgfeat/valid_obj36.tsv', topk=topk))
if 'test' in dataset.name:
img_data.extend(load_obj_tsv('data/nlvr2_imgfeat/test_obj36.tsv', topk=topk))
self.imgid2img = {}
for img_datum in img_data:
self.imgid2img[img_datum['img_id']] = img_datum
# Filter out the dataset
self.data = []
for datum in self.raw_dataset.data:
if datum['img0'] in self.imgid2img and datum['img1'] in self.imgid2img:
self.data.append(datum)
print("Use %d data in torch dataset" % (len(self.data)))
print()
def __len__(self):
return len(self.data)
def __getitem__(self, item: int):
datum = self.data[item]
ques_id = datum['uid']
ques = datum['sent']
# Get image info
boxes2 = []
feats2 = []
for key in ['img0', 'img1']:
img_id = datum[key]
img_info = self.imgid2img[img_id]
boxes = img_info['boxes'].copy()
feats = img_info['features'].copy()
assert len(boxes) == len(feats)
# Normalize the boxes (to 0 ~ 1)
img_h, img_w = img_info['img_h'], img_info['img_w']
boxes[..., (0, 2)] /= img_w
boxes[..., (1, 3)] /= img_h
np.testing.assert_array_less(boxes, 1+1e-5)
np.testing.assert_array_less(-boxes, 0+1e-5)
boxes2.append(boxes)
feats2.append(feats)
feats = np.stack(feats2)
boxes = np.stack(boxes2)
# Create target
if 'label' in datum:
label = datum['label']
return ques_id, feats, boxes, ques, label
else:
return ques_id, feats, boxes, ques
class NLVR2Evaluator:
def __init__(self, dataset: NLVR2Dataset):
self.dataset = dataset
def evaluate(self, quesid2ans: dict):
score = 0.
for quesid, ans in quesid2ans.items():
datum = self.dataset.id2datum[quesid]
label = datum['label']
if ans == label:
score += 1
return score / len(quesid2ans)
def dump_result(self, quesid2ans: dict, path):
"""
Dump result to a CSV file, which is compatible with NLVR2 evaluation system.
NLVR2 CSV file requirement:
Each line contains: identifier, answer
:param quesid2ans: nlvr2 uid to ans (either "True" or "False")
:param path: The desired path of saved file.
:return:
"""
with open(path, 'w') as f:
for uid, ans in quesid2ans.items():
idt = self.dataset.id2datum[uid]["identifier"]
ans = 'True' if ans == 1 else 'False'
f.write("%s,%s\n" % (idt, ans))
|