OnlyFlow / onlyflow /data /dataset_itr.py
arlaz's picture
initial commit
9bb001a
import functools
import os
from io import BytesIO
import torch
import torchvision
import torchvision.transforms.v2 as transforms
import webdataset as wds
def _video_shortener(video_tensor, length):
start = torch.randint(0, video_tensor.shape[0] - length, (1,))
return video_tensor[start:start + length]
def select_video_extract(length=16):
return functools.partial(_video_shortener, length=length)
def my_collate_fn(batch):
output = {}
for key in batch[0].keys():
if key == 'video':
output[key] = torch.stack([sample[key] for sample in batch])
else:
output[key] = [sample[key] for sample in batch]
return output
def map_mp4(sample):
return torchvision.io.read_video(BytesIO(sample), output_format="TCHW", pts_unit='sec')[0]
def map_txt(sample):
return sample.decode("utf-8")
class WebVidDataset(wds.DataPipeline):
def __init__(self, batch_size, tar_index, root_path, video_length=16, video_size=256, video_length_offset=0,
horizontal_flip=True, seed=None):
self.dataset_full_path = os.path.join(root_path, f'webvid-uw-{{{tar_index}}}.tar')
if isinstance(video_size, int):
video_size = (video_size, video_size)
for size in video_size:
if size % 8 != 0:
raise ValueError("video_size must be divisible by 8")
self.pipeline = [
wds.SimpleShardList('file:' + str(self.dataset_full_path), seed=seed),
wds.shuffle(50),
wds.split_by_node,
wds.tarfile_to_samples(),
wds.shuffle(100),
wds.split_by_worker,
wds.map_dict(
mp4=map_mp4,
txt=map_txt,
),
wds.map_dict(
mp4=transforms.Compose(
[
select_video_extract(length=video_length + video_length_offset),
transforms.Resize(size=video_size),
transforms.RandomCrop(size=video_size),
transforms.RandomHorizontalFlip() if horizontal_flip else transforms.Identity,
]
)
),
wds.rename_keys(video="mp4", text='txt', keep_unselected=True),
wds.batched(batch_size, collation_fn=my_collate_fn, partial=True)
]
super().__init__(self.pipeline)
self.batch_size = batch_size
self.video_length = video_length
self.video_size = video_size