Spaces:
Sleeping
Sleeping
File size: 2,530 Bytes
9bb001a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 |
import functools
import os
from io import BytesIO
import torch
import torchvision
import torchvision.transforms.v2 as transforms
import webdataset as wds
def _video_shortener(video_tensor, length):
start = torch.randint(0, video_tensor.shape[0] - length, (1,))
return video_tensor[start:start + length]
def select_video_extract(length=16):
return functools.partial(_video_shortener, length=length)
def my_collate_fn(batch):
output = {}
for key in batch[0].keys():
if key == 'video':
output[key] = torch.stack([sample[key] for sample in batch])
else:
output[key] = [sample[key] for sample in batch]
return output
def map_mp4(sample):
return torchvision.io.read_video(BytesIO(sample), output_format="TCHW", pts_unit='sec')[0]
def map_txt(sample):
return sample.decode("utf-8")
class WebVidDataset(wds.DataPipeline):
def __init__(self, batch_size, tar_index, root_path, video_length=16, video_size=256, video_length_offset=0,
horizontal_flip=True, seed=None):
self.dataset_full_path = os.path.join(root_path, f'webvid-uw-{{{tar_index}}}.tar')
if isinstance(video_size, int):
video_size = (video_size, video_size)
for size in video_size:
if size % 8 != 0:
raise ValueError("video_size must be divisible by 8")
self.pipeline = [
wds.SimpleShardList('file:' + str(self.dataset_full_path), seed=seed),
wds.shuffle(50),
wds.split_by_node,
wds.tarfile_to_samples(),
wds.shuffle(100),
wds.split_by_worker,
wds.map_dict(
mp4=map_mp4,
txt=map_txt,
),
wds.map_dict(
mp4=transforms.Compose(
[
select_video_extract(length=video_length + video_length_offset),
transforms.Resize(size=video_size),
transforms.RandomCrop(size=video_size),
transforms.RandomHorizontalFlip() if horizontal_flip else transforms.Identity,
]
)
),
wds.rename_keys(video="mp4", text='txt', keep_unselected=True),
wds.batched(batch_size, collation_fn=my_collate_fn, partial=True)
]
super().__init__(self.pipeline)
self.batch_size = batch_size
self.video_length = video_length
self.video_size = video_size
|