Paolo-Fraccaro commited on
Commit
c70cc56
1 Parent(s): 1ecae73

Upload Prithvi.py

Browse files
Files changed (1) hide show
  1. Prithvi.py +291 -0
Prithvi.py ADDED
@@ -0,0 +1,291 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # --------------------------------------------------------
7
+ # References:
8
+ # timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm
9
+ # DeiT: https://github.com/facebookresearch/deit
10
+ # --------------------------------------------------------
11
+
12
+ from functools import partial
13
+
14
+ import torch
15
+ import torch.nn as nn
16
+
17
+ from timm.models.vision_transformer import Block
18
+ from timm.models.layers import to_2tuple, _assert
19
+
20
+ import numpy as np
21
+
22
+ from einops import rearrange
23
+
24
+ def get_3d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
25
+ """
26
+ grid_size: 3d tuple of grid size: t, h, w
27
+ return:
28
+ pos_embed: L, D
29
+ """
30
+
31
+ assert embed_dim % 16 == 0
32
+
33
+ t_size, h_size, w_size = grid_size
34
+
35
+ w_embed_dim = embed_dim // 16 * 6
36
+ h_embed_dim = embed_dim // 16 * 6
37
+ t_embed_dim = embed_dim // 16 * 4
38
+
39
+ w_pos_embed = get_1d_sincos_pos_embed_from_grid(w_embed_dim, np.arange(w_size))
40
+ h_pos_embed = get_1d_sincos_pos_embed_from_grid(h_embed_dim, np.arange(h_size))
41
+ t_pos_embed = get_1d_sincos_pos_embed_from_grid(t_embed_dim, np.arange(t_size))
42
+
43
+ w_pos_embed = np.tile(w_pos_embed, (t_size * h_size, 1))
44
+ h_pos_embed = np.tile(np.repeat(h_pos_embed, w_size, axis=0), (t_size, 1))
45
+ t_pos_embed = np.repeat(t_pos_embed, h_size * w_size, axis=0)
46
+
47
+ pos_embed = np.concatenate((w_pos_embed, h_pos_embed, t_pos_embed), axis=1)
48
+
49
+ if cls_token:
50
+ pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
51
+ return pos_embed
52
+
53
+
54
+ class PatchEmbed(nn.Module):
55
+ """ Frames of 2D Images to Patch Embedding
56
+ The 3D version of timm.models.vision_transformer.PatchEmbed
57
+ """
58
+ def __init__(
59
+ self,
60
+ img_size=224,
61
+ patch_size=16,
62
+ num_frames=3,
63
+ tubelet_size=1,
64
+ in_chans=3,
65
+ embed_dim=768,
66
+ norm_layer=None,
67
+ flatten=True,
68
+ bias=True,
69
+ ):
70
+ super().__init__()
71
+ img_size = to_2tuple(img_size)
72
+ patch_size = to_2tuple(patch_size)
73
+ self.img_size = img_size
74
+ self.patch_size = patch_size
75
+ self.num_frames = num_frames
76
+ self.tubelet_size = tubelet_size
77
+ self.grid_size = (num_frames // tubelet_size, img_size[0] // patch_size[0], img_size[1] // patch_size[1])
78
+ self.num_patches = self.grid_size[0] * self.grid_size[1] * self.grid_size[2]
79
+ self.flatten = flatten
80
+
81
+ self.proj = nn.Conv3d(in_chans, embed_dim,
82
+ kernel_size=(tubelet_size, patch_size[0], patch_size[1]),
83
+ stride=(tubelet_size, patch_size[0], patch_size[1]), bias=bias)
84
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
85
+
86
+ def forward(self, x):
87
+ B, C, T, H, W = x.shape
88
+ _assert(H == self.img_size[0], f"Input image height ({H}) doesn't match model ({self.img_size[0]}).")
89
+ _assert(W == self.img_size[1], f"Input image width ({W}) doesn't match model ({self.img_size[1]}).")
90
+ x = self.proj(x)
91
+ if self.flatten:
92
+ x = x.flatten(2).transpose(1, 2) # B,C,T,H,W -> B,C,L -> B,L,C
93
+ x = self.norm(x)
94
+ return x
95
+
96
+
97
+ class MaskedAutoencoderViT(nn.Module):
98
+ """ Masked Autoencoder with VisionTransformer backbone
99
+ """
100
+ def __init__(self, img_size=224, patch_size=16,
101
+ num_frames=3, tubelet_size=1,
102
+ in_chans=3, embed_dim=1024, depth=24, num_heads=16,
103
+ decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16,
104
+ mlp_ratio=4., norm_layer=nn.LayerNorm, norm_pix_loss=False):
105
+ super().__init__()
106
+
107
+ # --------------------------------------------------------------------------
108
+ # MAE encoder specifics
109
+ self.patch_embed = PatchEmbed(img_size, patch_size,num_frames, tubelet_size, in_chans, embed_dim)
110
+ num_patches = self.patch_embed.num_patches
111
+
112
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
113
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim), requires_grad=False) # fixed sin-cos embedding
114
+
115
+ self.blocks = nn.ModuleList([
116
+ Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer)
117
+ for i in range(depth)])
118
+ self.norm = norm_layer(embed_dim)
119
+ # --------------------------------------------------------------------------
120
+
121
+ # --------------------------------------------------------------------------
122
+ # MAE decoder specifics
123
+ self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim, bias=True)
124
+
125
+ self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))
126
+
127
+ self.decoder_pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, decoder_embed_dim), requires_grad=False) # fixed sin-cos embedding
128
+
129
+ self.decoder_blocks = nn.ModuleList([
130
+ Block(decoder_embed_dim, decoder_num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer)
131
+ for i in range(decoder_depth)])
132
+
133
+ self.decoder_norm = norm_layer(decoder_embed_dim)
134
+ self.decoder_pred = nn.Linear(decoder_embed_dim, tubelet_size * patch_size * patch_size * in_chans, bias=True) # decoder to patch
135
+ # --------------------------------------------------------------------------
136
+
137
+ self.norm_pix_loss = norm_pix_loss
138
+
139
+ self.initialize_weights()
140
+
141
+ def initialize_weights(self):
142
+ # initialization
143
+ # initialize (and freeze) pos_embed by sin-cos embedding
144
+ pos_embed = get_3d_sincos_pos_embed(self.pos_embed.shape[-1], self.patch_embed.grid_size, cls_token=True)
145
+ self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
146
+
147
+ decoder_pos_embed = get_3d_sincos_pos_embed(self.decoder_pos_embed.shape[-1], self.patch_embed.grid_size, cls_token=True)
148
+ self.decoder_pos_embed.data.copy_(torch.from_numpy(decoder_pos_embed).float().unsqueeze(0))
149
+
150
+ # initialize patch_embed like nn.Linear (instead of nn.Conv2d)
151
+ w = self.patch_embed.proj.weight.data
152
+ torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
153
+
154
+ # timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.)
155
+ torch.nn.init.normal_(self.cls_token, std=.02)
156
+ torch.nn.init.normal_(self.mask_token, std=.02)
157
+
158
+ # initialize nn.Linear and nn.LayerNorm
159
+ self.apply(self._init_weights)
160
+
161
+ def _init_weights(self, m):
162
+ if isinstance(m, nn.Linear):
163
+ # we use xavier_uniform following official JAX ViT:
164
+ torch.nn.init.xavier_uniform_(m.weight)
165
+ if isinstance(m, nn.Linear) and m.bias is not None:
166
+ nn.init.constant_(m.bias, 0)
167
+ elif isinstance(m, nn.LayerNorm):
168
+ nn.init.constant_(m.bias, 0)
169
+ nn.init.constant_(m.weight, 1.0)
170
+
171
+ def patchify(self, imgs):
172
+ """
173
+ imgs: B, C, T, H, W
174
+ x: B, L, D
175
+ """
176
+ p = self.patch_embed.patch_size[0]
177
+ tub = self.patch_embed.tubelet_size
178
+ x = rearrange(imgs, 'b c (t tub) (h p) (w q) -> b (t h w) (tub p q c)', tub=tub, p=p, q=p)
179
+
180
+ return x
181
+
182
+ def unpatchify(self, x):
183
+ """
184
+ x: B, L, D
185
+ imgs: B, C, T, H, W
186
+ """
187
+ p = self.patch_embed.patch_size[0]
188
+ num_p = self.patch_embed.img_size[0] // p
189
+ tub = self.patch_embed.tubelet_size
190
+ imgs = rearrange(x, 'b (t h w) (tub p q c) -> b c (t tub) (h p) (w q)', h=num_p, w=num_p, tub=tub, p=p, q=p)
191
+ return imgs
192
+
193
+ def random_masking(self, x, mask_ratio):
194
+ """
195
+ Perform per-sample random masking by per-sample shuffling.
196
+ Per-sample shuffling is done by argsort random noise.
197
+ x: [N, L, D], sequence
198
+ """
199
+ N, L, D = x.shape # batch, length, dim
200
+ len_keep = int(L * (1 - mask_ratio))
201
+
202
+ noise = torch.rand(N, L, device=x.device) # noise in [0, 1]
203
+
204
+ # sort noise for each sample
205
+ ids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is remove
206
+ ids_restore = torch.argsort(ids_shuffle, dim=1)
207
+
208
+ # keep the first subset
209
+ ids_keep = ids_shuffle[:, :len_keep]
210
+ x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))
211
+
212
+ # generate the binary mask: 0 is keep, 1 is remove
213
+ mask = torch.ones([N, L], device=x.device)
214
+ mask[:, :len_keep] = 0
215
+ # unshuffle to get the binary mask
216
+ mask = torch.gather(mask, dim=1, index=ids_restore)
217
+
218
+ return x_masked, mask, ids_restore
219
+
220
+ def forward_encoder(self, x, mask_ratio):
221
+ # embed patches
222
+ x = self.patch_embed(x)
223
+
224
+ # add pos embed w/o cls token
225
+ x = x + self.pos_embed[:, 1:, :]
226
+
227
+ # masking: length -> length * mask_ratio
228
+ x, mask, ids_restore = self.random_masking(x, mask_ratio)
229
+
230
+ # append cls token
231
+ cls_token = self.cls_token + self.pos_embed[:, :1, :]
232
+ cls_tokens = cls_token.expand(x.shape[0], -1, -1)
233
+ x = torch.cat((cls_tokens, x), dim=1)
234
+
235
+ # apply Transformer blocks
236
+ for blk in self.blocks:
237
+ x = blk(x)
238
+ x = self.norm(x)
239
+
240
+ return x, mask, ids_restore
241
+
242
+ def forward_decoder(self, x, ids_restore):
243
+ # embed tokens
244
+ x = self.decoder_embed(x)
245
+
246
+ # append mask tokens to sequence
247
+ mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1)
248
+ x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1) # no cls token
249
+ x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])) # unshuffle
250
+ x = torch.cat([x[:, :1, :], x_], dim=1) # append cls token
251
+
252
+ # add pos embed
253
+ x = x + self.decoder_pos_embed
254
+
255
+ # apply Transformer blocks
256
+ for blk in self.decoder_blocks:
257
+ x = blk(x)
258
+ x = self.decoder_norm(x)
259
+
260
+ # predictor projection
261
+ x = self.decoder_pred(x)
262
+
263
+ # remove cls token
264
+ x = x[:, 1:, :]
265
+
266
+ return x
267
+
268
+ def forward_loss(self, imgs, pred, mask):
269
+ """
270
+ imgs: B, C, T, H, W
271
+ target: B, L, D
272
+ pred: B, L, D
273
+ mask: B, L. 0 is keep, 1 is remove,
274
+ """
275
+ target = self.patchify(imgs)
276
+ if self.norm_pix_loss:
277
+ mean = target.mean(dim=-1, keepdim=True)
278
+ var = target.var(dim=-1, keepdim=True)
279
+ target = (target - mean) / (var + 1.e-6)**.5
280
+
281
+ loss = (pred - target) ** 2
282
+ loss = loss.mean(dim=-1) # [N, L], mean loss per patch
283
+
284
+ loss = (loss * mask).sum() / mask.sum() # mean loss on removed patches
285
+ return loss
286
+
287
+ def forward(self, imgs, mask_ratio=0.75):
288
+ latent, mask, ids_restore = self.forward_encoder(imgs, mask_ratio)
289
+ pred = self.forward_decoder(latent, ids_restore)
290
+ loss = self.forward_loss(imgs, pred, mask)
291
+ return loss, pred, mask