Spaces:
Runtime error
Runtime error
import math | |
import torch | |
import torch.nn as nn | |
from pytorchvideo import transforms as pv_transforms | |
from pytorchvideo.data.clip_sampling import ConstantClipsPerVideoSampler | |
from pytorchvideo.data.encoded_video import EncodedVideo | |
from pytorchvideo.data.encoded_video_decord import EncodedVideoDecord | |
from torchvision import transforms | |
from torchvision.transforms._transforms_video import NormalizeVideo | |
def get_clip_timepoints(clip_sampler, duration): | |
# Read out all clips in this video | |
all_clips_timepoints = [] | |
is_last_clip = False | |
end = 0.0 | |
while not is_last_clip: | |
start, end, _, _, is_last_clip = clip_sampler(end, duration, annotation=None) | |
all_clips_timepoints.append((start, end)) | |
return all_clips_timepoints | |
def crop_boxes(boxes, x_offset, y_offset): | |
""" | |
Perform crop on the bounding boxes given the offsets. | |
Args: | |
boxes (ndarray or None): bounding boxes to perform crop. The dimension | |
is `num boxes` x 4. | |
x_offset (int): cropping offset in the x axis. | |
y_offset (int): cropping offset in the y axis. | |
Returns: | |
cropped_boxes (ndarray or None): the cropped boxes with dimension of | |
`num boxes` x 4. | |
""" | |
cropped_boxes = boxes.copy() | |
cropped_boxes[:, [0, 2]] = boxes[:, [0, 2]] - x_offset | |
cropped_boxes[:, [1, 3]] = boxes[:, [1, 3]] - y_offset | |
return cropped_boxes | |
def uniform_crop(images, size, spatial_idx, boxes=None, scale_size=None): | |
""" | |
Perform uniform spatial sampling on the images and corresponding boxes. | |
Args: | |
images (tensor): images to perform uniform crop. The dimension is | |
`num frames` x `channel` x `height` x `width`. | |
size (int): size of height and weight to crop the images. | |
spatial_idx (int): 0, 1, or 2 for left, center, and right crop if width | |
is larger than height. Or 0, 1, or 2 for top, center, and bottom | |
crop if height is larger than width. | |
boxes (ndarray or None): optional. Corresponding boxes to images. | |
Dimension is `num boxes` x 4. | |
scale_size (int): optinal. If not None, resize the images to scale_size before | |
performing any crop. | |
Returns: | |
cropped (tensor): images with dimension of | |
`num frames` x `channel` x `size` x `size`. | |
cropped_boxes (ndarray or None): the cropped boxes with dimension of | |
`num boxes` x 4. | |
""" | |
assert spatial_idx in [0, 1, 2] | |
ndim = len(images.shape) | |
if ndim == 3: | |
images = images.unsqueeze(0) | |
height = images.shape[2] | |
width = images.shape[3] | |
if scale_size is not None: | |
if width <= height: | |
width, height = scale_size, int(height / width * scale_size) | |
else: | |
width, height = int(width / height * scale_size), scale_size | |
images = torch.nn.functional.interpolate( | |
images, | |
size=(height, width), | |
mode="bilinear", | |
align_corners=False, | |
) | |
y_offset = int(math.ceil((height - size) / 2)) | |
x_offset = int(math.ceil((width - size) / 2)) | |
if height > width: | |
if spatial_idx == 0: | |
y_offset = 0 | |
elif spatial_idx == 2: | |
y_offset = height - size | |
else: | |
if spatial_idx == 0: | |
x_offset = 0 | |
elif spatial_idx == 2: | |
x_offset = width - size | |
cropped = images[:, :, y_offset : y_offset + size, x_offset : x_offset + size] | |
cropped_boxes = crop_boxes(boxes, x_offset, y_offset) if boxes is not None else None | |
if ndim == 3: | |
cropped = cropped.squeeze(0) | |
return cropped, cropped_boxes | |
class SpatialCrop(nn.Module): | |
""" | |
Convert the video into 3 smaller clips spatially. Must be used after the | |
temporal crops to get spatial crops, and should be used with | |
-2 in the spatial crop at the slowfast augmentation stage (so full | |
frames are passed in here). Will return a larger list with the | |
3x spatial crops as well. | |
""" | |
def __init__(self, crop_size: int = 224, num_crops: int = 3): | |
super().__init__() | |
self.crop_size = crop_size | |
if num_crops == 3: | |
self.crops_to_ext = [0, 1, 2] | |
self.flipped_crops_to_ext = [] | |
elif num_crops == 1: | |
self.crops_to_ext = [1] | |
self.flipped_crops_to_ext = [] | |
else: | |
raise NotImplementedError("Nothing else supported yet") | |
def forward(self, videos): | |
""" | |
Args: | |
videos: A list of C, T, H, W videos. | |
Returns: | |
videos: A list with 3x the number of elements. Each video converted | |
to C, T, H', W' by spatial cropping. | |
""" | |
assert isinstance(videos, list), "Must be a list of videos after temporal crops" | |
assert all([video.ndim == 4 for video in videos]), "Must be (C,T,H,W)" | |
res = [] | |
for video in videos: | |
for spatial_idx in self.crops_to_ext: | |
res.append(uniform_crop(video, self.crop_size, spatial_idx)[0]) | |
if not self.flipped_crops_to_ext: | |
continue | |
flipped_video = transforms.functional.hflip(video) | |
for spatial_idx in self.flipped_crops_to_ext: | |
res.append(uniform_crop(flipped_video, self.crop_size, spatial_idx)[0]) | |
return res | |
def load_and_transform_video_data( | |
video_file, | |
video_path, | |
clip_duration=2, | |
clips_per_video=5, | |
sample_rate=16000, | |
with_audio=False | |
): | |
video_transform = transforms.Compose( | |
[ | |
pv_transforms.ShortSideScale(224), | |
NormalizeVideo( | |
mean=(0.48145466, 0.4578275, 0.40821073), | |
std=(0.26862954, 0.26130258, 0.27577711), | |
), | |
] | |
) | |
clip_sampler = ConstantClipsPerVideoSampler( | |
clip_duration=clip_duration, clips_per_video=clips_per_video | |
) | |
frame_sampler = pv_transforms.UniformTemporalSubsample(num_samples=clip_duration) | |
if isinstance(video_file, str): | |
video = EncodedVideo.from_path( | |
video_file, | |
decoder="decord", | |
decode_audio=with_audio, | |
# **{"sample_rate": sample_rate}, | |
) | |
else: | |
video = EncodedVideoDecord(video_file, video_name=video_path, decode_video=True, decode_audio=with_audio, sample_rate=sample_rate) | |
all_clips_timepoints = get_clip_timepoints(clip_sampler, video.duration) | |
all_video = [] | |
for clip_timepoints in all_clips_timepoints: | |
# Read the clip, get frames | |
clip = video.get_clip(clip_timepoints[0], clip_timepoints[1]) | |
if clip is None: | |
raise ValueError("No clip found") | |
video_clip = frame_sampler(clip["video"]) | |
video_clip = video_clip / 255.0 # since this is float, need 0-1 | |
all_video.append(video_clip) | |
all_video = [video_transform(clip) for clip in all_video] | |
all_video = SpatialCrop(224, num_crops=3)(all_video) | |
all_video = torch.stack(all_video, dim=0) | |
if not with_audio: | |
return all_video | |
else: | |
return all_video, clip['audio'] | |
if __name__ == '__main__': | |
video_path = "datasets/InstructionTuning/video/music_aqa/MUSIC-AVQA-videos-Real/00000002.mp4" | |
video, audio = load_and_transform_video_data(video_path, video_path, clip_duration=1, clips_per_video=5, with_audio=True) | |
import pdb;pdb.set_trace() |