Spaces:
Running
Running
from typing import List, Optional, Tuple, Union | |
import numpy as np | |
import torch | |
from PIL import Image | |
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD | |
from torchvision import transforms | |
class GroupNormalize: | |
def __init__(self, mean: List[float], std: List[float]) -> None: | |
self.mean = mean | |
self.std = std | |
def __call__( | |
self, tensor_tuple: Tuple[torch.Tensor, torch.Tensor] | |
) -> Tuple[torch.Tensor, torch.Tensor]: | |
tensor, label = tensor_tuple | |
rep_mean = self.mean * (tensor.size()[0] // len(self.mean)) | |
rep_std = self.std * (tensor.size()[0] // len(self.std)) | |
for t, m, s in zip(tensor, rep_mean, rep_std): | |
t.sub_(m).div_(s) | |
return tensor, label | |
class GroupCenterCrop: | |
def __init__(self, size: int) -> None: | |
self.worker = transforms.CenterCrop(size) | |
def __call__( | |
self, img_tuple: Tuple[torch.Tensor, torch.Tensor] | |
) -> Tuple[List[torch.Tensor], torch.Tensor]: | |
img_group, label = img_tuple | |
return [self.worker(img) for img in img_group], label | |
class Stack: | |
def __init__(self, roll: Optional[bool] = False) -> None: | |
self.roll = roll | |
def __call__(self, img_tuple: Tuple[torch.Tensor, torch.Tensor]): | |
img_group, label = img_tuple | |
if img_group[0].mode == "L": | |
return ( | |
np.concatenate([np.expand_dims(x, 2) for x in img_group], axis=2), | |
label, | |
) | |
elif img_group[0].mode == "RGB": | |
if self.roll: | |
return ( | |
np.concatenate( | |
[np.array(x)[:, :, ::-1] for x in img_group], axis=2 | |
), | |
label, | |
) | |
else: | |
return np.concatenate(img_group, axis=2), label | |
class ToTorchFormatTensor: | |
def __init__(self, div: Optional[bool] = True) -> None: | |
self.div = div | |
def __call__( | |
self, pic_tuple: Tuple[Union[np.ndarray, torch.Tensor], torch.Tensor] | |
) -> Tuple[torch.Tensor, torch.Tensor]: | |
pic, label = pic_tuple | |
if isinstance(pic, np.ndarray): | |
# handle numpy array | |
img = torch.from_numpy(pic).permute(2, 0, 1).contiguous() | |
elif isinstance(pic, Image.Image): | |
# handle PIL Image | |
img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes())) | |
img = img.view(pic.size[1], pic.size[0], len(pic.mode)) | |
# put it from HWC to CHW format | |
# yikes, this transpose takes 80% of the loading time/CPU | |
img = img.transpose(0, 1).transpose(0, 2).contiguous() | |
else: | |
raise TypeError( | |
f"Unsupported type {type(pic)} must be np.ndarray or torch.Tensor" | |
) | |
return img.float().div(255.0) if self.div else img.float(), label | |
class TubeMaskingGenerator: | |
def __init__(self, input_size: Tuple[int, int, int], mask_ratio: float) -> None: | |
self.frames, self.height, self.width = input_size | |
self.num_patches_per_frame = self.height * self.width | |
self.total_patches = self.frames * self.num_patches_per_frame | |
self.num_masks_per_frame = int(mask_ratio * self.num_patches_per_frame) | |
self.total_masks = self.frames * self.num_masks_per_frame | |
def __call__(self): | |
mask_per_frame = np.hstack( | |
[ | |
np.zeros(self.num_patches_per_frame - self.num_masks_per_frame), | |
np.ones(self.num_masks_per_frame), | |
] | |
) | |
np.random.shuffle(mask_per_frame) | |
mask = np.tile(mask_per_frame, (self.frames, 1)).flatten() | |
return mask | |
def get_videomae_transform(input_size: int = 224) -> "transforms.Compose": | |
return transforms.Compose( | |
[ | |
GroupCenterCrop(input_size), | |
Stack(roll=False), | |
ToTorchFormatTensor(div=True), | |
GroupNormalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), | |
] | |
) | |