Spaces:
Running
Running
SauravMaheshkar
commited on
Commit
•
30d5854
1
Parent(s):
1d4cc3a
feat: use resize transform
Browse files- src/augmentations.py +12 -0
src/augmentations.py
CHANGED
@@ -7,6 +7,17 @@ from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
|
7 |
from torchvision import transforms
|
8 |
|
9 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
class GroupNormalize:
|
11 |
def __init__(self, mean: List[float], std: List[float]) -> None:
|
12 |
self.mean = mean
|
@@ -109,6 +120,7 @@ class TubeMaskingGenerator:
|
|
109 |
def get_videomae_transform(input_size: int = 224) -> "transforms.Compose":
|
110 |
return transforms.Compose(
|
111 |
[
|
|
|
112 |
GroupCenterCrop(input_size),
|
113 |
Stack(roll=False),
|
114 |
ToTorchFormatTensor(div=True),
|
|
|
7 |
from torchvision import transforms
|
8 |
|
9 |
|
10 |
+
class GroupResize:
|
11 |
+
def __init__(self, size: int = 256) -> None:
|
12 |
+
self.transform = transforms.Resize(size)
|
13 |
+
|
14 |
+
def __call__(
|
15 |
+
self, img_tuple: Tuple[torch.Tensor, torch.Tensor]
|
16 |
+
) -> Tuple[List[torch.Tensor], torch.Tensor]:
|
17 |
+
img_group, label = img_tuple
|
18 |
+
return [self.transform(img) for img in img_group], label
|
19 |
+
|
20 |
+
|
21 |
class GroupNormalize:
|
22 |
def __init__(self, mean: List[float], std: List[float]) -> None:
|
23 |
self.mean = mean
|
|
|
120 |
def get_videomae_transform(input_size: int = 224) -> "transforms.Compose":
|
121 |
return transforms.Compose(
|
122 |
[
|
123 |
+
GroupResize(size=384),
|
124 |
GroupCenterCrop(input_size),
|
125 |
Stack(roll=False),
|
126 |
ToTorchFormatTensor(div=True),
|