mshukor
init
3eb682b
raw
history blame
15.4 kB
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
import math
import numpy as np
import random
import torch
import torchvision.io as io
def temporal_sampling(frames, start_idx, end_idx, num_samples):
"""
Given the start and end frame index, sample num_samples frames between
the start and end with equal interval.
Args:
frames (tensor): a tensor of video frames, dimension is
`num video frames` x `channel` x `height` x `width`.
start_idx (int): the index of the start frame.
end_idx (int): the index of the end frame.
num_samples (int): number of frames to sample.
Returns:
frames (tersor): a tensor of temporal sampled video frames, dimension is
`num clip frames` x `channel` x `height` x `width`.
"""
index = torch.linspace(start_idx, end_idx, num_samples)
index = torch.clamp(index, 0, frames.shape[0] - 1).long()
frames = torch.index_select(frames, 0, index)
return frames
def get_start_end_idx(video_size, clip_size, clip_idx, num_clips):
"""
Sample a clip of size clip_size from a video of size video_size and
return the indices of the first and last frame of the clip. If clip_idx is
-1, the clip is randomly sampled, otherwise uniformly split the video to
num_clips clips, and select the start and end index of clip_idx-th video
clip.
Args:
video_size (int): number of overall frames.
clip_size (int): size of the clip to sample from the frames.
clip_idx (int): if clip_idx is -1, perform random jitter sampling. If
clip_idx is larger than -1, uniformly split the video to num_clips
clips, and select the start and end index of the clip_idx-th video
clip.
num_clips (int): overall number of clips to uniformly sample from the
given video for testing.
Returns:
start_idx (int): the start frame index.
end_idx (int): the end frame index.
"""
delta = max(video_size - clip_size, 0)
if clip_idx == -1:
# Random temporal sampling.
start_idx = random.uniform(0, delta)
else:
# Uniformly sample the clip with the given index.
start_idx = delta * clip_idx / num_clips
end_idx = start_idx + clip_size - 1
return start_idx, end_idx
def pyav_decode_stream(
container, start_pts, end_pts, stream, stream_name, buffer_size=0
):
"""
Decode the video with PyAV decoder.
Args:
container (container): PyAV container.
start_pts (int): the starting Presentation TimeStamp to fetch the
video frames.
end_pts (int): the ending Presentation TimeStamp of the decoded frames.
stream (stream): PyAV stream.
stream_name (dict): a dictionary of streams. For example, {"video": 0}
means video stream at stream index 0.
buffer_size (int): number of additional frames to decode beyond end_pts.
Returns:
result (list): list of frames decoded.
max_pts (int): max Presentation TimeStamp of the video sequence.
"""
# Seeking in the stream is imprecise. Thus, seek to an ealier PTS by a
# margin pts.
margin = 1024
seek_offset = max(start_pts - margin, 0)
container.seek(seek_offset, any_frame=False, backward=True, stream=stream)
frames = {}
buffer_count = 0
max_pts = 0
for frame in container.decode(**stream_name):
max_pts = max(max_pts, frame.pts)
if frame.pts < start_pts:
continue
if frame.pts <= end_pts:
frames[frame.pts] = frame
else:
buffer_count += 1
frames[frame.pts] = frame
if buffer_count >= buffer_size:
break
result = [frames[pts] for pts in sorted(frames)]
return result, max_pts
def torchvision_decode(
video_handle,
sampling_rate,
num_frames,
clip_idx,
video_meta,
num_clips=10,
target_fps=30,
modalities=("visual",),
max_spatial_scale=0,
):
"""
If video_meta is not empty, perform temporal selective decoding to sample a
clip from the video with TorchVision decoder. If video_meta is empty, decode
the entire video and update the video_meta.
Args:
video_handle (bytes): raw bytes of the video file.
sampling_rate (int): frame sampling rate (interval between two sampled
frames).
num_frames (int): number of frames to sample.
clip_idx (int): if clip_idx is -1, perform random temporal
sampling. If clip_idx is larger than -1, uniformly split the
video to num_clips clips, and select the clip_idx-th video clip.
video_meta (dict): a dict contains VideoMetaData. Details can be found
at `pytorch/vision/torchvision/io/_video_opt.py`.
num_clips (int): overall number of clips to uniformly sample from the
given video.
target_fps (int): the input video may has different fps, convert it to
the target video fps.
modalities (tuple): tuple of modalities to decode. Currently only
support `visual`, planning to support `acoustic` soon.
max_spatial_scale (int): the maximal resolution of the spatial shorter
edge size during decoding.
Returns:
frames (tensor): decoded frames from the video.
fps (float): the number of frames per second of the video.
decode_all_video (bool): if True, the entire video was decoded.
"""
# Convert the bytes to a tensor.
video_tensor = torch.from_numpy(np.frombuffer(video_handle, dtype=np.uint8))
decode_all_video = True
video_start_pts, video_end_pts = 0, -1
# The video_meta is empty, fetch the meta data from the raw video.
if len(video_meta) == 0:
# Tracking the meta info for selective decoding in the future.
meta = io._probe_video_from_memory(video_tensor)
# Using the information from video_meta to perform selective decoding.
video_meta["video_timebase"] = meta.video_timebase
video_meta["video_numerator"] = meta.video_timebase.numerator
video_meta["video_denominator"] = meta.video_timebase.denominator
video_meta["has_video"] = meta.has_video
video_meta["video_duration"] = meta.video_duration
video_meta["video_fps"] = meta.video_fps
video_meta["audio_timebas"] = meta.audio_timebase
video_meta["audio_numerator"] = meta.audio_timebase.numerator
video_meta["audio_denominator"] = meta.audio_timebase.denominator
video_meta["has_audio"] = meta.has_audio
video_meta["audio_duration"] = meta.audio_duration
video_meta["audio_sample_rate"] = meta.audio_sample_rate
fps = video_meta["video_fps"]
if (
video_meta["has_video"]
and video_meta["video_denominator"] > 0
and video_meta["video_duration"] > 0
):
# try selective decoding.
decode_all_video = False
clip_size = sampling_rate * num_frames / target_fps * fps
start_idx, end_idx = get_start_end_idx(
fps * video_meta["video_duration"], clip_size, clip_idx, num_clips
)
# Convert frame index to pts.
pts_per_frame = video_meta["video_denominator"] / fps
video_start_pts = int(start_idx * pts_per_frame)
video_end_pts = int(end_idx * pts_per_frame)
# Decode the raw video with the tv decoder.
v_frames, _ = io._read_video_from_memory(
video_tensor,
seek_frame_margin=1.0,
read_video_stream="visual" in modalities,
video_width=0,
video_height=0,
video_min_dimension=max_spatial_scale,
video_pts_range=(video_start_pts, video_end_pts),
video_timebase_numerator=video_meta["video_numerator"],
video_timebase_denominator=video_meta["video_denominator"],
)
if v_frames.shape == torch.Size([0]):
# failed selective decoding
decode_all_video = True
video_start_pts, video_end_pts = 0, -1
v_frames, _ = io._read_video_from_memory(
video_tensor,
seek_frame_margin=1.0,
read_video_stream="visual" in modalities,
video_width=0,
video_height=0,
video_min_dimension=max_spatial_scale,
video_pts_range=(video_start_pts, video_end_pts),
video_timebase_numerator=video_meta["video_numerator"],
video_timebase_denominator=video_meta["video_denominator"],
)
return v_frames, fps, decode_all_video
def pyav_decode(
container, sampling_rate, num_frames, clip_idx, num_clips=10, target_fps=30, start=None, end=None
, duration=None, frames_length=None):
"""
Convert the video from its original fps to the target_fps. If the video
support selective decoding (contain decoding information in the video head),
the perform temporal selective decoding and sample a clip from the video
with the PyAV decoder. If the video does not support selective decoding,
decode the entire video.
Args:
container (container): pyav container.
sampling_rate (int): frame sampling rate (interval between two sampled
frames.
num_frames (int): number of frames to sample.
clip_idx (int): if clip_idx is -1, perform random temporal sampling. If
clip_idx is larger than -1, uniformly split the video to num_clips
clips, and select the clip_idx-th video clip.
num_clips (int): overall number of clips to uniformly sample from the
given video.
target_fps (int): the input video may has different fps, convert it to
the target video fps before frame sampling.
Returns:
frames (tensor): decoded frames from the video. Return None if the no
video stream was found.
fps (float): the number of frames per second of the video.
decode_all_video (bool): If True, the entire video was decoded.
"""
# Try to fetch the decoding information from the video head. Some of the
# videos does not support fetching the decoding information, for that case
# it will get None duration.
fps = float(container.streams.video[0].average_rate)
orig_duration = duration
tb = float(container.streams.video[0].time_base)
frames_length = container.streams.video[0].frames
duration = container.streams.video[0].duration
if duration is None and orig_duration is not None:
duration = orig_duration / tb
if duration is None:
# If failed to fetch the decoding information, decode the entire video.
decode_all_video = True
video_start_pts, video_end_pts = 0, math.inf
else:
# Perform selective decoding.
decode_all_video = False
start_idx, end_idx = get_start_end_idx(
frames_length,
sampling_rate * num_frames / target_fps * fps,
clip_idx,
num_clips,
)
timebase = duration / frames_length
video_start_pts = int(start_idx * timebase)
video_end_pts = int(end_idx * timebase)
if start is not None and end is not None:
decode_all_video = False
frames = None
# If video stream was found, fetch video frames from the video.
if container.streams.video:
if start is None and end is None:
video_frames, max_pts = pyav_decode_stream(
container,
video_start_pts,
video_end_pts,
container.streams.video[0],
{"video": 0},
)
else:
timebase = duration / frames_length
start_i = start
end_i = end
video_frames, max_pts = pyav_decode_stream(
container,
start_i,
end_i,
container.streams.video[0],
{"video": 0},
)
container.close()
frames = [frame.to_rgb().to_ndarray() for frame in video_frames]
frames = torch.as_tensor(np.stack(frames))
return frames, fps, decode_all_video
def decode(
container,
sampling_rate,
num_frames,
clip_idx=-1,
num_clips=10,
video_meta=None,
target_fps=30,
backend="pyav",
max_spatial_scale=0,
start=None,
end=None,
duration=None,
frames_length=None,
):
"""
Decode the video and perform temporal sampling.
Args:
container (container): pyav container.
sampling_rate (int): frame sampling rate (interval between two sampled
frames).
num_frames (int): number of frames to sample.
clip_idx (int): if clip_idx is -1, perform random temporal
sampling. If clip_idx is larger than -1, uniformly split the
video to num_clips clips, and select the
clip_idx-th video clip.
num_clips (int): overall number of clips to uniformly
sample from the given video.
video_meta (dict): a dict contains VideoMetaData. Details can be find
at `pytorch/vision/torchvision/io/_video_opt.py`.
target_fps (int): the input video may have different fps, convert it to
the target video fps before frame sampling.
backend (str): decoding backend includes `pyav` and `torchvision`. The
default one is `pyav`.
max_spatial_scale (int): keep the aspect ratio and resize the frame so
that shorter edge size is max_spatial_scale. Only used in
`torchvision` backend.
Returns:
frames (tensor): decoded frames from the video.
"""
# Currently support two decoders: 1) PyAV, and 2) TorchVision.
assert clip_idx >= -1, "Not valied clip_idx {}".format(clip_idx)
try:
if backend == "pyav":
frames, fps, decode_all_video = pyav_decode(
container,
sampling_rate,
num_frames,
clip_idx,
num_clips,
target_fps,
start,
end,
duration,
frames_length,
)
elif backend == "torchvision":
frames, fps, decode_all_video = torchvision_decode(
container,
sampling_rate,
num_frames,
clip_idx,
video_meta,
num_clips,
target_fps,
("visual",),
max_spatial_scale,
)
else:
raise NotImplementedError(
"Unknown decoding backend {}".format(backend)
)
except Exception as e:
print("Failed to decode by {} with exception: {}".format(backend, e))
return None
# Return None if the frames was not decoded successfully.
if frames is None or frames.size(0) == 0:
return None
clip_sz = sampling_rate * num_frames / target_fps * fps
start_idx, end_idx = get_start_end_idx(
frames.shape[0],
clip_sz,
clip_idx if decode_all_video else 0,
num_clips if decode_all_video else 1,
)
# Perform temporal sampling from the decoded video.
frames = temporal_sampling(frames, start_idx, end_idx, num_frames)
return frames