meepmoo's picture
Upload folder using huggingface_hub
0dcccdd verified
raw
history blame
3.89 kB
import os
from typing import Optional
from pathlib import Path
from func_timeout import func_timeout, FunctionTimedOut
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from .logger import logger
from .video_utils import extract_frames
ALL_VIDEO_EXT = set(["mp4", "webm", "mkv", "avi", "flv", "mov"])
VIDEO_READER_TIMEOUT = 300
def collate_fn(batch):
batch = list(filter(lambda x: x is not None, batch))
if len(batch) != 0:
return {k: [item[k] for item in batch] for k in batch[0].keys()}
return {}
class VideoDataset(Dataset):
def __init__(
self,
dataset_inputs: dict[str, list[str]],
video_folder: Optional[str] = None,
video_path_column: str = "video_path",
text_column: Optional[str] = None,
sample_method: str = "mid",
num_sampled_frames: int = 1,
num_sample_stride: Optional[int] = None
):
length = len(dataset_inputs[list(dataset_inputs.keys())[0]])
if not all(len(v) == length for v in dataset_inputs.values()):
raise ValueError("All values in the dataset_inputs must have the same length.")
self.video_path_column = video_path_column
self.video_folder = video_folder
self.video_path_list = dataset_inputs[video_path_column]
if self.video_folder is not None:
self.video_path_list = [os.path.join(self.video_folder, video_path) for video_path in self.video_path_list]
self.text_column = text_column
self.text_list = dataset_inputs[self.text_column] if self.text_column is not None else None
self.sample_method = sample_method
self.num_sampled_frames = num_sampled_frames
self.num_sample_stride = num_sample_stride
def __getitem__(self, index):
video_path = self.video_path_list[index]
if self.sample_method == "image":
try:
sampled_frame_idx_list = None
with open(video_path, "rb") as f:
sampled_frame_list = [Image.open(f).convert("RGB")]
except Exception as e:
logger.warning(f"Failed to extract frames from video {video_path}. Error is {e}.")
return None
else:
# It is a trick to deal with decord hanging when reading some abnormal videos.
try:
sample_args = (video_path, self.sample_method, self.num_sampled_frames, self.num_sample_stride)
sampled_frame_idx_list, sampled_frame_list = func_timeout(
VIDEO_READER_TIMEOUT, extract_frames, args=sample_args
)
except FunctionTimedOut:
logger.warning(f"Read {video_path} timeout.")
return None
except Exception as e:
logger.warning(f"Failed to extract frames from video {video_path}. Error is {e}.")
return None
item = {
"path": video_path,
"sampled_frame_idx": sampled_frame_idx_list,
"sampled_frame": sampled_frame_list,
}
if self.text_list is not None:
item["text"] = self.text_list[index]
return item
def __len__(self):
return len(self.video_path_list)
if __name__ == "__main__":
video_folder = Path("your_video_folder")
video_path_list = []
for ext in ALL_VIDEO_EXT:
video_path_list += [str(file.relative_to(video_folder)) for file in video_folder.glob(f"*.{ext}")]
video_dataset = VideoDataset(dataset_inputs={"video_path": video_path_list})
video_dataloader = DataLoader(
video_dataset, batch_size=16, num_workers=16, collate_fn=collate_fn
)
for idx, batch in enumerate(video_dataloader):
if len(batch) != 0:
print(batch["video_path"], batch["sampled_frame_idx"], len(batch["video_path"]))