SauravMaheshkar commited on
Commit
65947b1
β€’
1 Parent(s): 38cb6ed

feat: initial commit

Browse files
Files changed (11) hide show
  1. .gitattributes +1 -0
  2. .pre-commit-config.yaml +10 -0
  3. README.md +22 -5
  4. app.py +53 -0
  5. assets/checkpoint.pth +3 -0
  6. assets/example.mp4 +3 -0
  7. augmentations.py +117 -0
  8. models.py +714 -0
  9. pyproject.toml +2 -0
  10. requirements.txt +6 -0
  11. utils.py +144 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.mp4 filter=lfs diff=lfs merge=lfs -text
.pre-commit-config.yaml ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ repos:
2
+ - repo: https://github.com/psf/black
3
+ rev: 24.4.2
4
+ hooks:
5
+ - id: black
6
+ - repo: https://github.com/pycqa/isort
7
+ rev: 5.13.2
8
+ hooks:
9
+ - id: isort
10
+ args: ["--profile", "black"]
README.md CHANGED
@@ -1,13 +1,30 @@
1
  ---
2
- title: Videomae Vis
3
  emoji: 🌍
4
- colorFrom: indigo
5
- colorTo: red
6
  sdk: gradio
7
  sdk_version: 4.39.0
8
  app_file: app.py
9
- pinned: false
10
  license: cc-by-4.0
 
11
  ---
12
 
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: VideoMAE Visualisation
3
  emoji: 🌍
4
+ colorFrom: blue
5
+ colorTo: blue
6
  sdk: gradio
7
  sdk_version: 4.39.0
8
  app_file: app.py
9
+ pinned: true
10
  license: cc-by-4.0
11
+ short_description: Visualise outputs of VideoMAE
12
  ---
13
 
14
+ ## References
15
+
16
+ Source Paper πŸ“œ: https://arxiv.org/abs/2203.12602
17
+
18
+ <details>
19
+ <summary>Citation</summary>
20
+
21
+ ```lang-misc
22
+ @inproceedings{tong2022videomae,
23
+ title={Video{MAE}: Masked Autoencoders are Data-Efficient Learners for Self-Supervised Video Pre-Training},
24
+ author={Zhan Tong and Yibing Song and Jue Wang and Limin Wang},
25
+ booktitle={Advances in Neural Information Processing Systems},
26
+ year={2022}
27
+ }
28
+ ```
29
+
30
+ </details>
app.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+
4
+ from augmentations import get_videomae_transform
5
+ from models import load_model
6
+ from utils import create_plot, get_frames, get_videomae_outputs, prepare_frames_masks
7
+
8
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
9
+
10
+ transform = get_videomae_transform()
11
+
12
+
13
+ def get_visualisations(mask_ratio, video_path):
14
+ frames, ids = get_frames(path=video_path, transform=transform)
15
+
16
+ model, masks, patch_size = load_model(
17
+ path="assets/checkpoint.pth",
18
+ mask_ratio=mask_ratio,
19
+ device=device,
20
+ )
21
+
22
+ with torch.no_grad():
23
+ frames, masks = prepare_frames_masks(frames, masks, device)
24
+ outputs = model(frames, masks)
25
+
26
+ visualisations = get_videomae_outputs(
27
+ frames=frames,
28
+ masks=masks,
29
+ outputs=outputs,
30
+ ids=ids,
31
+ patch_size=patch_size,
32
+ device=device,
33
+ )
34
+
35
+ return create_plot(visualisations)
36
+
37
+
38
+ with gr.Blocks() as app:
39
+ video = gr.Video(
40
+ value="assets/example.mp4",
41
+ )
42
+ mask_ratio_slider = gr.Slider(
43
+ minimum=0.25, maximum=0.95, step=0.05, value=0.75, label="masking ratio"
44
+ )
45
+ btn = gr.Button("Run")
46
+ btn.click(
47
+ get_visualisations,
48
+ inputs=[mask_ratio_slider, video],
49
+ outputs=gr.Plot(label="VideoMAE Outputs", format="png"),
50
+ )
51
+
52
+ if __name__ == "__main__":
53
+ app.launch()
assets/checkpoint.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:444df1bafe93eb03915a1cc97b23abf6ce843cb41555ae25795fbb5aefe5957e
3
+ size 376929369
assets/example.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1a1dd3bf1d9468a2ec14a7cf7523f8ee1de369da6a91276512e771ca7835e453
3
+ size 116347302
augmentations.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional, Tuple, Union
2
+
3
+ import numpy as np
4
+ import torch
5
+ from PIL import Image
6
+ from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
7
+ from torchvision import transforms
8
+
9
+
10
+ class GroupNormalize:
11
+ def __init__(self, mean: List[float], std: List[float]) -> None:
12
+ self.mean = mean
13
+ self.std = std
14
+
15
+ def __call__(
16
+ self, tensor_tuple: Tuple[torch.Tensor, torch.Tensor]
17
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
18
+ tensor, label = tensor_tuple
19
+ rep_mean = self.mean * (tensor.size()[0] // len(self.mean))
20
+ rep_std = self.std * (tensor.size()[0] // len(self.std))
21
+
22
+ for t, m, s in zip(tensor, rep_mean, rep_std):
23
+ t.sub_(m).div_(s)
24
+
25
+ return tensor, label
26
+
27
+
28
+ class GroupCenterCrop:
29
+ def __init__(self, size: int) -> None:
30
+ self.worker = transforms.CenterCrop(size)
31
+
32
+ def __call__(
33
+ self, img_tuple: Tuple[torch.Tensor, torch.Tensor]
34
+ ) -> Tuple[List[torch.Tensor], torch.Tensor]:
35
+ img_group, label = img_tuple
36
+ return [self.worker(img) for img in img_group], label
37
+
38
+
39
+ class Stack:
40
+ def __init__(self, roll: Optional[bool] = False) -> None:
41
+ self.roll = roll
42
+
43
+ def __call__(self, img_tuple: Tuple[torch.Tensor, torch.Tensor]):
44
+ img_group, label = img_tuple
45
+
46
+ if img_group[0].mode == "L":
47
+ return (
48
+ np.concatenate([np.expand_dims(x, 2) for x in img_group], axis=2),
49
+ label,
50
+ )
51
+ elif img_group[0].mode == "RGB":
52
+ if self.roll:
53
+ return (
54
+ np.concatenate(
55
+ [np.array(x)[:, :, ::-1] for x in img_group], axis=2
56
+ ),
57
+ label,
58
+ )
59
+ else:
60
+ return np.concatenate(img_group, axis=2), label
61
+
62
+
63
+ class ToTorchFormatTensor:
64
+ def __init__(self, div: Optional[bool] = True) -> None:
65
+ self.div = div
66
+
67
+ def __call__(
68
+ self, pic_tuple: Tuple[Union[np.ndarray, torch.Tensor], torch.Tensor]
69
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
70
+ pic, label = pic_tuple
71
+
72
+ if isinstance(pic, np.ndarray):
73
+ # handle numpy array
74
+ img = torch.from_numpy(pic).permute(2, 0, 1).contiguous()
75
+ elif isinstance(pic, Image.Image):
76
+ # handle PIL Image
77
+ img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes()))
78
+ img = img.view(pic.size[1], pic.size[0], len(pic.mode))
79
+ # put it from HWC to CHW format
80
+ # yikes, this transpose takes 80% of the loading time/CPU
81
+ img = img.transpose(0, 1).transpose(0, 2).contiguous()
82
+ else:
83
+ raise TypeError(
84
+ f"Unsupported type {type(pic)} must be np.ndarray or torch.Tensor"
85
+ )
86
+ return img.float().div(255.0) if self.div else img.float(), label
87
+
88
+
89
+ class TubeMaskingGenerator:
90
+ def __init__(self, input_size: Tuple[int, int, int], mask_ratio: float) -> None:
91
+ self.frames, self.height, self.width = input_size
92
+ self.num_patches_per_frame = self.height * self.width
93
+ self.total_patches = self.frames * self.num_patches_per_frame
94
+ self.num_masks_per_frame = int(mask_ratio * self.num_patches_per_frame)
95
+ self.total_masks = self.frames * self.num_masks_per_frame
96
+
97
+ def __call__(self):
98
+ mask_per_frame = np.hstack(
99
+ [
100
+ np.zeros(self.num_patches_per_frame - self.num_masks_per_frame),
101
+ np.ones(self.num_masks_per_frame),
102
+ ]
103
+ )
104
+ np.random.shuffle(mask_per_frame)
105
+ mask = np.tile(mask_per_frame, (self.frames, 1)).flatten()
106
+ return mask
107
+
108
+
109
+ def get_videomae_transform(input_size: int = 224) -> "transforms.Compose":
110
+ return transforms.Compose(
111
+ [
112
+ GroupCenterCrop(input_size),
113
+ Stack(roll=False),
114
+ ToTorchFormatTensor(div=True),
115
+ GroupNormalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
116
+ ]
117
+ )
models.py ADDED
@@ -0,0 +1,714 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+ from typing import Tuple
3
+
4
+ import numpy as np
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ import torch.utils.checkpoint as checkpoint
9
+ from timm.models.layers import drop_path, to_2tuple, trunc_normal_
10
+
11
+ from augmentations import TubeMaskingGenerator
12
+
13
+ __all__ = ["load_model"]
14
+
15
+
16
+ def _cfg(url="", **kwargs):
17
+ return {
18
+ "url": url,
19
+ "num_classes": 400,
20
+ "input_size": (3, 224, 224),
21
+ "pool_size": None,
22
+ "crop_pct": 0.9,
23
+ "interpolation": "bicubic",
24
+ "mean": (0.5, 0.5, 0.5),
25
+ "std": (0.5, 0.5, 0.5),
26
+ **kwargs,
27
+ }
28
+
29
+
30
+ class Mlp(nn.Module):
31
+ def __init__(
32
+ self,
33
+ in_features,
34
+ hidden_features=None,
35
+ out_features=None,
36
+ act_layer=nn.GELU,
37
+ drop=0.0,
38
+ ):
39
+ super().__init__()
40
+ out_features = out_features or in_features
41
+ hidden_features = hidden_features or in_features
42
+ self.fc1 = nn.Linear(in_features, hidden_features)
43
+ self.act = act_layer()
44
+ self.fc2 = nn.Linear(hidden_features, out_features)
45
+ self.drop = nn.Dropout(drop)
46
+
47
+ def forward(self, x):
48
+ x = self.fc1(x)
49
+ x = self.act(x)
50
+ # x = self.drop(x)
51
+ # commit this for the orignal BERT implement
52
+ x = self.fc2(x)
53
+ x = self.drop(x)
54
+ return x
55
+
56
+
57
+ class DropPath(nn.Module):
58
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
59
+
60
+ def __init__(self, drop_prob=None):
61
+ super(DropPath, self).__init__()
62
+ self.drop_prob = drop_prob
63
+
64
+ def forward(self, x):
65
+ return drop_path(x, self.drop_prob, self.training)
66
+
67
+ def extra_repr(self) -> str:
68
+ return "p={}".format(self.drop_prob)
69
+
70
+
71
+ class Attention(nn.Module):
72
+ def __init__(
73
+ self,
74
+ dim,
75
+ num_heads=8,
76
+ qkv_bias=False,
77
+ qk_scale=None,
78
+ attn_drop=0.0,
79
+ proj_drop=0.0,
80
+ attn_head_dim=None,
81
+ ):
82
+ super().__init__()
83
+ self.num_heads = num_heads
84
+ head_dim = dim // num_heads
85
+ if attn_head_dim is not None:
86
+ head_dim = attn_head_dim
87
+ all_head_dim = head_dim * self.num_heads
88
+ self.scale = qk_scale or head_dim**-0.5
89
+
90
+ self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)
91
+ if qkv_bias:
92
+ self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
93
+ self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
94
+ else:
95
+ self.q_bias = None
96
+ self.v_bias = None
97
+
98
+ self.attn_drop = nn.Dropout(attn_drop)
99
+ self.proj = nn.Linear(all_head_dim, dim)
100
+ self.proj_drop = nn.Dropout(proj_drop)
101
+
102
+ def forward(self, x):
103
+ B, N, C = x.shape
104
+ qkv_bias = None
105
+ if self.q_bias is not None:
106
+ qkv_bias = torch.cat(
107
+ (
108
+ self.q_bias,
109
+ torch.zeros_like(self.v_bias, requires_grad=False),
110
+ self.v_bias,
111
+ )
112
+ )
113
+ # qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
114
+ qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
115
+ qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
116
+ q, k, v = (
117
+ qkv[0],
118
+ qkv[1],
119
+ qkv[2],
120
+ ) # make torchscript happy (cannot use tensor as tuple)
121
+
122
+ q = q * self.scale
123
+ attn = q @ k.transpose(-2, -1)
124
+
125
+ attn = attn.softmax(dim=-1)
126
+ attn = self.attn_drop(attn)
127
+
128
+ x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
129
+ x = self.proj(x)
130
+ x = self.proj_drop(x)
131
+ return x
132
+
133
+
134
+ class Block(nn.Module):
135
+
136
+ def __init__(
137
+ self,
138
+ dim,
139
+ num_heads,
140
+ mlp_ratio=4.0,
141
+ qkv_bias=False,
142
+ qk_scale=None,
143
+ drop=0.0,
144
+ attn_drop=0.0,
145
+ drop_path=0.0,
146
+ init_values=None,
147
+ act_layer=nn.GELU,
148
+ norm_layer=nn.LayerNorm,
149
+ attn_head_dim=None,
150
+ ):
151
+ super().__init__()
152
+ self.norm1 = norm_layer(dim)
153
+ self.attn = Attention(
154
+ dim,
155
+ num_heads=num_heads,
156
+ qkv_bias=qkv_bias,
157
+ qk_scale=qk_scale,
158
+ attn_drop=attn_drop,
159
+ proj_drop=drop,
160
+ attn_head_dim=attn_head_dim,
161
+ )
162
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
163
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
164
+ self.norm2 = norm_layer(dim)
165
+ mlp_hidden_dim = int(dim * mlp_ratio)
166
+ self.mlp = Mlp(
167
+ in_features=dim,
168
+ hidden_features=mlp_hidden_dim,
169
+ act_layer=act_layer,
170
+ drop=drop,
171
+ )
172
+
173
+ if init_values > 0:
174
+ self.gamma_1 = nn.Parameter(
175
+ init_values * torch.ones((dim)), requires_grad=True
176
+ )
177
+ self.gamma_2 = nn.Parameter(
178
+ init_values * torch.ones((dim)), requires_grad=True
179
+ )
180
+ else:
181
+ self.gamma_1, self.gamma_2 = None, None
182
+
183
+ def forward(self, x):
184
+ if self.gamma_1 is None:
185
+ x = x + self.drop_path(self.attn(self.norm1(x)))
186
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
187
+ else:
188
+ x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x)))
189
+ x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
190
+ return x
191
+
192
+
193
+ class PatchEmbed(nn.Module):
194
+ """Image to Patch Embedding"""
195
+
196
+ def __init__(
197
+ self,
198
+ img_size=224,
199
+ patch_size=16,
200
+ in_chans=3,
201
+ embed_dim=768,
202
+ num_frames=16,
203
+ tubelet_size=2,
204
+ ):
205
+ super().__init__()
206
+ img_size = to_2tuple(img_size)
207
+ patch_size = to_2tuple(patch_size)
208
+ self.tubelet_size = int(tubelet_size)
209
+ num_patches = (
210
+ (img_size[1] // patch_size[1])
211
+ * (img_size[0] // patch_size[0])
212
+ * (num_frames // self.tubelet_size)
213
+ )
214
+ self.img_size = img_size
215
+ self.patch_size = patch_size
216
+ self.num_patches = num_patches
217
+ self.proj = nn.Conv3d(
218
+ in_channels=in_chans,
219
+ out_channels=embed_dim,
220
+ kernel_size=(self.tubelet_size, patch_size[0], patch_size[1]),
221
+ stride=(self.tubelet_size, patch_size[0], patch_size[1]),
222
+ )
223
+
224
+ def forward(self, x, **kwargs):
225
+ B, C, T, H, W = x.shape
226
+ # FIXME look at relaxing size constraints
227
+ assert (
228
+ H == self.img_size[0] and W == self.img_size[1]
229
+ ), f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
230
+ x = self.proj(x).flatten(2).transpose(1, 2)
231
+ return x
232
+
233
+
234
+ def get_sinusoid_encoding_table(n_position, d_hid):
235
+ def get_position_angle_vec(position):
236
+ return [
237
+ position / np.power(10000, 2 * (hid_j // 2) / d_hid)
238
+ for hid_j in range(d_hid)
239
+ ]
240
+
241
+ sinusoid_table = np.array(
242
+ [get_position_angle_vec(pos_i) for pos_i in range(n_position)]
243
+ )
244
+ sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
245
+ sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
246
+
247
+ return torch.tensor(
248
+ sinusoid_table, dtype=torch.float, requires_grad=False
249
+ ).unsqueeze(0)
250
+
251
+
252
+ class PretrainVisionTransformerEncoder(nn.Module):
253
+ """Vision Transformer with support for patch or hybrid CNN input stage"""
254
+
255
+ def __init__(
256
+ self,
257
+ img_size=224,
258
+ patch_size=16,
259
+ in_chans=3,
260
+ num_classes=0,
261
+ embed_dim=768,
262
+ depth=12,
263
+ num_heads=12,
264
+ mlp_ratio=4.0,
265
+ qkv_bias=False,
266
+ qk_scale=None,
267
+ drop_rate=0.0,
268
+ attn_drop_rate=0.0,
269
+ drop_path_rate=0.0,
270
+ norm_layer=nn.LayerNorm,
271
+ init_values=None,
272
+ tubelet_size=2,
273
+ use_checkpoint=False,
274
+ use_learnable_pos_emb=False,
275
+ ):
276
+ super().__init__()
277
+ self.num_classes = num_classes
278
+ self.num_features = self.embed_dim = (
279
+ embed_dim # num_features for consistency with other models
280
+ )
281
+ self.patch_embed = PatchEmbed(
282
+ img_size=img_size,
283
+ patch_size=patch_size,
284
+ in_chans=in_chans,
285
+ embed_dim=embed_dim,
286
+ tubelet_size=tubelet_size,
287
+ )
288
+ num_patches = self.patch_embed.num_patches
289
+ self.use_checkpoint = use_checkpoint
290
+
291
+ # TODO: Add the cls token
292
+ if use_learnable_pos_emb:
293
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
294
+ else:
295
+ # sine-cosine positional embeddings
296
+ self.pos_embed = get_sinusoid_encoding_table(num_patches, embed_dim)
297
+
298
+ dpr = [
299
+ x.item() for x in torch.linspace(0, drop_path_rate, depth)
300
+ ] # stochastic depth decay rule
301
+ self.blocks = nn.ModuleList(
302
+ [
303
+ Block(
304
+ dim=embed_dim,
305
+ num_heads=num_heads,
306
+ mlp_ratio=mlp_ratio,
307
+ qkv_bias=qkv_bias,
308
+ qk_scale=qk_scale,
309
+ drop=drop_rate,
310
+ attn_drop=attn_drop_rate,
311
+ drop_path=dpr[i],
312
+ norm_layer=norm_layer,
313
+ init_values=init_values,
314
+ )
315
+ for i in range(depth)
316
+ ]
317
+ )
318
+ self.norm = norm_layer(embed_dim)
319
+ self.head = (
320
+ nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
321
+ )
322
+
323
+ if use_learnable_pos_emb:
324
+ trunc_normal_(self.pos_embed, std=0.02)
325
+
326
+ self.apply(self._init_weights)
327
+
328
+ def _init_weights(self, m):
329
+ if isinstance(m, nn.Linear):
330
+ nn.init.xavier_uniform_(m.weight)
331
+ if isinstance(m, nn.Linear) and m.bias is not None:
332
+ nn.init.constant_(m.bias, 0)
333
+ elif isinstance(m, nn.LayerNorm):
334
+ nn.init.constant_(m.bias, 0)
335
+ nn.init.constant_(m.weight, 1.0)
336
+
337
+ def get_num_layers(self):
338
+ return len(self.blocks)
339
+
340
+ @torch.jit.ignore
341
+ def no_weight_decay(self):
342
+ return {"pos_embed", "cls_token"}
343
+
344
+ def get_classifier(self):
345
+ return self.head
346
+
347
+ def reset_classifier(self, num_classes, global_pool=""):
348
+ self.num_classes = num_classes
349
+ self.head = (
350
+ nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
351
+ )
352
+
353
+ def forward_features(self, x, mask):
354
+ _, _, T, _, _ = x.shape
355
+ x = self.patch_embed(x)
356
+
357
+ x = x + self.pos_embed.type_as(x).to(x.device).clone().detach()
358
+
359
+ B, _, C = x.shape
360
+ x_vis = x[~mask].reshape(B, -1, C) # ~mask means visible
361
+
362
+ if self.use_checkpoint:
363
+ for blk in self.blocks:
364
+ x_vis = checkpoint.checkpoint(blk, x_vis)
365
+ else:
366
+ for blk in self.blocks:
367
+ x_vis = blk(x_vis)
368
+
369
+ x_vis = self.norm(x_vis)
370
+ return x_vis
371
+
372
+ def forward(self, x, mask):
373
+ x = self.forward_features(x, mask)
374
+ x = self.head(x)
375
+ return x
376
+
377
+
378
+ class PretrainVisionTransformerDecoder(nn.Module):
379
+ """Vision Transformer with support for patch or hybrid CNN input stage"""
380
+
381
+ def __init__(
382
+ self,
383
+ patch_size=16,
384
+ num_classes=768,
385
+ embed_dim=768,
386
+ depth=12,
387
+ num_heads=12,
388
+ mlp_ratio=4.0,
389
+ qkv_bias=False,
390
+ qk_scale=None,
391
+ drop_rate=0.0,
392
+ attn_drop_rate=0.0,
393
+ drop_path_rate=0.0,
394
+ norm_layer=nn.LayerNorm,
395
+ init_values=None,
396
+ num_patches=196,
397
+ tubelet_size=2,
398
+ use_checkpoint=False,
399
+ ):
400
+ super().__init__()
401
+ self.num_classes = num_classes
402
+ assert num_classes == 3 * tubelet_size * patch_size**2
403
+ self.num_features = self.embed_dim = (
404
+ embed_dim # num_features for consistency with other models
405
+ )
406
+ self.patch_size = patch_size
407
+ self.use_checkpoint = use_checkpoint
408
+
409
+ dpr = [
410
+ x.item() for x in torch.linspace(0, drop_path_rate, depth)
411
+ ] # stochastic depth decay rule
412
+ self.blocks = nn.ModuleList(
413
+ [
414
+ Block(
415
+ dim=embed_dim,
416
+ num_heads=num_heads,
417
+ mlp_ratio=mlp_ratio,
418
+ qkv_bias=qkv_bias,
419
+ qk_scale=qk_scale,
420
+ drop=drop_rate,
421
+ attn_drop=attn_drop_rate,
422
+ drop_path=dpr[i],
423
+ norm_layer=norm_layer,
424
+ init_values=init_values,
425
+ )
426
+ for i in range(depth)
427
+ ]
428
+ )
429
+ self.norm = norm_layer(embed_dim)
430
+ self.head = (
431
+ nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
432
+ )
433
+
434
+ self.apply(self._init_weights)
435
+
436
+ def _init_weights(self, m):
437
+ if isinstance(m, nn.Linear):
438
+ nn.init.xavier_uniform_(m.weight)
439
+ if isinstance(m, nn.Linear) and m.bias is not None:
440
+ nn.init.constant_(m.bias, 0)
441
+ elif isinstance(m, nn.LayerNorm):
442
+ nn.init.constant_(m.bias, 0)
443
+ nn.init.constant_(m.weight, 1.0)
444
+
445
+ def get_num_layers(self):
446
+ return len(self.blocks)
447
+
448
+ @torch.jit.ignore
449
+ def no_weight_decay(self):
450
+ return {"pos_embed", "cls_token"}
451
+
452
+ def get_classifier(self):
453
+ return self.head
454
+
455
+ def reset_classifier(self, num_classes, global_pool=""):
456
+ self.num_classes = num_classes
457
+ self.head = (
458
+ nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
459
+ )
460
+
461
+ def forward(self, x, return_token_num):
462
+ if self.use_checkpoint:
463
+ for blk in self.blocks:
464
+ x = checkpoint.checkpoint(blk, x)
465
+ else:
466
+ for blk in self.blocks:
467
+ x = blk(x)
468
+
469
+ if return_token_num > 0:
470
+ x = self.head(
471
+ self.norm(x[:, -return_token_num:])
472
+ ) # only return the mask tokens predict pixels
473
+ else:
474
+ x = self.head(self.norm(x))
475
+
476
+ return x
477
+
478
+
479
+ class PretrainVisionTransformer(nn.Module):
480
+ """Vision Transformer with support for patch or hybrid CNN input stage"""
481
+
482
+ def __init__(
483
+ self,
484
+ img_size=224,
485
+ patch_size=16,
486
+ encoder_in_chans=3,
487
+ encoder_num_classes=0,
488
+ encoder_embed_dim=768,
489
+ encoder_depth=12,
490
+ encoder_num_heads=12,
491
+ decoder_num_classes=1536, # decoder_num_classes=768,
492
+ decoder_embed_dim=512,
493
+ decoder_depth=8,
494
+ decoder_num_heads=8,
495
+ mlp_ratio=4.0,
496
+ qkv_bias=False,
497
+ qk_scale=None,
498
+ drop_rate=0.0,
499
+ attn_drop_rate=0.0,
500
+ drop_path_rate=0.0,
501
+ norm_layer=nn.LayerNorm,
502
+ init_values=0.0,
503
+ use_learnable_pos_emb=False,
504
+ use_checkpoint=False,
505
+ tubelet_size=2,
506
+ num_classes=0, # avoid the error from create_fn in timm
507
+ in_chans=0, # avoid the error from create_fn in timm
508
+ ):
509
+ super().__init__()
510
+ self.encoder = PretrainVisionTransformerEncoder(
511
+ img_size=img_size,
512
+ patch_size=patch_size,
513
+ in_chans=encoder_in_chans,
514
+ num_classes=encoder_num_classes,
515
+ embed_dim=encoder_embed_dim,
516
+ depth=encoder_depth,
517
+ num_heads=encoder_num_heads,
518
+ mlp_ratio=mlp_ratio,
519
+ qkv_bias=qkv_bias,
520
+ qk_scale=qk_scale,
521
+ drop_rate=drop_rate,
522
+ attn_drop_rate=attn_drop_rate,
523
+ drop_path_rate=drop_path_rate,
524
+ norm_layer=norm_layer,
525
+ init_values=init_values,
526
+ tubelet_size=tubelet_size,
527
+ use_checkpoint=use_checkpoint,
528
+ use_learnable_pos_emb=use_learnable_pos_emb,
529
+ )
530
+
531
+ self.decoder = PretrainVisionTransformerDecoder(
532
+ patch_size=patch_size,
533
+ num_patches=self.encoder.patch_embed.num_patches,
534
+ num_classes=decoder_num_classes,
535
+ embed_dim=decoder_embed_dim,
536
+ depth=decoder_depth,
537
+ num_heads=decoder_num_heads,
538
+ mlp_ratio=mlp_ratio,
539
+ qkv_bias=qkv_bias,
540
+ qk_scale=qk_scale,
541
+ drop_rate=drop_rate,
542
+ attn_drop_rate=attn_drop_rate,
543
+ drop_path_rate=drop_path_rate,
544
+ norm_layer=norm_layer,
545
+ init_values=init_values,
546
+ tubelet_size=tubelet_size,
547
+ use_checkpoint=use_checkpoint,
548
+ )
549
+
550
+ self.encoder_to_decoder = nn.Linear(
551
+ encoder_embed_dim, decoder_embed_dim, bias=False
552
+ )
553
+
554
+ self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))
555
+
556
+ self.pos_embed = get_sinusoid_encoding_table(
557
+ self.encoder.patch_embed.num_patches, decoder_embed_dim
558
+ )
559
+
560
+ trunc_normal_(self.mask_token, std=0.02)
561
+
562
+ def _init_weights(self, m):
563
+ if isinstance(m, nn.Linear):
564
+ nn.init.xavier_uniform_(m.weight)
565
+ if isinstance(m, nn.Linear) and m.bias is not None:
566
+ nn.init.constant_(m.bias, 0)
567
+ elif isinstance(m, nn.LayerNorm):
568
+ nn.init.constant_(m.bias, 0)
569
+ nn.init.constant_(m.weight, 1.0)
570
+
571
+ def get_num_layers(self):
572
+ return len(self.blocks)
573
+
574
+ @torch.jit.ignore
575
+ def no_weight_decay(self):
576
+ return {"pos_embed", "cls_token", "mask_token"}
577
+
578
+ def forward(self, x, mask):
579
+ _, _, T, _, _ = x.shape
580
+ x_vis = self.encoder(x, mask) # [B, N_vis, C_e]
581
+ x_vis = self.encoder_to_decoder(x_vis) # [B, N_vis, C_d]
582
+ B, N, C = x_vis.shape
583
+ # we don't unshuffle the correct visible token order,
584
+ # but shuffle the pos embedding accorddingly.
585
+ expand_pos_embed = (
586
+ self.pos_embed.expand(B, -1, -1).type_as(x).to(x.device).clone().detach()
587
+ )
588
+ pos_emd_vis = expand_pos_embed[~mask].reshape(B, -1, C)
589
+ pos_emd_mask = expand_pos_embed[mask].reshape(B, -1, C)
590
+ x_full = torch.cat(
591
+ [x_vis + pos_emd_vis, self.mask_token + pos_emd_mask], dim=1
592
+ ) # [B, N, C_d]
593
+ x = self.decoder(x_full, pos_emd_mask.shape[1]) # [B, N_mask, 3 * 16 * 16]
594
+
595
+ return x
596
+
597
+
598
+ def pretrain_videomae_small_patch16_224(pretrained=False, **kwargs):
599
+ model = PretrainVisionTransformer(
600
+ img_size=224,
601
+ patch_size=16,
602
+ encoder_embed_dim=384,
603
+ encoder_depth=12,
604
+ encoder_num_heads=6,
605
+ encoder_num_classes=0,
606
+ decoder_num_classes=1536,
607
+ decoder_embed_dim=192,
608
+ decoder_num_heads=3,
609
+ mlp_ratio=4,
610
+ qkv_bias=True,
611
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
612
+ **kwargs,
613
+ )
614
+ model.default_cfg = _cfg()
615
+ if pretrained:
616
+ checkpoint = torch.load(kwargs["init_ckpt"], map_location="cpu")
617
+ model.load_state_dict(checkpoint["model"])
618
+ return model
619
+
620
+
621
+ def pretrain_videomae_base_patch16_224(pretrained=False, **kwargs):
622
+ model = PretrainVisionTransformer(
623
+ img_size=224,
624
+ patch_size=16,
625
+ encoder_embed_dim=768,
626
+ encoder_depth=12,
627
+ encoder_num_heads=12,
628
+ encoder_num_classes=0,
629
+ decoder_num_classes=1536,
630
+ decoder_embed_dim=384,
631
+ decoder_num_heads=6,
632
+ mlp_ratio=4,
633
+ qkv_bias=True,
634
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
635
+ **kwargs,
636
+ )
637
+ model.default_cfg = _cfg()
638
+ if pretrained:
639
+ checkpoint = torch.load(kwargs["init_ckpt"], map_location="cpu")
640
+ model.load_state_dict(checkpoint["model"])
641
+ return model
642
+
643
+
644
+ def pretrain_videomae_large_patch16_224(pretrained=False, **kwargs):
645
+ model = PretrainVisionTransformer(
646
+ img_size=224,
647
+ patch_size=16,
648
+ encoder_embed_dim=1024,
649
+ encoder_depth=24,
650
+ encoder_num_heads=16,
651
+ encoder_num_classes=0,
652
+ decoder_num_classes=1536,
653
+ decoder_embed_dim=512,
654
+ decoder_num_heads=8,
655
+ mlp_ratio=4,
656
+ qkv_bias=True,
657
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
658
+ **kwargs,
659
+ )
660
+ model.default_cfg = _cfg()
661
+ if pretrained:
662
+ checkpoint = torch.load(kwargs["init_ckpt"], map_location="cpu")
663
+ model.load_state_dict(checkpoint["model"])
664
+ return model
665
+
666
+
667
+ def pretrain_videomae_huge_patch16_224(pretrained=False, **kwargs):
668
+ model = PretrainVisionTransformer(
669
+ img_size=224,
670
+ patch_size=16,
671
+ encoder_embed_dim=1280,
672
+ encoder_depth=32,
673
+ encoder_num_heads=16,
674
+ encoder_num_classes=0,
675
+ decoder_num_classes=1536,
676
+ decoder_embed_dim=640,
677
+ decoder_num_heads=8,
678
+ mlp_ratio=4,
679
+ qkv_bias=True,
680
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
681
+ **kwargs,
682
+ )
683
+ model.default_cfg = _cfg()
684
+ if pretrained:
685
+ checkpoint = torch.load(kwargs["init_ckpt"], map_location="cpu")
686
+ model.load_state_dict(checkpoint["model"])
687
+ return model
688
+
689
+
690
+ def load_model(
691
+ path: str,
692
+ mask_ratio: float,
693
+ device: "torch.device",
694
+ num_frames: int = 16,
695
+ input_size: int = 224,
696
+ ) -> Tuple[torch.nn.Module, torch.Tensor, Tuple[int, ...]]:
697
+ model = pretrain_videomae_base_patch16_224(
698
+ pretrained=False, drop_path_rate=0.0, decoder_depth=4
699
+ ).to(device)
700
+ patch_size = model.encoder.patch_embed.patch_size
701
+ window_size = (
702
+ num_frames // 2,
703
+ input_size // patch_size[0],
704
+ input_size // patch_size[1],
705
+ )
706
+
707
+ weights = torch.load(path, map_location="cpu")
708
+ model.load_state_dict(weights["model"])
709
+ model.eval()
710
+
711
+ masked_generator = TubeMaskingGenerator(window_size, mask_ratio)
712
+ masks = torch.from_numpy(masked_generator())
713
+
714
+ return model, masks, patch_size
pyproject.toml ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ [tool.isort]
2
+ profile = "black"
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ einops
2
+ decord
3
+ numpy
4
+ timm
5
+ torch
6
+ torchvision
utils.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Tuple
2
+
3
+ import matplotlib.pyplot as plt
4
+ import numpy as np
5
+ import torch
6
+ from decord import VideoReader, cpu
7
+ from einops import rearrange
8
+ from PIL import Image
9
+ from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
10
+ from torchvision import transforms
11
+ from torchvision.transforms import ToPILImage
12
+
13
+
14
+ def get_frames(
15
+ path: str, transform: transforms.Compose, num_frames: int = 16
16
+ ) -> Tuple[torch.Tensor, List[int]]:
17
+ vr = VideoReader(path, ctx=cpu(0))
18
+ tmp = np.arange(0, num_frames * 2, 2) + 60
19
+ frame_id_list = tmp.tolist()
20
+ video_data = vr.get_batch(frame_id_list).asnumpy()
21
+ frames, _ = transform(
22
+ (
23
+ [
24
+ Image.fromarray(video_data[vid, :, :, :]).convert("RGB")
25
+ for vid, _ in enumerate(frame_id_list)
26
+ ],
27
+ None,
28
+ )
29
+ )
30
+ frames = frames.view((num_frames, 3) + frames.size()[-2:]).transpose(0, 1)
31
+
32
+ return frames, frame_id_list
33
+
34
+
35
+ def prepare_frames_masks(
36
+ frames: torch.Tensor, masks: torch.Tensor, device: "torch.device"
37
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
38
+ frames = frames.unsqueeze(0)
39
+ masks = masks.unsqueeze(0)
40
+
41
+ frames = frames.to(device, non_blocking=True)
42
+ masks = masks.to(device, non_blocking=True).flatten(1).to(torch.bool)
43
+
44
+ return frames, masks
45
+
46
+
47
+ def get_videomae_outputs(
48
+ frames: torch.Tensor,
49
+ masks: torch.Tensor,
50
+ outputs: torch.Tensor,
51
+ ids: List[int],
52
+ patch_size: Tuple[int, ...],
53
+ device: "torch.device",
54
+ ):
55
+ visualisations = []
56
+
57
+ mean = torch.as_tensor(IMAGENET_DEFAULT_MEAN).to(device)[None, :, None, None, None]
58
+ std = torch.as_tensor(IMAGENET_DEFAULT_STD).to(device)[None, :, None, None, None]
59
+ ori_img = frames * std + mean # in [0, 1]
60
+ original_images = [
61
+ ToPILImage()(ori_img[0, :, vid, :, :].cpu()) for vid, _ in enumerate(ids)
62
+ ]
63
+
64
+ img_squeeze = rearrange(
65
+ ori_img,
66
+ "b c (t p0) (h p1) (w p2) -> b (t h w) (p0 p1 p2) c",
67
+ p0=2,
68
+ p1=patch_size[0],
69
+ p2=patch_size[0],
70
+ )
71
+ img_norm = (img_squeeze - img_squeeze.mean(dim=-2, keepdim=True)) / (
72
+ img_squeeze.var(dim=-2, unbiased=True, keepdim=True).sqrt() + 1e-6
73
+ )
74
+ img_patch = rearrange(img_norm, "b n p c -> b n (p c)")
75
+ img_patch[masks] = outputs
76
+
77
+ # make mask
78
+ mask = torch.ones_like(img_patch)
79
+ mask[masks] = 0
80
+ mask = rearrange(mask, "b n (p c) -> b n p c", c=3)
81
+ mask = rearrange(
82
+ mask,
83
+ "b (t h w) (p0 p1 p2) c -> b c (t p0) (h p1) (w p2) ",
84
+ p0=2,
85
+ p1=patch_size[0],
86
+ p2=patch_size[1],
87
+ h=14,
88
+ w=14,
89
+ )
90
+
91
+ # save reconstruction video
92
+ rec_img = rearrange(img_patch, "b n (p c) -> b n p c", c=3)
93
+ rec_img = rec_img * (
94
+ img_squeeze.var(dim=-2, unbiased=True, keepdim=True).sqrt() + 1e-6
95
+ ) + img_squeeze.mean(dim=-2, keepdim=True)
96
+ rec_img = rearrange(
97
+ rec_img,
98
+ "b (t h w) (p0 p1 p2) c -> b c (t p0) (h p1) (w p2)",
99
+ p0=2,
100
+ p1=patch_size[0],
101
+ p2=patch_size[1],
102
+ h=14,
103
+ w=14,
104
+ )
105
+ reconstructed_images = [
106
+ ToPILImage()(rec_img[0, :, vid, :, :].cpu().clamp(0, 0.996))
107
+ for vid, _ in enumerate(ids)
108
+ ]
109
+
110
+ # save masked video
111
+ img_mask = rec_img * mask
112
+ masked_images = [
113
+ ToPILImage()(img_mask[0, :, vid, :, :].cpu()) for vid, _ in enumerate(ids)
114
+ ]
115
+
116
+ assert len(original_images) == len(reconstructed_images) == len(masked_images)
117
+
118
+ for i in range(len(original_images)):
119
+ visualisations.append(
120
+ [original_images[i], masked_images[i], reconstructed_images[i]]
121
+ )
122
+
123
+ return visualisations
124
+
125
+
126
+ def create_plot(images):
127
+ num_cols = 3
128
+ num_rows = 16
129
+ column_names = ["Original Patch", "Masked Patch", "Reconstructed Patch"]
130
+
131
+ fig, axes = plt.subplots(num_rows, num_cols, figsize=(12, 48))
132
+
133
+ for i in range(num_rows):
134
+ for j in range(num_cols):
135
+ axes[i, j].imshow(images[i][j])
136
+ axes[i, j].axis("off")
137
+
138
+ if i == 0:
139
+ axes[i, j].set_title(column_names[j], fontsize=16)
140
+
141
+ plt.tight_layout()
142
+ plt.show()
143
+
144
+ return fig