from typing import List, Optional, Tuple, Union import numpy as np import torch from PIL import Image from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from torchvision import transforms class GroupNormalize: def __init__(self, mean: List[float], std: List[float]) -> None: self.mean = mean self.std = std def __call__( self, tensor_tuple: Tuple[torch.Tensor, torch.Tensor] ) -> Tuple[torch.Tensor, torch.Tensor]: tensor, label = tensor_tuple rep_mean = self.mean * (tensor.size()[0] // len(self.mean)) rep_std = self.std * (tensor.size()[0] // len(self.std)) for t, m, s in zip(tensor, rep_mean, rep_std): t.sub_(m).div_(s) return tensor, label class GroupCenterCrop: def __init__(self, size: int) -> None: self.worker = transforms.CenterCrop(size) def __call__( self, img_tuple: Tuple[torch.Tensor, torch.Tensor] ) -> Tuple[List[torch.Tensor], torch.Tensor]: img_group, label = img_tuple return [self.worker(img) for img in img_group], label class Stack: def __init__(self, roll: Optional[bool] = False) -> None: self.roll = roll def __call__(self, img_tuple: Tuple[torch.Tensor, torch.Tensor]): img_group, label = img_tuple if img_group[0].mode == "L": return ( np.concatenate([np.expand_dims(x, 2) for x in img_group], axis=2), label, ) elif img_group[0].mode == "RGB": if self.roll: return ( np.concatenate( [np.array(x)[:, :, ::-1] for x in img_group], axis=2 ), label, ) else: return np.concatenate(img_group, axis=2), label class ToTorchFormatTensor: def __init__(self, div: Optional[bool] = True) -> None: self.div = div def __call__( self, pic_tuple: Tuple[Union[np.ndarray, torch.Tensor], torch.Tensor] ) -> Tuple[torch.Tensor, torch.Tensor]: pic, label = pic_tuple if isinstance(pic, np.ndarray): # handle numpy array img = torch.from_numpy(pic).permute(2, 0, 1).contiguous() elif isinstance(pic, Image.Image): # handle PIL Image img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes())) img = img.view(pic.size[1], pic.size[0], len(pic.mode)) # put it from HWC to CHW format # yikes, this transpose takes 80% of the loading time/CPU img = img.transpose(0, 1).transpose(0, 2).contiguous() else: raise TypeError( f"Unsupported type {type(pic)} must be np.ndarray or torch.Tensor" ) return img.float().div(255.0) if self.div else img.float(), label class TubeMaskingGenerator: def __init__(self, input_size: Tuple[int, int, int], mask_ratio: float) -> None: self.frames, self.height, self.width = input_size self.num_patches_per_frame = self.height * self.width self.total_patches = self.frames * self.num_patches_per_frame self.num_masks_per_frame = int(mask_ratio * self.num_patches_per_frame) self.total_masks = self.frames * self.num_masks_per_frame def __call__(self): mask_per_frame = np.hstack( [ np.zeros(self.num_patches_per_frame - self.num_masks_per_frame), np.ones(self.num_masks_per_frame), ] ) np.random.shuffle(mask_per_frame) mask = np.tile(mask_per_frame, (self.frames, 1)).flatten() return mask def get_videomae_transform(input_size: int = 224) -> "transforms.Compose": return transforms.Compose( [ GroupCenterCrop(input_size), Stack(roll=False), ToTorchFormatTensor(div=True), GroupNormalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), ] )