videomae-vis / src /augmentations.py
SauravMaheshkar's picture
chore: refactor src
1d4cc3a unverified
raw
history blame
4.04 kB
from typing import List, Optional, Tuple, Union
import numpy as np
import torch
from PIL import Image
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from torchvision import transforms
class GroupNormalize:
def __init__(self, mean: List[float], std: List[float]) -> None:
self.mean = mean
self.std = std
def __call__(
self, tensor_tuple: Tuple[torch.Tensor, torch.Tensor]
) -> Tuple[torch.Tensor, torch.Tensor]:
tensor, label = tensor_tuple
rep_mean = self.mean * (tensor.size()[0] // len(self.mean))
rep_std = self.std * (tensor.size()[0] // len(self.std))
for t, m, s in zip(tensor, rep_mean, rep_std):
t.sub_(m).div_(s)
return tensor, label
class GroupCenterCrop:
def __init__(self, size: int) -> None:
self.worker = transforms.CenterCrop(size)
def __call__(
self, img_tuple: Tuple[torch.Tensor, torch.Tensor]
) -> Tuple[List[torch.Tensor], torch.Tensor]:
img_group, label = img_tuple
return [self.worker(img) for img in img_group], label
class Stack:
def __init__(self, roll: Optional[bool] = False) -> None:
self.roll = roll
def __call__(self, img_tuple: Tuple[torch.Tensor, torch.Tensor]):
img_group, label = img_tuple
if img_group[0].mode == "L":
return (
np.concatenate([np.expand_dims(x, 2) for x in img_group], axis=2),
label,
)
elif img_group[0].mode == "RGB":
if self.roll:
return (
np.concatenate(
[np.array(x)[:, :, ::-1] for x in img_group], axis=2
),
label,
)
else:
return np.concatenate(img_group, axis=2), label
class ToTorchFormatTensor:
def __init__(self, div: Optional[bool] = True) -> None:
self.div = div
def __call__(
self, pic_tuple: Tuple[Union[np.ndarray, torch.Tensor], torch.Tensor]
) -> Tuple[torch.Tensor, torch.Tensor]:
pic, label = pic_tuple
if isinstance(pic, np.ndarray):
# handle numpy array
img = torch.from_numpy(pic).permute(2, 0, 1).contiguous()
elif isinstance(pic, Image.Image):
# handle PIL Image
img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes()))
img = img.view(pic.size[1], pic.size[0], len(pic.mode))
# put it from HWC to CHW format
# yikes, this transpose takes 80% of the loading time/CPU
img = img.transpose(0, 1).transpose(0, 2).contiguous()
else:
raise TypeError(
f"Unsupported type {type(pic)} must be np.ndarray or torch.Tensor"
)
return img.float().div(255.0) if self.div else img.float(), label
class TubeMaskingGenerator:
def __init__(self, input_size: Tuple[int, int, int], mask_ratio: float) -> None:
self.frames, self.height, self.width = input_size
self.num_patches_per_frame = self.height * self.width
self.total_patches = self.frames * self.num_patches_per_frame
self.num_masks_per_frame = int(mask_ratio * self.num_patches_per_frame)
self.total_masks = self.frames * self.num_masks_per_frame
def __call__(self):
mask_per_frame = np.hstack(
[
np.zeros(self.num_patches_per_frame - self.num_masks_per_frame),
np.ones(self.num_masks_per_frame),
]
)
np.random.shuffle(mask_per_frame)
mask = np.tile(mask_per_frame, (self.frames, 1)).flatten()
return mask
def get_videomae_transform(input_size: int = 224) -> "transforms.Compose":
return transforms.Compose(
[
GroupCenterCrop(input_size),
Stack(roll=False),
ToTorchFormatTensor(div=True),
GroupNormalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
]
)