|
import os |
|
from typing import Optional |
|
from pathlib import Path |
|
|
|
from func_timeout import func_timeout, FunctionTimedOut |
|
from PIL import Image |
|
from torch.utils.data import Dataset, DataLoader |
|
|
|
from .logger import logger |
|
from .video_utils import extract_frames |
|
|
|
|
|
ALL_VIDEO_EXT = set(["mp4", "webm", "mkv", "avi", "flv", "mov"]) |
|
VIDEO_READER_TIMEOUT = 300 |
|
|
|
|
|
def collate_fn(batch): |
|
batch = list(filter(lambda x: x is not None, batch)) |
|
if len(batch) != 0: |
|
return {k: [item[k] for item in batch] for k in batch[0].keys()} |
|
return {} |
|
|
|
|
|
class VideoDataset(Dataset): |
|
def __init__( |
|
self, |
|
dataset_inputs: dict[str, list[str]], |
|
video_folder: Optional[str] = None, |
|
video_path_column: str = "video_path", |
|
text_column: Optional[str] = None, |
|
sample_method: str = "mid", |
|
num_sampled_frames: int = 1, |
|
num_sample_stride: Optional[int] = None |
|
): |
|
length = len(dataset_inputs[list(dataset_inputs.keys())[0]]) |
|
if not all(len(v) == length for v in dataset_inputs.values()): |
|
raise ValueError("All values in the dataset_inputs must have the same length.") |
|
|
|
self.video_path_column = video_path_column |
|
self.video_folder = video_folder |
|
self.video_path_list = dataset_inputs[video_path_column] |
|
if self.video_folder is not None: |
|
self.video_path_list = [os.path.join(self.video_folder, video_path) for video_path in self.video_path_list] |
|
self.text_column = text_column |
|
self.text_list = dataset_inputs[self.text_column] if self.text_column is not None else None |
|
|
|
self.sample_method = sample_method |
|
self.num_sampled_frames = num_sampled_frames |
|
self.num_sample_stride = num_sample_stride |
|
|
|
def __getitem__(self, index): |
|
video_path = self.video_path_list[index] |
|
if self.sample_method == "image": |
|
try: |
|
sampled_frame_idx_list = None |
|
with open(video_path, "rb") as f: |
|
sampled_frame_list = [Image.open(f).convert("RGB")] |
|
except Exception as e: |
|
logger.warning(f"Failed to extract frames from video {video_path}. Error is {e}.") |
|
return None |
|
else: |
|
|
|
try: |
|
sample_args = (video_path, self.sample_method, self.num_sampled_frames, self.num_sample_stride) |
|
sampled_frame_idx_list, sampled_frame_list = func_timeout( |
|
VIDEO_READER_TIMEOUT, extract_frames, args=sample_args |
|
) |
|
except FunctionTimedOut: |
|
logger.warning(f"Read {video_path} timeout.") |
|
return None |
|
except Exception as e: |
|
logger.warning(f"Failed to extract frames from video {video_path}. Error is {e}.") |
|
return None |
|
|
|
item = { |
|
"path": video_path, |
|
"sampled_frame_idx": sampled_frame_idx_list, |
|
"sampled_frame": sampled_frame_list, |
|
} |
|
if self.text_list is not None: |
|
item["text"] = self.text_list[index] |
|
|
|
return item |
|
|
|
def __len__(self): |
|
return len(self.video_path_list) |
|
|
|
|
|
if __name__ == "__main__": |
|
video_folder = Path("your_video_folder") |
|
video_path_list = [] |
|
for ext in ALL_VIDEO_EXT: |
|
video_path_list += [str(file.relative_to(video_folder)) for file in video_folder.glob(f"*.{ext}")] |
|
|
|
video_dataset = VideoDataset(dataset_inputs={"video_path": video_path_list}) |
|
video_dataloader = DataLoader( |
|
video_dataset, batch_size=16, num_workers=16, collate_fn=collate_fn |
|
) |
|
for idx, batch in enumerate(video_dataloader): |
|
if len(batch) != 0: |
|
print(batch["video_path"], batch["sampled_frame_idx"], len(batch["video_path"])) |