File size: 3,969 Bytes
c705408
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import os
from tracemalloc import start
import warnings
import glob
import random
import numpy as np
from PIL import Image

import torch
from torch.utils.data import Dataset
import torchvision
import torch.distributed as dist

from decord import VideoReader
from pcache_fileio import fileio
from pcache_fileio.oss_conf import OssConfigFactory


class SakugaRefDataset(Dataset):
    def __init__(
            self, 
            # width=1024, height=576, 
            video_frames=25, 
            ref_jump_frames=36,
            base_folder='data/samples/',
            file_list=None, 
            temporal_sample=None,
            transform=None,
            seed=42,
        ):
        """
        Args:
            num_samples (int): Number of samples in the dataset.
            channels (int): Number of channels, default is 3 for RGB.
        """
        # Define the path to the folder containing video frames
        # self.base_folder = 'bdd100k/images/track/mini'
        self.base_folder = base_folder

        self.file_list = file_list
        if file_list is None:
            self.video_lists = glob.glob(os.path.join(self.base_folder, '*.mp4'))
        else:
            # read from file_list.txt
            self.video_lists = []
            with open(file_list, 'r') as f:
                for line in f:
                    video_path = line.strip()
                    self.video_lists.append(os.path.join(self.base_folder, video_path))

        self.num_samples = len(self.video_lists)
        self.channels = 3
        # self.width = width
        # self.height = height
        self.video_frames = video_frames
        self.ref_jump_frames = ref_jump_frames
        self.temporal_sample = temporal_sample
        self.transform = transform

        self.seed = seed

    def __len__(self):
        return self.num_samples

    def get_sample(self, idx):
        """
        Args:
            idx (int): Index of the sample to return.

        Returns:
            dict: A dictionary containing the 'pixel_values' tensor of shape (16, channels, 320, 512).
        """

        # path = random.choice(self.video_lists)
        path = self.video_lists[idx]

        if self.file_list is not None:  # read from pcache
            with open(path, 'rb') as f:
                vframes = VideoReader(f)
        else:
            vframes, aframes, info = torchvision.io.read_video(filename=path, pts_unit='sec', output_format='TCHW')
        total_frames = len(vframes)

        # Sampling video frames
        ref_frame_ind, end_frame_ind = self.temporal_sample(total_frames)
        if not end_frame_ind - ref_frame_ind >= self.video_frames+self.ref_jump_frames:
            raise ValueError(f'video {path} does not have enough frames')
        start_frame_ind = ref_frame_ind + self.ref_jump_frames
        frame_indice = np.linspace(start_frame_ind, end_frame_ind-1, self.video_frames, dtype=int)
        frame_indice = np.insert(frame_indice, 0, ref_frame_ind)
        if self.file_list is not None:  # read from pcache
            video = torch.from_numpy(vframes.get_batch(frame_indice).asnumpy()).permute(0, 3, 1, 2).contiguous()
        else:
            video = vframes[frame_indice]

        # (f c h w)
        pixel_values = self.transform(video)

        return {'pixel_values': pixel_values}  # the [0] index for pixel_values is the reference image, the other indexes are the video frames

    def __getitem__(self, idx):
        # return self.get_sample(idx)

        while(True):
            try:
                # idx = np.random.randint(0, len(self.video_lists) - 1)
                # idx = self.rng.integers(0, len(self.video_lists))
                item = self.get_sample(idx)
                return item
            except:
                # warnings.warn(f'loading {idx} failed, retrying...')
                idx = np.random.randint(0, len(self.video_lists) - 1)



            # item = self.get_sample(idx)
            # return item