Spaces:
Paused
Paused
File size: 6,660 Bytes
c6f92cc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 |
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import math
from typing import Sequence
import torch
import torch.nn as nn
from torchvision import transforms
class Permute(nn.Module):
"""
Permutation as an op
"""
def __init__(self, ordering):
super().__init__()
self.ordering = ordering
def forward(self, frames):
"""
Args:
frames in some ordering, by default (C, T, H, W)
Returns:
frames in the ordering that was specified
"""
return frames.permute(self.ordering)
class TemporalCrop(nn.Module):
"""
Convert the video into smaller clips temporally.
"""
def __init__(
self, frames_per_clip: int = 8, stride: int = 8, frame_stride: int = 1
):
super().__init__()
self.frames = frames_per_clip
self.stride = stride
self.frame_stride = frame_stride
def forward(self, video):
assert video.ndim == 4, "Must be (C, T, H, W)"
res = []
for start in range(
0, video.size(1) - (self.frames * self.frame_stride) + 1, self.stride
):
end = start + (self.frames) * self.frame_stride
res.append(video[:, start: end: self.frame_stride, ...])
return res
def crop_boxes(boxes, x_offset, y_offset):
"""
Peform crop on the bounding boxes given the offsets.
Args:
boxes (ndarray or None): bounding boxes to peform crop. The dimension
is `num boxes` x 4.
x_offset (int): cropping offset in the x axis.
y_offset (int): cropping offset in the y axis.
Returns:
cropped_boxes (ndarray or None): the cropped boxes with dimension of
`num boxes` x 4.
"""
cropped_boxes = boxes.copy()
cropped_boxes[:, [0, 2]] = boxes[:, [0, 2]] - x_offset
cropped_boxes[:, [1, 3]] = boxes[:, [1, 3]] - y_offset
return cropped_boxes
def uniform_crop(images, size, spatial_idx, boxes=None, scale_size=None):
"""
Perform uniform spatial sampling on the images and corresponding boxes.
Args:
images (tensor): images to perform uniform crop. The dimension is
`num frames` x `channel` x `height` x `width`.
size (int): size of height and weight to crop the images.
spatial_idx (int): 0, 1, or 2 for left, center, and right crop if width
is larger than height. Or 0, 1, or 2 for top, center, and bottom
crop if height is larger than width.
boxes (ndarray or None): optional. Corresponding boxes to images.
Dimension is `num boxes` x 4.
scale_size (int): optinal. If not None, resize the images to scale_size before
performing any crop.
Returns:
cropped (tensor): images with dimension of
`num frames` x `channel` x `size` x `size`.
cropped_boxes (ndarray or None): the cropped boxes with dimension of
`num boxes` x 4.
"""
assert spatial_idx in [0, 1, 2]
ndim = len(images.shape)
if ndim == 3:
images = images.unsqueeze(0)
height = images.shape[2]
width = images.shape[3]
if scale_size is not None:
if width <= height:
width, height = scale_size, int(height / width * scale_size)
else:
width, height = int(width / height * scale_size), scale_size
images = torch.nn.functional.interpolate(
images,
size=(height, width),
mode="bilinear",
align_corners=False,
)
y_offset = int(math.ceil((height - size) / 2))
x_offset = int(math.ceil((width - size) / 2))
if height > width:
if spatial_idx == 0:
y_offset = 0
elif spatial_idx == 2:
y_offset = height - size
else:
if spatial_idx == 0:
x_offset = 0
elif spatial_idx == 2:
x_offset = width - size
cropped = images[:, :, y_offset: y_offset + size, x_offset: x_offset + size]
cropped_boxes = crop_boxes(boxes, x_offset, y_offset) if boxes is not None else None
if ndim == 3:
cropped = cropped.squeeze(0)
return cropped, cropped_boxes
class SpatialCrop(nn.Module):
"""
Convert the video into 3 smaller clips spatially. Must be used after the
temporal crops to get spatial crops, and should be used with
-2 in the spatial crop at the slowfast augmentation stage (so full
frames are passed in here). Will return a larger list with the
3x spatial crops as well. It's useful for 3x4 testing (eg in SwinT)
or 3x10 testing in SlowFast etc.
"""
def __init__(self, crop_size: int = 224, num_crops: int = 3):
super().__init__()
self.crop_size = crop_size
if num_crops == 6:
self.crops_to_ext = [0, 1, 2]
# I guess Swin uses 5 crops without flipping, but that doesn't
# make sense given they first resize to 224 and take 224 crops.
# (pg 6 of https://arxiv.org/pdf/2106.13230.pdf)
# So I'm assuming we can use flipped crops and that will add sth..
self.flipped_crops_to_ext = [0, 1, 2]
elif num_crops == 3:
self.crops_to_ext = [0, 1, 2]
self.flipped_crops_to_ext = []
elif num_crops == 1:
self.crops_to_ext = [1]
self.flipped_crops_to_ext = []
else:
raise NotImplementedError(
"Nothing else supported yet, "
"slowfast only takes 0, 1, 2 as arguments"
)
def forward(self, videos: Sequence[torch.Tensor]):
"""
Args:
videos: A list of C, T, H, W videos.
Returns:
videos: A list with 3x the number of elements. Each video converted
to C, T, H', W' by spatial cropping.
"""
assert isinstance(videos, list), "Must be a list of videos after temporal crops"
assert all([video.ndim == 4 for video in videos]), "Must be (C,T,H,W)"
res = []
for video in videos:
for spatial_idx in self.crops_to_ext:
res.append(uniform_crop(video, self.crop_size, spatial_idx)[0])
if not self.flipped_crops_to_ext:
continue
flipped_video = transforms.functional.hflip(video)
for spatial_idx in self.flipped_crops_to_ext:
res.append(uniform_crop(flipped_video, self.crop_size, spatial_idx)[0])
return res
|