Spaces:
Sleeping
Sleeping
import functools | |
from io import BytesIO | |
import torch | |
import torchvision | |
import torchvision.transforms.v2 as transforms | |
import wids | |
from torch.utils.data import DataLoader | |
def _video_shortener(video_tensor, length, generator=None): | |
start = torch.randint(0, video_tensor.shape[0] - length, (1,), generator=generator) | |
return video_tensor[start:start + length] | |
def select_video_extract(length=16, generator=None): | |
return functools.partial(_video_shortener, length=length, generator=generator) | |
def my_collate_fn(batch): | |
videos = torch.stack([sample[0] for sample in batch]) | |
txts = [sample[1] for sample in batch] | |
return videos, txts | |
class WebVidDataset(wids.ShardListDataset): | |
def __init__(self, shards, cache_dir, video_length=16, video_size=256, video_length_offset=1, val=False, seed=42, | |
**kwargs): | |
self.val = val | |
self.generator = torch.Generator() | |
self.generator.manual_seed(seed) | |
self.generator_init_state = self.generator.get_state() | |
super().__init__(shards, cache_dir=cache_dir, keep=True, **kwargs) | |
if isinstance(video_size, int): | |
video_size = (video_size, video_size) | |
self.video_size = video_size | |
for size in video_size: | |
if size % 8 != 0: | |
raise ValueError("video_size must be divisible by 8") | |
self.transform = transforms.Compose( | |
[ | |
select_video_extract(length=video_length + video_length_offset, generator=self.generator), | |
transforms.Resize(size=video_size), | |
transforms.RandomCrop(size=video_size) if not self.val else transforms.CenterCrop(size=video_size), | |
transforms.RandomHorizontalFlip() if not self.val else transforms.Identity(), | |
] | |
) | |
self.add_transform(self._make_sample) | |
def _make_sample(self, sample): | |
if self.val: | |
self.generator.set_state(self.generator_init_state) | |
video = torchvision.io.read_video(BytesIO(sample[".mp4"].read()), output_format="TCHW", pts_unit='sec')[0] | |
label = sample[".txt"] | |
return self.transform(video), label | |
if __name__ == "__main__": | |
dataset = WebVidDataset( | |
tar_index=0, | |
root_path='/users/Etu9/3711799/onlyflow/data/webvid/desc.json', | |
video_length=16, | |
video_size=256, | |
video_length_offset=0, | |
) | |
sampler = wids.DistributedChunkedSampler(dataset, chunksize=1000, shuffle=True) | |
dataloader = DataLoader( | |
dataset, | |
collate_fn=my_collate_fn, | |
batch_size=4, | |
sampler=sampler, | |
num_workers=4 | |
) | |
for i, (images, labels) in enumerate(dataloader): | |
print(i, images.shape, labels) | |
if i > 10: | |
break | |