SauravMaheshkar commited on
Commit
30d5854
1 Parent(s): 1d4cc3a

feat: use resize transform

Browse files
Files changed (1) hide show
  1. src/augmentations.py +12 -0
src/augmentations.py CHANGED
@@ -7,6 +7,17 @@ from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
7
  from torchvision import transforms
8
 
9
 
 
 
 
 
 
 
 
 
 
 
 
10
  class GroupNormalize:
11
  def __init__(self, mean: List[float], std: List[float]) -> None:
12
  self.mean = mean
@@ -109,6 +120,7 @@ class TubeMaskingGenerator:
109
  def get_videomae_transform(input_size: int = 224) -> "transforms.Compose":
110
  return transforms.Compose(
111
  [
 
112
  GroupCenterCrop(input_size),
113
  Stack(roll=False),
114
  ToTorchFormatTensor(div=True),
 
7
  from torchvision import transforms
8
 
9
 
10
+ class GroupResize:
11
+ def __init__(self, size: int = 256) -> None:
12
+ self.transform = transforms.Resize(size)
13
+
14
+ def __call__(
15
+ self, img_tuple: Tuple[torch.Tensor, torch.Tensor]
16
+ ) -> Tuple[List[torch.Tensor], torch.Tensor]:
17
+ img_group, label = img_tuple
18
+ return [self.transform(img) for img in img_group], label
19
+
20
+
21
  class GroupNormalize:
22
  def __init__(self, mean: List[float], std: List[float]) -> None:
23
  self.mean = mean
 
120
  def get_videomae_transform(input_size: int = 224) -> "transforms.Compose":
121
  return transforms.Compose(
122
  [
123
+ GroupResize(size=384),
124
  GroupCenterCrop(input_size),
125
  Stack(roll=False),
126
  ToTorchFormatTensor(div=True),