# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. import math import numpy as np import random import torch import torchvision.io as io def temporal_sampling(frames, start_idx, end_idx, num_samples): """ Given the start and end frame index, sample num_samples frames between the start and end with equal interval. Args: frames (tensor): a tensor of video frames, dimension is `num video frames` x `channel` x `height` x `width`. start_idx (int): the index of the start frame. end_idx (int): the index of the end frame. num_samples (int): number of frames to sample. Returns: frames (tersor): a tensor of temporal sampled video frames, dimension is `num clip frames` x `channel` x `height` x `width`. """ index = torch.linspace(start_idx, end_idx, num_samples) index = torch.clamp(index, 0, frames.shape[0] - 1).long() frames = torch.index_select(frames, 0, index) return frames def get_start_end_idx(video_size, clip_size, clip_idx, num_clips): """ Sample a clip of size clip_size from a video of size video_size and return the indices of the first and last frame of the clip. If clip_idx is -1, the clip is randomly sampled, otherwise uniformly split the video to num_clips clips, and select the start and end index of clip_idx-th video clip. Args: video_size (int): number of overall frames. clip_size (int): size of the clip to sample from the frames. clip_idx (int): if clip_idx is -1, perform random jitter sampling. If clip_idx is larger than -1, uniformly split the video to num_clips clips, and select the start and end index of the clip_idx-th video clip. num_clips (int): overall number of clips to uniformly sample from the given video for testing. Returns: start_idx (int): the start frame index. end_idx (int): the end frame index. """ delta = max(video_size - clip_size, 0) if clip_idx == -1: # Random temporal sampling. start_idx = random.uniform(0, delta) else: # Uniformly sample the clip with the given index. start_idx = delta * clip_idx / num_clips end_idx = start_idx + clip_size - 1 return start_idx, end_idx def pyav_decode_stream( container, start_pts, end_pts, stream, stream_name, buffer_size=0 ): """ Decode the video with PyAV decoder. Args: container (container): PyAV container. start_pts (int): the starting Presentation TimeStamp to fetch the video frames. end_pts (int): the ending Presentation TimeStamp of the decoded frames. stream (stream): PyAV stream. stream_name (dict): a dictionary of streams. For example, {"video": 0} means video stream at stream index 0. buffer_size (int): number of additional frames to decode beyond end_pts. Returns: result (list): list of frames decoded. max_pts (int): max Presentation TimeStamp of the video sequence. """ # Seeking in the stream is imprecise. Thus, seek to an ealier PTS by a # margin pts. margin = 1024 seek_offset = max(start_pts - margin, 0) container.seek(seek_offset, any_frame=False, backward=True, stream=stream) frames = {} buffer_count = 0 max_pts = 0 for frame in container.decode(**stream_name): max_pts = max(max_pts, frame.pts) if frame.pts < start_pts: continue if frame.pts <= end_pts: frames[frame.pts] = frame else: buffer_count += 1 frames[frame.pts] = frame if buffer_count >= buffer_size: break result = [frames[pts] for pts in sorted(frames)] return result, max_pts def torchvision_decode( video_handle, sampling_rate, num_frames, clip_idx, video_meta, num_clips=10, target_fps=30, modalities=("visual",), max_spatial_scale=0, ): """ If video_meta is not empty, perform temporal selective decoding to sample a clip from the video with TorchVision decoder. If video_meta is empty, decode the entire video and update the video_meta. Args: video_handle (bytes): raw bytes of the video file. sampling_rate (int): frame sampling rate (interval between two sampled frames). num_frames (int): number of frames to sample. clip_idx (int): if clip_idx is -1, perform random temporal sampling. If clip_idx is larger than -1, uniformly split the video to num_clips clips, and select the clip_idx-th video clip. video_meta (dict): a dict contains VideoMetaData. Details can be found at `pytorch/vision/torchvision/io/_video_opt.py`. num_clips (int): overall number of clips to uniformly sample from the given video. target_fps (int): the input video may has different fps, convert it to the target video fps. modalities (tuple): tuple of modalities to decode. Currently only support `visual`, planning to support `acoustic` soon. max_spatial_scale (int): the maximal resolution of the spatial shorter edge size during decoding. Returns: frames (tensor): decoded frames from the video. fps (float): the number of frames per second of the video. decode_all_video (bool): if True, the entire video was decoded. """ # Convert the bytes to a tensor. video_tensor = torch.from_numpy(np.frombuffer(video_handle, dtype=np.uint8)) decode_all_video = True video_start_pts, video_end_pts = 0, -1 # The video_meta is empty, fetch the meta data from the raw video. if len(video_meta) == 0: # Tracking the meta info for selective decoding in the future. meta = io._probe_video_from_memory(video_tensor) # Using the information from video_meta to perform selective decoding. video_meta["video_timebase"] = meta.video_timebase video_meta["video_numerator"] = meta.video_timebase.numerator video_meta["video_denominator"] = meta.video_timebase.denominator video_meta["has_video"] = meta.has_video video_meta["video_duration"] = meta.video_duration video_meta["video_fps"] = meta.video_fps video_meta["audio_timebas"] = meta.audio_timebase video_meta["audio_numerator"] = meta.audio_timebase.numerator video_meta["audio_denominator"] = meta.audio_timebase.denominator video_meta["has_audio"] = meta.has_audio video_meta["audio_duration"] = meta.audio_duration video_meta["audio_sample_rate"] = meta.audio_sample_rate fps = video_meta["video_fps"] if ( video_meta["has_video"] and video_meta["video_denominator"] > 0 and video_meta["video_duration"] > 0 ): # try selective decoding. decode_all_video = False clip_size = sampling_rate * num_frames / target_fps * fps start_idx, end_idx = get_start_end_idx( fps * video_meta["video_duration"], clip_size, clip_idx, num_clips ) # Convert frame index to pts. pts_per_frame = video_meta["video_denominator"] / fps video_start_pts = int(start_idx * pts_per_frame) video_end_pts = int(end_idx * pts_per_frame) # Decode the raw video with the tv decoder. v_frames, _ = io._read_video_from_memory( video_tensor, seek_frame_margin=1.0, read_video_stream="visual" in modalities, video_width=0, video_height=0, video_min_dimension=max_spatial_scale, video_pts_range=(video_start_pts, video_end_pts), video_timebase_numerator=video_meta["video_numerator"], video_timebase_denominator=video_meta["video_denominator"], ) if v_frames.shape == torch.Size([0]): # failed selective decoding decode_all_video = True video_start_pts, video_end_pts = 0, -1 v_frames, _ = io._read_video_from_memory( video_tensor, seek_frame_margin=1.0, read_video_stream="visual" in modalities, video_width=0, video_height=0, video_min_dimension=max_spatial_scale, video_pts_range=(video_start_pts, video_end_pts), video_timebase_numerator=video_meta["video_numerator"], video_timebase_denominator=video_meta["video_denominator"], ) return v_frames, fps, decode_all_video def pyav_decode( container, sampling_rate, num_frames, clip_idx, num_clips=10, target_fps=30, start=None, end=None , duration=None, frames_length=None): """ Convert the video from its original fps to the target_fps. If the video support selective decoding (contain decoding information in the video head), the perform temporal selective decoding and sample a clip from the video with the PyAV decoder. If the video does not support selective decoding, decode the entire video. Args: container (container): pyav container. sampling_rate (int): frame sampling rate (interval between two sampled frames. num_frames (int): number of frames to sample. clip_idx (int): if clip_idx is -1, perform random temporal sampling. If clip_idx is larger than -1, uniformly split the video to num_clips clips, and select the clip_idx-th video clip. num_clips (int): overall number of clips to uniformly sample from the given video. target_fps (int): the input video may has different fps, convert it to the target video fps before frame sampling. Returns: frames (tensor): decoded frames from the video. Return None if the no video stream was found. fps (float): the number of frames per second of the video. decode_all_video (bool): If True, the entire video was decoded. """ # Try to fetch the decoding information from the video head. Some of the # videos does not support fetching the decoding information, for that case # it will get None duration. fps = float(container.streams.video[0].average_rate) orig_duration = duration tb = float(container.streams.video[0].time_base) frames_length = container.streams.video[0].frames duration = container.streams.video[0].duration if duration is None and orig_duration is not None: duration = orig_duration / tb if duration is None: # If failed to fetch the decoding information, decode the entire video. decode_all_video = True video_start_pts, video_end_pts = 0, math.inf else: # Perform selective decoding. decode_all_video = False start_idx, end_idx = get_start_end_idx( frames_length, sampling_rate * num_frames / target_fps * fps, clip_idx, num_clips, ) timebase = duration / frames_length video_start_pts = int(start_idx * timebase) video_end_pts = int(end_idx * timebase) if start is not None and end is not None: decode_all_video = False frames = None # If video stream was found, fetch video frames from the video. if container.streams.video: if start is None and end is None: video_frames, max_pts = pyav_decode_stream( container, video_start_pts, video_end_pts, container.streams.video[0], {"video": 0}, ) else: timebase = duration / frames_length start_i = start end_i = end video_frames, max_pts = pyav_decode_stream( container, start_i, end_i, container.streams.video[0], {"video": 0}, ) container.close() frames = [frame.to_rgb().to_ndarray() for frame in video_frames] frames = torch.as_tensor(np.stack(frames)) return frames, fps, decode_all_video def decode( container, sampling_rate, num_frames, clip_idx=-1, num_clips=10, video_meta=None, target_fps=30, backend="pyav", max_spatial_scale=0, start=None, end=None, duration=None, frames_length=None, ): """ Decode the video and perform temporal sampling. Args: container (container): pyav container. sampling_rate (int): frame sampling rate (interval between two sampled frames). num_frames (int): number of frames to sample. clip_idx (int): if clip_idx is -1, perform random temporal sampling. If clip_idx is larger than -1, uniformly split the video to num_clips clips, and select the clip_idx-th video clip. num_clips (int): overall number of clips to uniformly sample from the given video. video_meta (dict): a dict contains VideoMetaData. Details can be find at `pytorch/vision/torchvision/io/_video_opt.py`. target_fps (int): the input video may have different fps, convert it to the target video fps before frame sampling. backend (str): decoding backend includes `pyav` and `torchvision`. The default one is `pyav`. max_spatial_scale (int): keep the aspect ratio and resize the frame so that shorter edge size is max_spatial_scale. Only used in `torchvision` backend. Returns: frames (tensor): decoded frames from the video. """ # Currently support two decoders: 1) PyAV, and 2) TorchVision. assert clip_idx >= -1, "Not valied clip_idx {}".format(clip_idx) try: if backend == "pyav": frames, fps, decode_all_video = pyav_decode( container, sampling_rate, num_frames, clip_idx, num_clips, target_fps, start, end, duration, frames_length, ) elif backend == "torchvision": frames, fps, decode_all_video = torchvision_decode( container, sampling_rate, num_frames, clip_idx, video_meta, num_clips, target_fps, ("visual",), max_spatial_scale, ) else: raise NotImplementedError( "Unknown decoding backend {}".format(backend) ) except Exception as e: print("Failed to decode by {} with exception: {}".format(backend, e)) return None # Return None if the frames was not decoded successfully. if frames is None or frames.size(0) == 0: return None clip_sz = sampling_rate * num_frames / target_fps * fps start_idx, end_idx = get_start_end_idx( frames.shape[0], clip_sz, clip_idx if decode_all_video else 0, num_clips if decode_all_video else 1, ) # Perform temporal sampling from the decoded video. frames = temporal_sampling(frames, start_idx, end_idx, num_frames) return frames