Spaces:
Sleeping
Sleeping
import functools | |
import os | |
from io import BytesIO | |
import torch | |
import torchvision | |
import torchvision.transforms.v2 as transforms | |
import webdataset as wds | |
def _video_shortener(video_tensor, length): | |
start = torch.randint(0, video_tensor.shape[0] - length, (1,)) | |
return video_tensor[start:start + length] | |
def select_video_extract(length=16): | |
return functools.partial(_video_shortener, length=length) | |
def my_collate_fn(batch): | |
output = {} | |
for key in batch[0].keys(): | |
if key == 'video': | |
output[key] = torch.stack([sample[key] for sample in batch]) | |
else: | |
output[key] = [sample[key] for sample in batch] | |
return output | |
def map_mp4(sample): | |
return torchvision.io.read_video(BytesIO(sample), output_format="TCHW", pts_unit='sec')[0] | |
def map_txt(sample): | |
return sample.decode("utf-8") | |
class WebVidDataset(wds.DataPipeline): | |
def __init__(self, batch_size, tar_index, root_path, video_length=16, video_size=256, video_length_offset=0, | |
horizontal_flip=True, seed=None): | |
self.dataset_full_path = os.path.join(root_path, f'webvid-uw-{{{tar_index}}}.tar') | |
if isinstance(video_size, int): | |
video_size = (video_size, video_size) | |
for size in video_size: | |
if size % 8 != 0: | |
raise ValueError("video_size must be divisible by 8") | |
self.pipeline = [ | |
wds.SimpleShardList('file:' + str(self.dataset_full_path), seed=seed), | |
wds.shuffle(50), | |
wds.split_by_node, | |
wds.tarfile_to_samples(), | |
wds.shuffle(100), | |
wds.split_by_worker, | |
wds.map_dict( | |
mp4=map_mp4, | |
txt=map_txt, | |
), | |
wds.map_dict( | |
mp4=transforms.Compose( | |
[ | |
select_video_extract(length=video_length + video_length_offset), | |
transforms.Resize(size=video_size), | |
transforms.RandomCrop(size=video_size), | |
transforms.RandomHorizontalFlip() if horizontal_flip else transforms.Identity, | |
] | |
) | |
), | |
wds.rename_keys(video="mp4", text='txt', keep_unselected=True), | |
wds.batched(batch_size, collation_fn=my_collate_fn, partial=True) | |
] | |
super().__init__(self.pipeline) | |
self.batch_size = batch_size | |
self.video_length = video_length | |
self.video_size = video_size | |