File size: 2,782 Bytes
9bb001a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import functools
from io import BytesIO

import torch
import torchvision
import torchvision.transforms.v2 as transforms
import wids
from torch.utils.data import DataLoader


def _video_shortener(video_tensor, length, generator=None):
    start = torch.randint(0, video_tensor.shape[0] - length, (1,), generator=generator)
    return video_tensor[start:start + length]


def select_video_extract(length=16, generator=None):
    return functools.partial(_video_shortener, length=length, generator=generator)


def my_collate_fn(batch):
    videos = torch.stack([sample[0] for sample in batch])
    txts = [sample[1] for sample in batch]

    return videos, txts


class WebVidDataset(wids.ShardListDataset):

    def __init__(self, shards, cache_dir, video_length=16, video_size=256, video_length_offset=1, val=False, seed=42,
                 **kwargs):

        self.val = val
        self.generator = torch.Generator()
        self.generator.manual_seed(seed)
        self.generator_init_state = self.generator.get_state()
        super().__init__(shards, cache_dir=cache_dir, keep=True, **kwargs)

        if isinstance(video_size, int):
            video_size = (video_size, video_size)

        self.video_size = video_size

        for size in video_size:
            if size % 8 != 0:
                raise ValueError("video_size must be divisible by 8")

        self.transform = transforms.Compose(
            [
                select_video_extract(length=video_length + video_length_offset, generator=self.generator),
                transforms.Resize(size=video_size),
                transforms.RandomCrop(size=video_size) if not self.val else transforms.CenterCrop(size=video_size),
                transforms.RandomHorizontalFlip() if not self.val else transforms.Identity(),
            ]
        )

        self.add_transform(self._make_sample)

    def _make_sample(self, sample):
        if self.val:
            self.generator.set_state(self.generator_init_state)
        video = torchvision.io.read_video(BytesIO(sample[".mp4"].read()), output_format="TCHW", pts_unit='sec')[0]
        label = sample[".txt"]
        return self.transform(video), label


if __name__ == "__main__":

    dataset = WebVidDataset(
        tar_index=0,
        root_path='/users/Etu9/3711799/onlyflow/data/webvid/desc.json',
        video_length=16,
        video_size=256,
        video_length_offset=0,
    )

    sampler = wids.DistributedChunkedSampler(dataset, chunksize=1000, shuffle=True)
    dataloader = DataLoader(
        dataset,
        collate_fn=my_collate_fn,
        batch_size=4,
        sampler=sampler,
        num_workers=4
    )

    for i, (images, labels) in enumerate(dataloader):
        print(i, images.shape, labels)
        if i > 10:
            break