import json from pathlib import Path import torch from lightning import LightningDataModule from PIL import Image from torch.utils.data import DataLoader, Dataset from src.data.transforms import transform_test, transform_train from src.data.utils import id2int, pre_caption Image.MAX_IMAGE_PIXELS = None # Disable DecompressionBombWarning class CIRRDataModule(LightningDataModule): def __init__( self, batch_size: int, num_workers: int = 4, pin_memory: bool = True, annotation: dict = {"train": "", "val": ""}, img_dirs: dict = {"train": "", "val": ""}, emb_dirs: dict = {"train": "", "val": ""}, image_size: int = 384, **kwargs, # type: ignore ) -> None: super().__init__() self.save_hyperparameters(logger=False) self.batch_size = batch_size self.num_workers = num_workers self.pin_memory = pin_memory self.transform_train = transform_train(image_size) self.transform_test = transform_test(image_size) self.data_train = CIRRDataset( transform=self.transform_train, annotation=annotation["train"], img_dir=img_dirs["train"], emb_dir=emb_dirs["train"], split="train", ) self.data_val = CIRRDataset( transform=self.transform_test, annotation=annotation["val"], img_dir=img_dirs["val"], emb_dir=emb_dirs["val"], split="val", ) def prepare_data(self): # things to do on 1 GPU/TPU (not on every GPU/TPU in DDP) # download data, pre-process, split, save to disk, etc... pass def train_dataloader(self): return DataLoader( dataset=self.data_train, batch_size=self.batch_size, num_workers=self.num_workers, pin_memory=self.pin_memory, shuffle=True, drop_last=True, ) def val_dataloader(self): return DataLoader( dataset=self.data_val, batch_size=self.batch_size, num_workers=self.num_workers, pin_memory=self.pin_memory, shuffle=False, drop_last=False, ) class CIRRTestDataModule(LightningDataModule): def __init__( self, batch_size: int, annotation: str, img_dirs: str, emb_dirs: str, num_workers: int = 4, pin_memory: bool = True, image_size: int = 384, **kwargs, # type: ignore ) -> None: super().__init__() self.save_hyperparameters(logger=False) self.batch_size = batch_size self.num_workers = num_workers self.pin_memory = pin_memory self.transform_test = transform_test(image_size) self.data_test = CIRRDataset( transform=self.transform_test, annotation=annotation, img_dir=img_dirs, emb_dir=emb_dirs, split="test", ) def test_dataloader(self): return DataLoader( dataset=self.data_test, batch_size=self.batch_size, num_workers=self.num_workers, pin_memory=self.pin_memory, shuffle=False, drop_last=False, ) class CIRRDataset(Dataset): def __init__( self, transform, annotation: str, img_dir: str, emb_dir: str, split: str, max_words: int = 30, ) -> None: super().__init__() self.transform = transform self.annotation_pth = annotation assert Path(annotation).exists(), f"Annotation file {annotation} does not exist" self.annotation = json.load(open(annotation, "r")) self.split = split self.max_words = max_words self.img_dir = Path(img_dir) self.emb_dir = Path(emb_dir) assert split in [ "train", "val", "test", ], f"Invalid split: {split}, must be one of train, val, or test" assert self.img_dir.exists(), f"Image directory {img_dir} does not exist" assert self.emb_dir.exists(), f"Embedding directory {emb_dir} does not exist" self.pairid2ref = { ann["pairid"]: id2int(ann["reference"]) for ann in self.annotation } self.int2id = { id2int(ann["reference"]): ann["reference"] for ann in self.annotation } ids = {ann["reference"] for ann in self.annotation} assert len(self.int2id) == len(ids), "Reference ids are not unique" self.pairid2members = { ann["pairid"]: id2int(ann["img_set"]["members"]) for ann in self.annotation } if split != "test": self.pairid2tar = { ann["pairid"]: id2int(ann["target_hard"]) for ann in self.annotation } else: self.pairid2tar = None if split == "train": img_pths = self.img_dir.glob("*/*.png") emb_pths = self.emb_dir.glob("*/*.pth") else: img_pths = self.img_dir.glob("*.png") emb_pths = self.emb_dir.glob("*.pth") self.id2imgpth = {img_pth.stem: img_pth for img_pth in img_pths} self.id2embpth = {emb_pth.stem: emb_pth for emb_pth in emb_pths} for ann in self.annotation: assert ( ann["reference"] in self.id2imgpth ), f"Path to reference {ann['reference']} not found in {self.img_dir}" assert ( ann["reference"] in self.id2embpth ), f"Path to reference {ann['reference']} not found in {self.emb_dir}" if split != "test": assert ( ann["target_hard"] in self.id2imgpth ), f"Path to target {ann['target_hard']} not found" assert ( ann["target_hard"] in self.id2embpth ), f"Path to target {ann['target_hard']} not found" def __len__(self) -> int: return len(self.annotation) def __getitem__(self, index): ann = self.annotation[index] reference_img_pth = self.id2imgpth[ann["reference"]] reference_img = Image.open(reference_img_pth).convert("RGB") reference_img = self.transform(reference_img) caption = pre_caption(ann["caption"], self.max_words) if self.split == "test": reference_feat = torch.load(self.id2embpth[ann["reference"]]) return reference_img, reference_feat, caption, ann["pairid"] target_emb_pth = self.id2embpth[ann["target_hard"]] target_feat = torch.load(target_emb_pth).cpu() return ( reference_img, target_feat, caption, ann["pairid"], )