import json import os from PIL import Image from torch.utils.data import Dataset from torchvision.datasets import ImageFolder from open_flamingo.eval.imagenet_utils import IMAGENET_1K_CLASS_ID_TO_LABEL class COCOFlickrDataset(Dataset): def __init__( self, image_dir_path, annotations_path, is_flickr=False, ): self.image_dir_path = image_dir_path self.annotations = json.load(open(annotations_path))["annotations"] self.is_flickr = is_flickr def __len__(self): return len(self.annotations) def get_img_path(self, idx): if self.is_flickr: return f"{self.image_dir_path}/{self.annotations[idx]['image_id']}.jpg" else: return f"{self.image_dir_path}/{self.annotations[idx]['image_id']:012d}.jpg" def __getitem__(self, idx): image = Image.open(self.get_img_path(idx)) caption = self.annotations[idx]["caption"] return { "image": image, "caption": caption, "image_id": self.annotations[idx]["image_id"], } class VQADataset(Dataset): def __init__( self, image_dir_path="/mmfs1/gscratch/efml/anasa2/data/vqav2/train2014/", question_path="/mmfs1/gscratch/efml/anasa2/data/vqav2/v2_OpenEnded_mscoco_train2014_questions.json", annotations_path="/mmfs1/gscratch/efml/anasa2/data/vqav2/v2_mscoco_train2014_annotations.json", vqa_dataset="vqa", ): self.questions = json.load(open(question_path, "r"))["questions"] self.answers = json.load(open(annotations_path, "r"))["annotations"] self.image_dir_path = image_dir_path self.vqa_dataset = vqa_dataset def __len__(self): return len(self.questions) def get_img_path(self, question): if self.vqa_dataset == "vqa": return os.path.join( self.image_dir_path, f"COCO_val2014_{question['image_id']:012d}.jpg" ) elif self.vqa_dataset == "ok_vqa": return os.path.join( self.image_dir_path, f"COCO_val2014_{question['image_id']:012d}.jpg" ) else: raise Exception(f"Unknown VQA dataset {self.vqa_dataset}") def __getitem__(self, idx): question = self.questions[idx] answers = self.answers[idx] img_path = self.get_img_path(question) image = Image.open(img_path) return { "image": image, "question": question["question"], "answers": [a["answer"] for a in answers["answers"]], "question_id": question["question_id"], } class ImageNetDataset(ImageFolder): """Class to represent the ImageNet1k dataset.""" def __init__(self, root, **kwargs): super().__init__(root=root, **kwargs) def __getitem__(self, idx): sample, target = super().__getitem__(idx) target_label = IMAGENET_1K_CLASS_ID_TO_LABEL[target] return { "image": sample, "class_id": target, # numeric ID of the ImageNet class "class_name": target_label, # human-readable name of ImageNet class } if __name__ == "__main__": gqa_dataset = GQADataset() for sample in gqa_dataset: print(sample)