Spaces:
Sleeping
Sleeping
import json | |
from pathlib import Path | |
import torch | |
from lightning import LightningDataModule | |
from PIL import Image | |
from torch.utils.data import DataLoader, Dataset | |
from src.data.transforms import transform_test, transform_train | |
from src.data.utils import id2int, pre_caption | |
Image.MAX_IMAGE_PIXELS = None # Disable DecompressionBombWarning | |
class CIRRDataModule(LightningDataModule): | |
def __init__( | |
self, | |
batch_size: int, | |
num_workers: int = 4, | |
pin_memory: bool = True, | |
annotation: dict = {"train": "", "val": ""}, | |
img_dirs: dict = {"train": "", "val": ""}, | |
emb_dirs: dict = {"train": "", "val": ""}, | |
image_size: int = 384, | |
**kwargs, # type: ignore | |
) -> None: | |
super().__init__() | |
self.save_hyperparameters(logger=False) | |
self.batch_size = batch_size | |
self.num_workers = num_workers | |
self.pin_memory = pin_memory | |
self.transform_train = transform_train(image_size) | |
self.transform_test = transform_test(image_size) | |
self.data_train = CIRRDataset( | |
transform=self.transform_train, | |
annotation=annotation["train"], | |
img_dir=img_dirs["train"], | |
emb_dir=emb_dirs["train"], | |
split="train", | |
) | |
self.data_val = CIRRDataset( | |
transform=self.transform_test, | |
annotation=annotation["val"], | |
img_dir=img_dirs["val"], | |
emb_dir=emb_dirs["val"], | |
split="val", | |
) | |
def prepare_data(self): | |
# things to do on 1 GPU/TPU (not on every GPU/TPU in DDP) | |
# download data, pre-process, split, save to disk, etc... | |
pass | |
def train_dataloader(self): | |
return DataLoader( | |
dataset=self.data_train, | |
batch_size=self.batch_size, | |
num_workers=self.num_workers, | |
pin_memory=self.pin_memory, | |
shuffle=True, | |
drop_last=True, | |
) | |
def val_dataloader(self): | |
return DataLoader( | |
dataset=self.data_val, | |
batch_size=self.batch_size, | |
num_workers=self.num_workers, | |
pin_memory=self.pin_memory, | |
shuffle=False, | |
drop_last=False, | |
) | |
class CIRRTestDataModule(LightningDataModule): | |
def __init__( | |
self, | |
batch_size: int, | |
annotation: str, | |
img_dirs: str, | |
emb_dirs: str, | |
num_workers: int = 4, | |
pin_memory: bool = True, | |
image_size: int = 384, | |
**kwargs, # type: ignore | |
) -> None: | |
super().__init__() | |
self.save_hyperparameters(logger=False) | |
self.batch_size = batch_size | |
self.num_workers = num_workers | |
self.pin_memory = pin_memory | |
self.transform_test = transform_test(image_size) | |
self.data_test = CIRRDataset( | |
transform=self.transform_test, | |
annotation=annotation, | |
img_dir=img_dirs, | |
emb_dir=emb_dirs, | |
split="test", | |
) | |
def test_dataloader(self): | |
return DataLoader( | |
dataset=self.data_test, | |
batch_size=self.batch_size, | |
num_workers=self.num_workers, | |
pin_memory=self.pin_memory, | |
shuffle=False, | |
drop_last=False, | |
) | |
class CIRRDataset(Dataset): | |
def __init__( | |
self, | |
transform, | |
annotation: str, | |
img_dir: str, | |
emb_dir: str, | |
split: str, | |
max_words: int = 30, | |
) -> None: | |
super().__init__() | |
self.transform = transform | |
self.annotation_pth = annotation | |
assert Path(annotation).exists(), f"Annotation file {annotation} does not exist" | |
self.annotation = json.load(open(annotation, "r")) | |
self.split = split | |
self.max_words = max_words | |
self.img_dir = Path(img_dir) | |
self.emb_dir = Path(emb_dir) | |
assert split in [ | |
"train", | |
"val", | |
"test", | |
], f"Invalid split: {split}, must be one of train, val, or test" | |
assert self.img_dir.exists(), f"Image directory {img_dir} does not exist" | |
assert self.emb_dir.exists(), f"Embedding directory {emb_dir} does not exist" | |
self.pairid2ref = { | |
ann["pairid"]: id2int(ann["reference"]) for ann in self.annotation | |
} | |
self.int2id = { | |
id2int(ann["reference"]): ann["reference"] for ann in self.annotation | |
} | |
ids = {ann["reference"] for ann in self.annotation} | |
assert len(self.int2id) == len(ids), "Reference ids are not unique" | |
self.pairid2members = { | |
ann["pairid"]: id2int(ann["img_set"]["members"]) for ann in self.annotation | |
} | |
if split != "test": | |
self.pairid2tar = { | |
ann["pairid"]: id2int(ann["target_hard"]) for ann in self.annotation | |
} | |
else: | |
self.pairid2tar = None | |
if split == "train": | |
img_pths = self.img_dir.glob("*/*.png") | |
emb_pths = self.emb_dir.glob("*/*.pth") | |
else: | |
img_pths = self.img_dir.glob("*.png") | |
emb_pths = self.emb_dir.glob("*.pth") | |
self.id2imgpth = {img_pth.stem: img_pth for img_pth in img_pths} | |
self.id2embpth = {emb_pth.stem: emb_pth for emb_pth in emb_pths} | |
for ann in self.annotation: | |
assert ( | |
ann["reference"] in self.id2imgpth | |
), f"Path to reference {ann['reference']} not found in {self.img_dir}" | |
assert ( | |
ann["reference"] in self.id2embpth | |
), f"Path to reference {ann['reference']} not found in {self.emb_dir}" | |
if split != "test": | |
assert ( | |
ann["target_hard"] in self.id2imgpth | |
), f"Path to target {ann['target_hard']} not found" | |
assert ( | |
ann["target_hard"] in self.id2embpth | |
), f"Path to target {ann['target_hard']} not found" | |
def __len__(self) -> int: | |
return len(self.annotation) | |
def __getitem__(self, index): | |
ann = self.annotation[index] | |
reference_img_pth = self.id2imgpth[ann["reference"]] | |
reference_img = Image.open(reference_img_pth).convert("RGB") | |
reference_img = self.transform(reference_img) | |
caption = pre_caption(ann["caption"], self.max_words) | |
if self.split == "test": | |
reference_feat = torch.load(self.id2embpth[ann["reference"]]) | |
return reference_img, reference_feat, caption, ann["pairid"] | |
target_emb_pth = self.id2embpth[ann["target_hard"]] | |
target_feat = torch.load(target_emb_pth).cpu() | |
return ( | |
reference_img, | |
target_feat, | |
caption, | |
ann["pairid"], | |
) | |