Spaces:
Sleeping
Sleeping
import ast | |
from pathlib import Path | |
import pandas as pd | |
import torch | |
from lightning import LightningDataModule | |
from PIL import Image | |
from torch.utils.data import DataLoader, Dataset | |
from src.data.transforms import transform_test, transform_train | |
from src.data.utils import FrameLoader, id2int, pre_caption | |
from src.tools.files import write_txt | |
from src.tools.utils import print_dist | |
Image.MAX_IMAGE_PIXELS = None # Disable DecompressionBombWarning | |
class WebVidCoVRDataModule(LightningDataModule): | |
def __init__( | |
self, | |
batch_size: int, | |
num_workers: int = 4, | |
pin_memory: bool = True, | |
annotation: dict = {"train": "", "val": ""}, | |
vid_dirs: dict = {"train": "", "val": ""}, | |
emb_dirs: dict = {"train": "", "val": ""}, | |
image_size: int = 384, | |
emb_pool: str = "query", | |
iterate: str = "pth2", | |
vid_query_method: str = "middle", | |
vid_frames: int = 1, | |
**kwargs, # type: ignore | |
) -> None: | |
super().__init__() | |
# this line allows to access init params with 'self.hparams' attribute | |
# also ensures init params will be stored in ckpt | |
self.save_hyperparameters(logger=False) | |
self.batch_size = batch_size | |
self.num_workers = num_workers | |
self.pin_memory = pin_memory | |
self.emb_pool = emb_pool | |
self.iterate = iterate | |
self.vid_query_method = vid_query_method | |
self.vid_frames = vid_frames | |
self.transform_train = transform_train(image_size) | |
self.transform_test = transform_test(image_size) | |
self.data_train = WebVidCoVRDataset( | |
transform=self.transform_train, | |
annotation=annotation["train"], | |
vid_dir=vid_dirs["train"], | |
emb_dir=emb_dirs["train"], | |
split="train", | |
emb_pool=self.emb_pool, | |
iterate=self.iterate, | |
vid_query_method=self.vid_query_method, | |
vid_frames=self.vid_frames, | |
) | |
self.data_val = WebVidCoVRDataset( | |
transform=self.transform_test, | |
annotation=annotation["val"], | |
vid_dir=vid_dirs["val"], | |
emb_dir=emb_dirs["val"], | |
split="val", | |
emb_pool=self.emb_pool, | |
iterate=self.iterate, | |
vid_query_method=self.vid_query_method, | |
vid_frames=self.vid_frames, | |
) | |
def prepare_data(self): | |
# things to do on 1 GPU/TPU (not on every GPU/TPU in DDP) | |
# download data, pre-process, split, save to disk, etc... | |
pass | |
def train_dataloader(self): | |
return DataLoader( | |
dataset=self.data_train, | |
batch_size=self.batch_size, | |
num_workers=self.num_workers, | |
pin_memory=self.pin_memory, | |
shuffle=True, | |
drop_last=True, | |
) | |
def val_dataloader(self): | |
return DataLoader( | |
dataset=self.data_val, | |
batch_size=self.batch_size, | |
num_workers=self.num_workers, | |
pin_memory=self.pin_memory, | |
shuffle=False, | |
drop_last=False, | |
) | |
class WebVidCoVRTestDataModule(LightningDataModule): | |
def __init__( | |
self, | |
batch_size: int, | |
annotation: str, | |
vid_dirs: str, | |
emb_dirs: str, | |
num_workers: int = 4, | |
pin_memory: bool = True, | |
image_size: int = 384, | |
emb_pool: str = "query", | |
iterate: str = "pth2", | |
vid_query_method: str = "middle", | |
vid_frames: int = 1, | |
**kwargs, # type: ignore | |
) -> None: | |
super().__init__() | |
self.save_hyperparameters(logger=False) | |
self.batch_size = batch_size | |
self.num_workers = num_workers | |
self.pin_memory = pin_memory | |
self.emb_pool = emb_pool | |
self.iterate = iterate | |
self.vid_query_method = vid_query_method | |
self.vid_frames = vid_frames | |
self.transform_test = transform_test(image_size) | |
self.data_test = WebVidCoVRDataset( | |
transform=self.transform_test, | |
annotation=annotation, | |
vid_dir=vid_dirs, | |
emb_dir=emb_dirs, | |
split="test", | |
emb_pool=self.emb_pool, | |
iterate=self.iterate, | |
vid_query_method=self.vid_query_method, | |
vid_frames=self.vid_frames, | |
) | |
def test_dataloader(self): | |
return DataLoader( | |
dataset=self.data_test, | |
batch_size=self.batch_size, | |
num_workers=self.num_workers, | |
pin_memory=self.pin_memory, | |
shuffle=False, | |
drop_last=False, | |
) | |
class WebVidCoVRDataset(Dataset): | |
def __init__( | |
self, | |
transform, | |
annotation: str, | |
vid_dir: str, | |
emb_dir: str, | |
split: str, | |
max_words: int = 30, | |
emb_pool: str = "query", | |
iterate: str = "pth2", | |
vid_query_method: str = "middle", | |
vid_frames: int = 1, | |
) -> None: | |
super().__init__() | |
self.transform = transform | |
self.annotation_pth = annotation | |
assert Path(annotation).exists(), f"Annotation file {annotation} does not exist" | |
self.df = pd.read_csv(annotation) | |
self.vid_dir = Path(vid_dir) | |
self.emb_dir = Path(emb_dir) | |
assert self.vid_dir.exists(), f"Image directory {self.vid_dir} does not exist" | |
assert self.emb_dir.exists(), f"Embedding directory {emb_dir} does not exist" | |
assert split in [ | |
"train", | |
"val", | |
"test", | |
], f"Invalid split: {split}, must be one of train, val, or test" | |
self.split = split | |
vid_pths = self.vid_dir.glob("*/*.mp4") | |
emb_pths = self.emb_dir.glob("*/*.pth") | |
id2vidpth = { | |
vid_pth.parent.stem + "/" + vid_pth.stem: vid_pth for vid_pth in vid_pths | |
} | |
id2embpth = { | |
emb_pth.parent.stem + "/" + emb_pth.stem: emb_pth for emb_pth in emb_pths | |
} | |
assert len(id2vidpth) > 0, f"No videos found in {vid_dir}" | |
assert len(id2embpth) > 0, f"No embeddings found in {emb_dir}" | |
self.df["path1"] = self.df["pth1"].apply(lambda x: id2vidpth.get(x, None)) # type: ignore | |
self.df["path2"] = self.df["pth2"].apply(lambda x: id2embpth.get(x, None)) # type: ignore | |
# Count unique missing paths | |
missing_pth1 = self.df[self.df["path1"].isna()]["pth1"].unique().tolist() | |
missing_pth1.sort() | |
total_pth1 = self.df["pth1"].nunique() | |
missing_pth2 = self.df[self.df["path2"].isna()]["pth2"].unique().tolist() | |
missing_pth2.sort() | |
total_pth2 = self.df["pth2"].nunique() | |
if len(missing_pth1) > 0: | |
print_dist( | |
f"Missing {len(missing_pth1)} pth1's ({len(missing_pth1)/total_pth1 * 100:.1f}%), saving them to missing_pth1-{split}.txt" | |
) | |
if split == "test": | |
raise ValueError( | |
f"Missing {len(missing_pth1)} pth1's ({len(missing_pth1)/total_pth1 * 100:.1f}%) in test split" | |
) | |
write_txt(missing_pth1, f"missing_pth1-{split}.txt") | |
if len(missing_pth2) > 0: | |
print_dist( | |
f"Missing {len(missing_pth2)} pth2's ({len(missing_pth2)/total_pth2 * 100:.1f}%), saving them to missing_pth2-{split}.txt" | |
) | |
if split == "test": | |
raise ValueError( | |
f"Missing {len(missing_pth2)} pth2's ({len(missing_pth2)/total_pth2 * 100:.1f}%) in test split" | |
) | |
write_txt(missing_pth2, f"missing_pth2-{split}.txt") | |
# Remove missing paths | |
self.df = self.df[self.df["path1"].notna()] | |
self.df = self.df[self.df["path2"].notna()] | |
self.df.reset_index(drop=True, inplace=True) | |
self.max_words = max_words | |
assert emb_pool in [ | |
"middle", | |
"mean", | |
"query", | |
], f"Invalid emb_pool: {emb_pool}, must be one of middle, mean, or query" | |
self.emb_pool = emb_pool | |
if iterate in ["idx", "triplets"]: | |
iterate = "idx" | |
self.df["idx"] = self.df.index | |
self.iterate = iterate | |
self.target_txts = self.df[iterate].unique() | |
assert iterate in self.df.columns, f"{iterate} not in {Path(annotation).stem}" | |
self.df.sort_values(iterate, inplace=True) | |
self.df.reset_index(drop=True, inplace=True) | |
self.df["int1"] = self.df["pth1"].apply(lambda x: id2int(x, sub="0")) | |
self.df["int2"] = self.df["pth2"].apply(lambda x: id2int(x, sub="0")) | |
self.pairid2ref = self.df["int1"].to_dict() | |
assert ( | |
self.df["int1"].nunique() == self.df["pth1"].nunique() | |
), "int1 is not unique" | |
assert ( | |
self.df["int2"].nunique() == self.df["pth2"].nunique() | |
), "int2 is not unique" | |
# int2id is a dict with key: int1, value: pth1 | |
self.int2id = self.df.groupby("int1")["pth1"].apply(set).to_dict() | |
self.int2id = {k: list(v)[0] for k, v in self.int2id.items()} | |
self.pairid2tar = self.df["int2"].to_dict() | |
self.df.set_index(iterate, inplace=True) | |
self.df[iterate] = self.df.index | |
if split == "test": | |
assert ( | |
len(self.target_txts) == self.df.shape[0] | |
), "Test split should have one caption per row" | |
assert vid_query_method in [ | |
"middle", | |
"random", | |
"sample", | |
], f"Invalid vid_query_method: {vid_query_method}, must be one of middle, random, or sample" | |
self.frame_loader = FrameLoader( | |
transform=self.transform, method=vid_query_method, frames_video=vid_frames | |
) | |
def __len__(self) -> int: | |
return len(self.target_txts) | |
def __getitem__(self, index): | |
target_txt = self.target_txts[index] | |
ann = self.df.loc[target_txt] | |
if ann.ndim > 1: | |
ann = ann.sample() | |
ann = ann.iloc[0] | |
reference_pth = str(ann["path1"]) | |
reference_vid = self.frame_loader(reference_pth) | |
caption = pre_caption(ann["edit"], self.max_words) | |
target_pth = str(ann["path2"]) | |
target_emb = torch.load(target_pth).cpu() | |
if self.emb_pool == "middle": | |
target_emb = target_emb[len(target_emb) // 2] | |
elif self.emb_pool == "mean": | |
target_emb = target_emb.mean(0) | |
elif self.emb_pool == "query": | |
vid_scores = ast.literal_eval(str(ann["scores"])) | |
if len(vid_scores) == 0: | |
vid_scores = [1.0] * len(target_emb) | |
vid_scores = torch.Tensor(vid_scores) | |
vid_scores = (vid_scores / 0.1).softmax(dim=0) | |
target_emb = torch.einsum("f,fe->e", vid_scores, target_emb) | |
return reference_vid, target_emb, caption, index | |