Spaces:
Running
Running
from pathlib import Path | |
import pandas as pd | |
from func_timeout import FunctionTimedOut, func_timeout | |
from torch.utils.data import DataLoader, Dataset | |
from utils.logger import logger | |
from utils.video_utils import get_video_path_list, extract_frames | |
ALL_VIDEO_EXT = set(["mp4", "webm", "mkv", "avi", "flv", "mov"]) | |
VIDEO_READER_TIMEOUT = 10 | |
def collate_fn(batch): | |
batch = list(filter(lambda x: x is not None, batch)) | |
if len(batch) != 0: | |
return {k: [item[k] for item in batch] for k in batch[0].keys()} | |
return {} | |
class VideoDataset(Dataset): | |
def __init__( | |
self, | |
video_path_list=None, | |
video_folder=None, | |
video_metadata_path=None, | |
video_path_column=None, | |
sample_method="mid", | |
num_sampled_frames=1, | |
num_sample_stride=None, | |
): | |
self.video_path_column = video_path_column | |
self.video_folder = video_folder | |
self.sample_method = sample_method | |
self.num_sampled_frames = num_sampled_frames | |
self.num_sample_stride = num_sample_stride | |
if video_path_list is not None: | |
self.video_path_list = video_path_list | |
self.metadata_df = pd.DataFrame({video_path_column: self.video_path_list}) | |
else: | |
self.video_path_list = get_video_path_list( | |
video_folder=video_folder, | |
video_metadata_path=video_metadata_path, | |
video_path_column=video_path_column | |
) | |
def __getitem__(self, index): | |
# video_path = os.path.join(self.video_folder, str(self.video_path_list[index])) | |
video_path = self.video_path_list[index] | |
try: | |
sample_args = (video_path, self.sample_method, self.num_sampled_frames, self.num_sample_stride) | |
sampled_frame_idx_list, sampled_frame_list = func_timeout( | |
VIDEO_READER_TIMEOUT, extract_frames, args=sample_args | |
) | |
except FunctionTimedOut: | |
logger.warning(f"Read {video_path} timeout.") | |
return None | |
except Exception as e: | |
logger.warning(f"Failed to extract frames from video {video_path}. Error is {e}.") | |
return None | |
item = { | |
"video_path": Path(video_path).name, | |
"sampled_frame_idx": sampled_frame_idx_list, | |
"sampled_frame": sampled_frame_list, | |
} | |
return item | |
def __len__(self): | |
return len(self.video_path_list) | |
if __name__ == "__main__": | |
video_folder = "your_video_folder" | |
video_dataset = VideoDataset(video_folder=video_folder) | |
video_dataloader = DataLoader( | |
video_dataset, batch_size=16, num_workers=16, collate_fn=collate_fn | |
) | |
for idx, batch in enumerate(video_dataloader): | |
if len(batch) != 0: | |
print(batch["video_path"], batch["sampled_frame_idx"], len(batch["video_path"])) |