videomae-vis / src /augmentations.py
SauravMaheshkar's picture
feat: use resize transform
30d5854 unverified
raw
history blame
4.43 kB
from typing import List, Optional, Tuple, Union
import numpy as np
import torch
from PIL import Image
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from torchvision import transforms
class GroupResize:
def __init__(self, size: int = 256) -> None:
self.transform = transforms.Resize(size)
def __call__(
self, img_tuple: Tuple[torch.Tensor, torch.Tensor]
) -> Tuple[List[torch.Tensor], torch.Tensor]:
img_group, label = img_tuple
return [self.transform(img) for img in img_group], label
class GroupNormalize:
def __init__(self, mean: List[float], std: List[float]) -> None:
self.mean = mean
self.std = std
def __call__(
self, tensor_tuple: Tuple[torch.Tensor, torch.Tensor]
) -> Tuple[torch.Tensor, torch.Tensor]:
tensor, label = tensor_tuple
rep_mean = self.mean * (tensor.size()[0] // len(self.mean))
rep_std = self.std * (tensor.size()[0] // len(self.std))
for t, m, s in zip(tensor, rep_mean, rep_std):
t.sub_(m).div_(s)
return tensor, label
class GroupCenterCrop:
def __init__(self, size: int) -> None:
self.worker = transforms.CenterCrop(size)
def __call__(
self, img_tuple: Tuple[torch.Tensor, torch.Tensor]
) -> Tuple[List[torch.Tensor], torch.Tensor]:
img_group, label = img_tuple
return [self.worker(img) for img in img_group], label
class Stack:
def __init__(self, roll: Optional[bool] = False) -> None:
self.roll = roll
def __call__(self, img_tuple: Tuple[torch.Tensor, torch.Tensor]):
img_group, label = img_tuple
if img_group[0].mode == "L":
return (
np.concatenate([np.expand_dims(x, 2) for x in img_group], axis=2),
label,
)
elif img_group[0].mode == "RGB":
if self.roll:
return (
np.concatenate(
[np.array(x)[:, :, ::-1] for x in img_group], axis=2
),
label,
)
else:
return np.concatenate(img_group, axis=2), label
class ToTorchFormatTensor:
def __init__(self, div: Optional[bool] = True) -> None:
self.div = div
def __call__(
self, pic_tuple: Tuple[Union[np.ndarray, torch.Tensor], torch.Tensor]
) -> Tuple[torch.Tensor, torch.Tensor]:
pic, label = pic_tuple
if isinstance(pic, np.ndarray):
# handle numpy array
img = torch.from_numpy(pic).permute(2, 0, 1).contiguous()
elif isinstance(pic, Image.Image):
# handle PIL Image
img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes()))
img = img.view(pic.size[1], pic.size[0], len(pic.mode))
# put it from HWC to CHW format
# yikes, this transpose takes 80% of the loading time/CPU
img = img.transpose(0, 1).transpose(0, 2).contiguous()
else:
raise TypeError(
f"Unsupported type {type(pic)} must be np.ndarray or torch.Tensor"
)
return img.float().div(255.0) if self.div else img.float(), label
class TubeMaskingGenerator:
def __init__(self, input_size: Tuple[int, int, int], mask_ratio: float) -> None:
self.frames, self.height, self.width = input_size
self.num_patches_per_frame = self.height * self.width
self.total_patches = self.frames * self.num_patches_per_frame
self.num_masks_per_frame = int(mask_ratio * self.num_patches_per_frame)
self.total_masks = self.frames * self.num_masks_per_frame
def __call__(self):
mask_per_frame = np.hstack(
[
np.zeros(self.num_patches_per_frame - self.num_masks_per_frame),
np.ones(self.num_masks_per_frame),
]
)
np.random.shuffle(mask_per_frame)
mask = np.tile(mask_per_frame, (self.frames, 1)).flatten()
return mask
def get_videomae_transform(input_size: int = 224) -> "transforms.Compose":
return transforms.Compose(
[
GroupResize(size=384),
GroupCenterCrop(input_size),
Stack(roll=False),
ToTorchFormatTensor(div=True),
GroupNormalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
]
)