|
import os |
|
from tracemalloc import start |
|
import warnings |
|
import glob |
|
import random |
|
import numpy as np |
|
from PIL import Image |
|
|
|
import torch |
|
from torch.utils.data import Dataset |
|
import torchvision |
|
import torch.distributed as dist |
|
|
|
from decord import VideoReader |
|
from pcache_fileio import fileio |
|
from pcache_fileio.oss_conf import OssConfigFactory |
|
|
|
|
|
class SakugaRefDataset(Dataset): |
|
def __init__( |
|
self, |
|
|
|
video_frames=25, |
|
ref_jump_frames=36, |
|
base_folder='data/samples/', |
|
file_list=None, |
|
temporal_sample=None, |
|
transform=None, |
|
seed=42, |
|
): |
|
""" |
|
Args: |
|
num_samples (int): Number of samples in the dataset. |
|
channels (int): Number of channels, default is 3 for RGB. |
|
""" |
|
|
|
|
|
self.base_folder = base_folder |
|
|
|
self.file_list = file_list |
|
if file_list is None: |
|
self.video_lists = glob.glob(os.path.join(self.base_folder, '*.mp4')) |
|
else: |
|
|
|
self.video_lists = [] |
|
with open(file_list, 'r') as f: |
|
for line in f: |
|
video_path = line.strip() |
|
self.video_lists.append(os.path.join(self.base_folder, video_path)) |
|
|
|
self.num_samples = len(self.video_lists) |
|
self.channels = 3 |
|
|
|
|
|
self.video_frames = video_frames |
|
self.ref_jump_frames = ref_jump_frames |
|
self.temporal_sample = temporal_sample |
|
self.transform = transform |
|
|
|
self.seed = seed |
|
|
|
def __len__(self): |
|
return self.num_samples |
|
|
|
def get_sample(self, idx): |
|
""" |
|
Args: |
|
idx (int): Index of the sample to return. |
|
|
|
Returns: |
|
dict: A dictionary containing the 'pixel_values' tensor of shape (16, channels, 320, 512). |
|
""" |
|
|
|
|
|
path = self.video_lists[idx] |
|
|
|
if self.file_list is not None: |
|
with open(path, 'rb') as f: |
|
vframes = VideoReader(f) |
|
else: |
|
vframes, aframes, info = torchvision.io.read_video(filename=path, pts_unit='sec', output_format='TCHW') |
|
total_frames = len(vframes) |
|
|
|
|
|
ref_frame_ind, end_frame_ind = self.temporal_sample(total_frames) |
|
if not end_frame_ind - ref_frame_ind >= self.video_frames+self.ref_jump_frames: |
|
raise ValueError(f'video {path} does not have enough frames') |
|
start_frame_ind = ref_frame_ind + self.ref_jump_frames |
|
frame_indice = np.linspace(start_frame_ind, end_frame_ind-1, self.video_frames, dtype=int) |
|
frame_indice = np.insert(frame_indice, 0, ref_frame_ind) |
|
if self.file_list is not None: |
|
video = torch.from_numpy(vframes.get_batch(frame_indice).asnumpy()).permute(0, 3, 1, 2).contiguous() |
|
else: |
|
video = vframes[frame_indice] |
|
|
|
|
|
pixel_values = self.transform(video) |
|
|
|
return {'pixel_values': pixel_values} |
|
|
|
def __getitem__(self, idx): |
|
|
|
|
|
while(True): |
|
try: |
|
|
|
|
|
item = self.get_sample(idx) |
|
return item |
|
except: |
|
|
|
idx = np.random.randint(0, len(self.video_lists) - 1) |
|
|
|
|
|
|
|
|
|
|