Spaces:
Running
Running
File size: 4,427 Bytes
65947b1 30d5854 65947b1 30d5854 65947b1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 |
from typing import List, Optional, Tuple, Union
import numpy as np
import torch
from PIL import Image
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from torchvision import transforms
class GroupResize:
def __init__(self, size: int = 256) -> None:
self.transform = transforms.Resize(size)
def __call__(
self, img_tuple: Tuple[torch.Tensor, torch.Tensor]
) -> Tuple[List[torch.Tensor], torch.Tensor]:
img_group, label = img_tuple
return [self.transform(img) for img in img_group], label
class GroupNormalize:
def __init__(self, mean: List[float], std: List[float]) -> None:
self.mean = mean
self.std = std
def __call__(
self, tensor_tuple: Tuple[torch.Tensor, torch.Tensor]
) -> Tuple[torch.Tensor, torch.Tensor]:
tensor, label = tensor_tuple
rep_mean = self.mean * (tensor.size()[0] // len(self.mean))
rep_std = self.std * (tensor.size()[0] // len(self.std))
for t, m, s in zip(tensor, rep_mean, rep_std):
t.sub_(m).div_(s)
return tensor, label
class GroupCenterCrop:
def __init__(self, size: int) -> None:
self.worker = transforms.CenterCrop(size)
def __call__(
self, img_tuple: Tuple[torch.Tensor, torch.Tensor]
) -> Tuple[List[torch.Tensor], torch.Tensor]:
img_group, label = img_tuple
return [self.worker(img) for img in img_group], label
class Stack:
def __init__(self, roll: Optional[bool] = False) -> None:
self.roll = roll
def __call__(self, img_tuple: Tuple[torch.Tensor, torch.Tensor]):
img_group, label = img_tuple
if img_group[0].mode == "L":
return (
np.concatenate([np.expand_dims(x, 2) for x in img_group], axis=2),
label,
)
elif img_group[0].mode == "RGB":
if self.roll:
return (
np.concatenate(
[np.array(x)[:, :, ::-1] for x in img_group], axis=2
),
label,
)
else:
return np.concatenate(img_group, axis=2), label
class ToTorchFormatTensor:
def __init__(self, div: Optional[bool] = True) -> None:
self.div = div
def __call__(
self, pic_tuple: Tuple[Union[np.ndarray, torch.Tensor], torch.Tensor]
) -> Tuple[torch.Tensor, torch.Tensor]:
pic, label = pic_tuple
if isinstance(pic, np.ndarray):
# handle numpy array
img = torch.from_numpy(pic).permute(2, 0, 1).contiguous()
elif isinstance(pic, Image.Image):
# handle PIL Image
img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes()))
img = img.view(pic.size[1], pic.size[0], len(pic.mode))
# put it from HWC to CHW format
# yikes, this transpose takes 80% of the loading time/CPU
img = img.transpose(0, 1).transpose(0, 2).contiguous()
else:
raise TypeError(
f"Unsupported type {type(pic)} must be np.ndarray or torch.Tensor"
)
return img.float().div(255.0) if self.div else img.float(), label
class TubeMaskingGenerator:
def __init__(self, input_size: Tuple[int, int, int], mask_ratio: float) -> None:
self.frames, self.height, self.width = input_size
self.num_patches_per_frame = self.height * self.width
self.total_patches = self.frames * self.num_patches_per_frame
self.num_masks_per_frame = int(mask_ratio * self.num_patches_per_frame)
self.total_masks = self.frames * self.num_masks_per_frame
def __call__(self):
mask_per_frame = np.hstack(
[
np.zeros(self.num_patches_per_frame - self.num_masks_per_frame),
np.ones(self.num_masks_per_frame),
]
)
np.random.shuffle(mask_per_frame)
mask = np.tile(mask_per_frame, (self.frames, 1)).flatten()
return mask
def get_videomae_transform(input_size: int = 224) -> "transforms.Compose":
return transforms.Compose(
[
GroupResize(size=384),
GroupCenterCrop(input_size),
Stack(roll=False),
ToTorchFormatTensor(div=True),
GroupNormalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
]
)
|