import json import logging import os from pathlib import Path from typing import Union import pandas as pd import torch from torch.utils.data.dataset import Dataset from torchvision.transforms import v2 from torio.io import StreamingMediaDecoder from mmaudio.utils.dist_utils import local_rank log = logging.getLogger() _CLIP_SIZE = 384 _CLIP_FPS = 8.0 _SYNC_SIZE = 224 _SYNC_FPS = 25.0 class VideoDataset(Dataset): def __init__( self, video_root: Union[str, Path], *, duration_sec: float = 8.0, ): self.video_root = Path(video_root) self.duration_sec = duration_sec self.clip_expected_length = int(_CLIP_FPS * self.duration_sec) self.sync_expected_length = int(_SYNC_FPS * self.duration_sec) self.clip_transform = v2.Compose([ v2.Resize((_CLIP_SIZE, _CLIP_SIZE), interpolation=v2.InterpolationMode.BICUBIC), v2.ToImage(), v2.ToDtype(torch.float32, scale=True), ]) self.sync_transform = v2.Compose([ v2.Resize(_SYNC_SIZE, interpolation=v2.InterpolationMode.BICUBIC), v2.CenterCrop(_SYNC_SIZE), v2.ToImage(), v2.ToDtype(torch.float32, scale=True), v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), ]) # to be implemented by subclasses self.captions = {} self.videos = sorted(list(self.captions.keys())) def sample(self, idx: int) -> dict[str, torch.Tensor]: video_id = self.videos[idx] caption = self.captions[video_id] reader = StreamingMediaDecoder(self.video_root / (video_id + '.mp4')) reader.add_basic_video_stream( frames_per_chunk=int(_CLIP_FPS * self.duration_sec), frame_rate=_CLIP_FPS, format='rgb24', ) reader.add_basic_video_stream( frames_per_chunk=int(_SYNC_FPS * self.duration_sec), frame_rate=_SYNC_FPS, format='rgb24', ) reader.fill_buffer() data_chunk = reader.pop_chunks() clip_chunk = data_chunk[0] sync_chunk = data_chunk[1] if clip_chunk is None: raise RuntimeError(f'CLIP video returned None {video_id}') if clip_chunk.shape[0] < self.clip_expected_length: raise RuntimeError( f'CLIP video too short {video_id}, expected {self.clip_expected_length}, got {clip_chunk.shape[0]}' ) if sync_chunk is None: raise RuntimeError(f'Sync video returned None {video_id}') if sync_chunk.shape[0] < self.sync_expected_length: raise RuntimeError( f'Sync video too short {video_id}, expected {self.sync_expected_length}, got {sync_chunk.shape[0]}' ) # truncate the video clip_chunk = clip_chunk[:self.clip_expected_length] if clip_chunk.shape[0] != self.clip_expected_length: raise RuntimeError(f'CLIP video wrong length {video_id}, ' f'expected {self.clip_expected_length}, ' f'got {clip_chunk.shape[0]}') clip_chunk = self.clip_transform(clip_chunk) sync_chunk = sync_chunk[:self.sync_expected_length] if sync_chunk.shape[0] != self.sync_expected_length: raise RuntimeError(f'Sync video wrong length {video_id}, ' f'expected {self.sync_expected_length}, ' f'got {sync_chunk.shape[0]}') sync_chunk = self.sync_transform(sync_chunk) data = { 'name': video_id, 'caption': caption, 'clip_video': clip_chunk, 'sync_video': sync_chunk, } return data def __getitem__(self, idx: int) -> dict[str, torch.Tensor]: try: return self.sample(idx) except Exception as e: log.error(f'Error loading video {self.videos[idx]}: {e}') return None def __len__(self): return len(self.captions) class VGGSound(VideoDataset): def __init__( self, video_root: Union[str, Path], csv_path: Union[str, Path], *, duration_sec: float = 8.0, ): super().__init__(video_root, duration_sec=duration_sec) self.video_root = Path(video_root) self.csv_path = Path(csv_path) videos = sorted(os.listdir(self.video_root)) if local_rank == 0: log.info(f'{len(videos)} videos found in {video_root}') self.captions = {} df = pd.read_csv(csv_path, header=None, names=['id', 'sec', 'caption', 'split']).to_dict(orient='records') videos_no_found = [] for row in df: if row['split'] == 'test': start_sec = int(row['sec']) video_id = str(row['id']) # this is how our videos are named video_name = f'{video_id}_{start_sec:06d}' if video_name + '.mp4' not in videos: videos_no_found.append(video_name) continue self.captions[video_name] = row['caption'] if local_rank == 0: log.info(f'{len(videos)} videos found in {video_root}') log.info(f'{len(self.captions)} useable videos found') if videos_no_found: log.info(f'{len(videos_no_found)} found in {csv_path} but not in {video_root}') log.info( 'A small amount is expected, as not all videos are still available on YouTube') self.videos = sorted(list(self.captions.keys())) class MovieGen(VideoDataset): def __init__( self, video_root: Union[str, Path], jsonl_root: Union[str, Path], *, duration_sec: float = 10.0, ): super().__init__(video_root, duration_sec=duration_sec) self.video_root = Path(video_root) self.jsonl_root = Path(jsonl_root) videos = sorted(os.listdir(self.video_root)) videos = [v[:-4] for v in videos] # remove extensions self.captions = {} for v in videos: with open(self.jsonl_root / (v + '.jsonl')) as f: data = json.load(f) self.captions[v] = data['audio_prompt'] if local_rank == 0: log.info(f'{len(videos)} videos found in {video_root}') self.videos = videos