File size: 2,530 Bytes
9bb001a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import functools
import os
from io import BytesIO

import torch
import torchvision
import torchvision.transforms.v2 as transforms
import webdataset as wds


def _video_shortener(video_tensor, length):
    start = torch.randint(0, video_tensor.shape[0] - length, (1,))
    return video_tensor[start:start + length]


def select_video_extract(length=16):
    return functools.partial(_video_shortener, length=length)


def my_collate_fn(batch):
    output = {}
    for key in batch[0].keys():
        if key == 'video':
            output[key] = torch.stack([sample[key] for sample in batch])
        else:
            output[key] = [sample[key] for sample in batch]

    return output


def map_mp4(sample):
    return torchvision.io.read_video(BytesIO(sample), output_format="TCHW", pts_unit='sec')[0]


def map_txt(sample):
    return sample.decode("utf-8")


class WebVidDataset(wds.DataPipeline):
    def __init__(self, batch_size, tar_index, root_path, video_length=16, video_size=256, video_length_offset=0,
                 horizontal_flip=True, seed=None):

        self.dataset_full_path = os.path.join(root_path, f'webvid-uw-{{{tar_index}}}.tar')

        if isinstance(video_size, int):
            video_size = (video_size, video_size)

        for size in video_size:
            if size % 8 != 0:
                raise ValueError("video_size must be divisible by 8")

        self.pipeline = [
            wds.SimpleShardList('file:' + str(self.dataset_full_path), seed=seed),
            wds.shuffle(50),
            wds.split_by_node,
            wds.tarfile_to_samples(),
            wds.shuffle(100),
            wds.split_by_worker,
            wds.map_dict(
                mp4=map_mp4,
                txt=map_txt,
            ),
            wds.map_dict(
                mp4=transforms.Compose(
                    [
                        select_video_extract(length=video_length + video_length_offset),
                        transforms.Resize(size=video_size),
                        transforms.RandomCrop(size=video_size),
                        transforms.RandomHorizontalFlip() if horizontal_flip else transforms.Identity,
                    ]
                )
            ),
            wds.rename_keys(video="mp4", text='txt', keep_unselected=True),
            wds.batched(batch_size, collation_fn=my_collate_fn, partial=True)
        ]

        super().__init__(self.pipeline)

        self.batch_size = batch_size
        self.video_length = video_length
        self.video_size = video_size