File size: 15,362 Bytes
3eb682b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 |
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
import math
import numpy as np
import random
import torch
import torchvision.io as io
def temporal_sampling(frames, start_idx, end_idx, num_samples):
"""
Given the start and end frame index, sample num_samples frames between
the start and end with equal interval.
Args:
frames (tensor): a tensor of video frames, dimension is
`num video frames` x `channel` x `height` x `width`.
start_idx (int): the index of the start frame.
end_idx (int): the index of the end frame.
num_samples (int): number of frames to sample.
Returns:
frames (tersor): a tensor of temporal sampled video frames, dimension is
`num clip frames` x `channel` x `height` x `width`.
"""
index = torch.linspace(start_idx, end_idx, num_samples)
index = torch.clamp(index, 0, frames.shape[0] - 1).long()
frames = torch.index_select(frames, 0, index)
return frames
def get_start_end_idx(video_size, clip_size, clip_idx, num_clips):
"""
Sample a clip of size clip_size from a video of size video_size and
return the indices of the first and last frame of the clip. If clip_idx is
-1, the clip is randomly sampled, otherwise uniformly split the video to
num_clips clips, and select the start and end index of clip_idx-th video
clip.
Args:
video_size (int): number of overall frames.
clip_size (int): size of the clip to sample from the frames.
clip_idx (int): if clip_idx is -1, perform random jitter sampling. If
clip_idx is larger than -1, uniformly split the video to num_clips
clips, and select the start and end index of the clip_idx-th video
clip.
num_clips (int): overall number of clips to uniformly sample from the
given video for testing.
Returns:
start_idx (int): the start frame index.
end_idx (int): the end frame index.
"""
delta = max(video_size - clip_size, 0)
if clip_idx == -1:
# Random temporal sampling.
start_idx = random.uniform(0, delta)
else:
# Uniformly sample the clip with the given index.
start_idx = delta * clip_idx / num_clips
end_idx = start_idx + clip_size - 1
return start_idx, end_idx
def pyav_decode_stream(
container, start_pts, end_pts, stream, stream_name, buffer_size=0
):
"""
Decode the video with PyAV decoder.
Args:
container (container): PyAV container.
start_pts (int): the starting Presentation TimeStamp to fetch the
video frames.
end_pts (int): the ending Presentation TimeStamp of the decoded frames.
stream (stream): PyAV stream.
stream_name (dict): a dictionary of streams. For example, {"video": 0}
means video stream at stream index 0.
buffer_size (int): number of additional frames to decode beyond end_pts.
Returns:
result (list): list of frames decoded.
max_pts (int): max Presentation TimeStamp of the video sequence.
"""
# Seeking in the stream is imprecise. Thus, seek to an ealier PTS by a
# margin pts.
margin = 1024
seek_offset = max(start_pts - margin, 0)
container.seek(seek_offset, any_frame=False, backward=True, stream=stream)
frames = {}
buffer_count = 0
max_pts = 0
for frame in container.decode(**stream_name):
max_pts = max(max_pts, frame.pts)
if frame.pts < start_pts:
continue
if frame.pts <= end_pts:
frames[frame.pts] = frame
else:
buffer_count += 1
frames[frame.pts] = frame
if buffer_count >= buffer_size:
break
result = [frames[pts] for pts in sorted(frames)]
return result, max_pts
def torchvision_decode(
video_handle,
sampling_rate,
num_frames,
clip_idx,
video_meta,
num_clips=10,
target_fps=30,
modalities=("visual",),
max_spatial_scale=0,
):
"""
If video_meta is not empty, perform temporal selective decoding to sample a
clip from the video with TorchVision decoder. If video_meta is empty, decode
the entire video and update the video_meta.
Args:
video_handle (bytes): raw bytes of the video file.
sampling_rate (int): frame sampling rate (interval between two sampled
frames).
num_frames (int): number of frames to sample.
clip_idx (int): if clip_idx is -1, perform random temporal
sampling. If clip_idx is larger than -1, uniformly split the
video to num_clips clips, and select the clip_idx-th video clip.
video_meta (dict): a dict contains VideoMetaData. Details can be found
at `pytorch/vision/torchvision/io/_video_opt.py`.
num_clips (int): overall number of clips to uniformly sample from the
given video.
target_fps (int): the input video may has different fps, convert it to
the target video fps.
modalities (tuple): tuple of modalities to decode. Currently only
support `visual`, planning to support `acoustic` soon.
max_spatial_scale (int): the maximal resolution of the spatial shorter
edge size during decoding.
Returns:
frames (tensor): decoded frames from the video.
fps (float): the number of frames per second of the video.
decode_all_video (bool): if True, the entire video was decoded.
"""
# Convert the bytes to a tensor.
video_tensor = torch.from_numpy(np.frombuffer(video_handle, dtype=np.uint8))
decode_all_video = True
video_start_pts, video_end_pts = 0, -1
# The video_meta is empty, fetch the meta data from the raw video.
if len(video_meta) == 0:
# Tracking the meta info for selective decoding in the future.
meta = io._probe_video_from_memory(video_tensor)
# Using the information from video_meta to perform selective decoding.
video_meta["video_timebase"] = meta.video_timebase
video_meta["video_numerator"] = meta.video_timebase.numerator
video_meta["video_denominator"] = meta.video_timebase.denominator
video_meta["has_video"] = meta.has_video
video_meta["video_duration"] = meta.video_duration
video_meta["video_fps"] = meta.video_fps
video_meta["audio_timebas"] = meta.audio_timebase
video_meta["audio_numerator"] = meta.audio_timebase.numerator
video_meta["audio_denominator"] = meta.audio_timebase.denominator
video_meta["has_audio"] = meta.has_audio
video_meta["audio_duration"] = meta.audio_duration
video_meta["audio_sample_rate"] = meta.audio_sample_rate
fps = video_meta["video_fps"]
if (
video_meta["has_video"]
and video_meta["video_denominator"] > 0
and video_meta["video_duration"] > 0
):
# try selective decoding.
decode_all_video = False
clip_size = sampling_rate * num_frames / target_fps * fps
start_idx, end_idx = get_start_end_idx(
fps * video_meta["video_duration"], clip_size, clip_idx, num_clips
)
# Convert frame index to pts.
pts_per_frame = video_meta["video_denominator"] / fps
video_start_pts = int(start_idx * pts_per_frame)
video_end_pts = int(end_idx * pts_per_frame)
# Decode the raw video with the tv decoder.
v_frames, _ = io._read_video_from_memory(
video_tensor,
seek_frame_margin=1.0,
read_video_stream="visual" in modalities,
video_width=0,
video_height=0,
video_min_dimension=max_spatial_scale,
video_pts_range=(video_start_pts, video_end_pts),
video_timebase_numerator=video_meta["video_numerator"],
video_timebase_denominator=video_meta["video_denominator"],
)
if v_frames.shape == torch.Size([0]):
# failed selective decoding
decode_all_video = True
video_start_pts, video_end_pts = 0, -1
v_frames, _ = io._read_video_from_memory(
video_tensor,
seek_frame_margin=1.0,
read_video_stream="visual" in modalities,
video_width=0,
video_height=0,
video_min_dimension=max_spatial_scale,
video_pts_range=(video_start_pts, video_end_pts),
video_timebase_numerator=video_meta["video_numerator"],
video_timebase_denominator=video_meta["video_denominator"],
)
return v_frames, fps, decode_all_video
def pyav_decode(
container, sampling_rate, num_frames, clip_idx, num_clips=10, target_fps=30, start=None, end=None
, duration=None, frames_length=None):
"""
Convert the video from its original fps to the target_fps. If the video
support selective decoding (contain decoding information in the video head),
the perform temporal selective decoding and sample a clip from the video
with the PyAV decoder. If the video does not support selective decoding,
decode the entire video.
Args:
container (container): pyav container.
sampling_rate (int): frame sampling rate (interval between two sampled
frames.
num_frames (int): number of frames to sample.
clip_idx (int): if clip_idx is -1, perform random temporal sampling. If
clip_idx is larger than -1, uniformly split the video to num_clips
clips, and select the clip_idx-th video clip.
num_clips (int): overall number of clips to uniformly sample from the
given video.
target_fps (int): the input video may has different fps, convert it to
the target video fps before frame sampling.
Returns:
frames (tensor): decoded frames from the video. Return None if the no
video stream was found.
fps (float): the number of frames per second of the video.
decode_all_video (bool): If True, the entire video was decoded.
"""
# Try to fetch the decoding information from the video head. Some of the
# videos does not support fetching the decoding information, for that case
# it will get None duration.
fps = float(container.streams.video[0].average_rate)
orig_duration = duration
tb = float(container.streams.video[0].time_base)
frames_length = container.streams.video[0].frames
duration = container.streams.video[0].duration
if duration is None and orig_duration is not None:
duration = orig_duration / tb
if duration is None:
# If failed to fetch the decoding information, decode the entire video.
decode_all_video = True
video_start_pts, video_end_pts = 0, math.inf
else:
# Perform selective decoding.
decode_all_video = False
start_idx, end_idx = get_start_end_idx(
frames_length,
sampling_rate * num_frames / target_fps * fps,
clip_idx,
num_clips,
)
timebase = duration / frames_length
video_start_pts = int(start_idx * timebase)
video_end_pts = int(end_idx * timebase)
if start is not None and end is not None:
decode_all_video = False
frames = None
# If video stream was found, fetch video frames from the video.
if container.streams.video:
if start is None and end is None:
video_frames, max_pts = pyav_decode_stream(
container,
video_start_pts,
video_end_pts,
container.streams.video[0],
{"video": 0},
)
else:
timebase = duration / frames_length
start_i = start
end_i = end
video_frames, max_pts = pyav_decode_stream(
container,
start_i,
end_i,
container.streams.video[0],
{"video": 0},
)
container.close()
frames = [frame.to_rgb().to_ndarray() for frame in video_frames]
frames = torch.as_tensor(np.stack(frames))
return frames, fps, decode_all_video
def decode(
container,
sampling_rate,
num_frames,
clip_idx=-1,
num_clips=10,
video_meta=None,
target_fps=30,
backend="pyav",
max_spatial_scale=0,
start=None,
end=None,
duration=None,
frames_length=None,
):
"""
Decode the video and perform temporal sampling.
Args:
container (container): pyav container.
sampling_rate (int): frame sampling rate (interval between two sampled
frames).
num_frames (int): number of frames to sample.
clip_idx (int): if clip_idx is -1, perform random temporal
sampling. If clip_idx is larger than -1, uniformly split the
video to num_clips clips, and select the
clip_idx-th video clip.
num_clips (int): overall number of clips to uniformly
sample from the given video.
video_meta (dict): a dict contains VideoMetaData. Details can be find
at `pytorch/vision/torchvision/io/_video_opt.py`.
target_fps (int): the input video may have different fps, convert it to
the target video fps before frame sampling.
backend (str): decoding backend includes `pyav` and `torchvision`. The
default one is `pyav`.
max_spatial_scale (int): keep the aspect ratio and resize the frame so
that shorter edge size is max_spatial_scale. Only used in
`torchvision` backend.
Returns:
frames (tensor): decoded frames from the video.
"""
# Currently support two decoders: 1) PyAV, and 2) TorchVision.
assert clip_idx >= -1, "Not valied clip_idx {}".format(clip_idx)
try:
if backend == "pyav":
frames, fps, decode_all_video = pyav_decode(
container,
sampling_rate,
num_frames,
clip_idx,
num_clips,
target_fps,
start,
end,
duration,
frames_length,
)
elif backend == "torchvision":
frames, fps, decode_all_video = torchvision_decode(
container,
sampling_rate,
num_frames,
clip_idx,
video_meta,
num_clips,
target_fps,
("visual",),
max_spatial_scale,
)
else:
raise NotImplementedError(
"Unknown decoding backend {}".format(backend)
)
except Exception as e:
print("Failed to decode by {} with exception: {}".format(backend, e))
return None
# Return None if the frames was not decoded successfully.
if frames is None or frames.size(0) == 0:
return None
clip_sz = sampling_rate * num_frames / target_fps * fps
start_idx, end_idx = get_start_end_idx(
frames.shape[0],
clip_sz,
clip_idx if decode_all_video else 0,
num_clips if decode_all_video else 1,
)
# Perform temporal sampling from the decoded video.
frames = temporal_sampling(frames, start_idx, end_idx, num_frames)
return frames
|