from pathlib import Path import pandas as pd from func_timeout import FunctionTimedOut, func_timeout from torch.utils.data import DataLoader, Dataset from utils.logger import logger from utils.video_utils import get_video_path_list, extract_frames ALL_VIDEO_EXT = set(["mp4", "webm", "mkv", "avi", "flv", "mov"]) VIDEO_READER_TIMEOUT = 10 def collate_fn(batch): batch = list(filter(lambda x: x is not None, batch)) if len(batch) != 0: return {k: [item[k] for item in batch] for k in batch[0].keys()} return {} class VideoDataset(Dataset): def __init__( self, video_path_list=None, video_folder=None, video_metadata_path=None, video_path_column=None, sample_method="mid", num_sampled_frames=1, num_sample_stride=None, ): self.video_path_column = video_path_column self.video_folder = video_folder self.sample_method = sample_method self.num_sampled_frames = num_sampled_frames self.num_sample_stride = num_sample_stride if video_path_list is not None: self.video_path_list = video_path_list self.metadata_df = pd.DataFrame({video_path_column: self.video_path_list}) else: self.video_path_list = get_video_path_list( video_folder=video_folder, video_metadata_path=video_metadata_path, video_path_column=video_path_column ) def __getitem__(self, index): # video_path = os.path.join(self.video_folder, str(self.video_path_list[index])) video_path = self.video_path_list[index] try: sample_args = (video_path, self.sample_method, self.num_sampled_frames, self.num_sample_stride) sampled_frame_idx_list, sampled_frame_list = func_timeout( VIDEO_READER_TIMEOUT, extract_frames, args=sample_args ) except FunctionTimedOut: logger.warning(f"Read {video_path} timeout.") return None except Exception as e: logger.warning(f"Failed to extract frames from video {video_path}. Error is {e}.") return None item = { "video_path": Path(video_path).name, "sampled_frame_idx": sampled_frame_idx_list, "sampled_frame": sampled_frame_list, } return item def __len__(self): return len(self.video_path_list) if __name__ == "__main__": video_folder = "your_video_folder" video_dataset = VideoDataset(video_folder=video_folder) video_dataloader = DataLoader( video_dataset, batch_size=16, num_workers=16, collate_fn=collate_fn ) for idx, batch in enumerate(video_dataloader): if len(batch) != 0: print(batch["video_path"], batch["sampled_frame_idx"], len(batch["video_path"]))