import functools from io import BytesIO import torch import torchvision import torchvision.transforms.v2 as transforms import wids from torch.utils.data import DataLoader def _video_shortener(video_tensor, length, generator=None): start = torch.randint(0, video_tensor.shape[0] - length, (1,), generator=generator) return video_tensor[start:start + length] def select_video_extract(length=16, generator=None): return functools.partial(_video_shortener, length=length, generator=generator) def my_collate_fn(batch): videos = torch.stack([sample[0] for sample in batch]) txts = [sample[1] for sample in batch] return videos, txts class WebVidDataset(wids.ShardListDataset): def __init__(self, shards, cache_dir, video_length=16, video_size=256, video_length_offset=1, val=False, seed=42, **kwargs): self.val = val self.generator = torch.Generator() self.generator.manual_seed(seed) self.generator_init_state = self.generator.get_state() super().__init__(shards, cache_dir=cache_dir, keep=True, **kwargs) if isinstance(video_size, int): video_size = (video_size, video_size) self.video_size = video_size for size in video_size: if size % 8 != 0: raise ValueError("video_size must be divisible by 8") self.transform = transforms.Compose( [ select_video_extract(length=video_length + video_length_offset, generator=self.generator), transforms.Resize(size=video_size), transforms.RandomCrop(size=video_size) if not self.val else transforms.CenterCrop(size=video_size), transforms.RandomHorizontalFlip() if not self.val else transforms.Identity(), ] ) self.add_transform(self._make_sample) def _make_sample(self, sample): if self.val: self.generator.set_state(self.generator_init_state) video = torchvision.io.read_video(BytesIO(sample[".mp4"].read()), output_format="TCHW", pts_unit='sec')[0] label = sample[".txt"] return self.transform(video), label if __name__ == "__main__": dataset = WebVidDataset( tar_index=0, root_path='/users/Etu9/3711799/onlyflow/data/webvid/desc.json', video_length=16, video_size=256, video_length_offset=0, ) sampler = wids.DistributedChunkedSampler(dataset, chunksize=1000, shuffle=True) dataloader = DataLoader( dataset, collate_fn=my_collate_fn, batch_size=4, sampler=sampler, num_workers=4 ) for i, (images, labels) in enumerate(dataloader): print(i, images.shape, labels) if i > 10: break