kiankaydee commited on
Commit
373b8b8
·
1 Parent(s): 8f0bd34
Files changed (6) hide show
  1. config.yaml +15 -0
  2. loss.py +50 -0
  3. mae_modules.py +272 -0
  4. mae_utils.py +64 -0
  5. masking.py +46 -0
  6. vit.py +284 -0
config.yaml ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ loss:
2
+ _target_: torch.nn.MSELoss # combine with fourier loss weighted at 0.01 mixing factor for best results
3
+ reduction: none
4
+ optimizer:
5
+ _target_: timm.optim.lion.Lion
6
+ _partial_: true
7
+ lr: *lr 1e-4 # 1e-4 for <= ViT-B, and 3e-5 for ViT-L
8
+ weight_decay: 0.05
9
+ betas: [0.9, 0.95]
10
+ lr_scheduler:
11
+ _target_: torch.optim.lr_scheduler.OneCycleLR
12
+ _partial_: true
13
+ max_lr: @lr
14
+ pct_start: 0.1
15
+ anneal_strategy: cos
loss.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+
5
+ class FourierLoss(nn.Module):
6
+ def __init__(
7
+ self,
8
+ use_l1_loss: bool = True,
9
+ num_multimodal_modalities: int = 1, # set to 1 for vanilla MAE, 6 for channel-agnostic MAE
10
+ ) -> None:
11
+ """
12
+ Fourier transform loss is only sound when using L1 or L2 loss to compare the frequency domains
13
+ between the images / their radial histograms.
14
+
15
+ We will always set `reduction="none"` and enforce that the computation of any reductions from the
16
+ output of this loss be managed by the model under question.
17
+ """
18
+ super().__init__()
19
+ self.loss = nn.L1Loss(reduction="none") if use_l1_loss else nn.MSELoss(reduction="none")
20
+ self.num_modalities = num_multimodal_modalities
21
+
22
+ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
23
+ # input = reconstructed image, target = original image
24
+ # flattened images from MAE are (B, H*W, C), so, here we convert to B x C x H x W (note we assume H == W)
25
+ flattened_images = len(input.shape) == len(target.shape) == 3
26
+ if flattened_images:
27
+ B, H_W, C = input.shape
28
+ H_W = H_W // self.num_modalities
29
+ four_d_shape = (B, C * self.num_modalities, int(H_W**0.5), int(H_W**0.5))
30
+ input = input.view(*four_d_shape)
31
+ target = target.view(*four_d_shape)
32
+ else:
33
+ B, C, h, w = input.shape
34
+ H_W = h * w
35
+
36
+ if len(input.shape) != len(target.shape) != 4:
37
+ raise ValueError(f"Invalid input shape: got {input.shape} and {target.shape}.")
38
+
39
+ fft_reconstructed = torch.fft.fft2(input)
40
+ fft_original = torch.fft.fft2(target)
41
+
42
+ magnitude_reconstructed = torch.abs(fft_reconstructed)
43
+ magnitude_original = torch.abs(fft_original)
44
+
45
+ loss_tensor: torch.Tensor = self.loss(magnitude_reconstructed, magnitude_original)
46
+
47
+ if flattened_images and not self.num_bins: # then output loss should be reshaped
48
+ loss_tensor = loss_tensor.reshape(B, H_W * self.num_modalities, C)
49
+
50
+ return loss_tensor
mae_modules.py ADDED
@@ -0,0 +1,272 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+ from typing import Tuple, Union
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ from timm.models.helpers import checkpoint_seq
7
+ from timm.models.vision_transformer import Block, Mlp, VisionTransformer
8
+
9
+ from .masking import transformer_random_masking
10
+ from .vit import channel_agnostic_vit
11
+
12
+ # If interested in training new MAEs, combine an encoder and decoder into a new module, and you should
13
+ # leverage the flattening and unflattening utilities as needed from mae_utils.py.
14
+ # Be sure to use an encoder-decoder Linear projection layer to match encoder dims with decoder dimensions.
15
+ # As described in the paper, images are self-standardized at the start.
16
+
17
+
18
+ class SelfStandardize(nn.Module):
19
+ def __init__(self) -> None:
20
+ super().__init__()
21
+ self.self_standardize = nn.LazyInstanceNorm2d(
22
+ affine=False, track_running_stats=False
23
+ )
24
+
25
+ def forward(self, pixels: torch.Tensor) -> torch.Tensor:
26
+ x = pixels.float() / 255.0
27
+ return self.self_standardize(x)
28
+
29
+
30
+ class MAEEncoder(nn.Module):
31
+ def __init__(
32
+ self,
33
+ vit_backbone: VisionTransformer,
34
+ max_in_chans: int = 6,
35
+ channel_agnostic: bool = False,
36
+ ) -> None:
37
+ super().__init__()
38
+ if channel_agnostic:
39
+ self.vit_backbone = channel_agnostic_vit(
40
+ vit_backbone, max_in_chans=max_in_chans
41
+ )
42
+ else:
43
+ self.vit_backbone = vit_backbone
44
+ self.max_in_chans = max_in_chans
45
+ self.channel_agnostic = channel_agnostic
46
+
47
+ @property
48
+ def embed_dim(self) -> int:
49
+ return int(self.vit_backbone.embed_dim)
50
+
51
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
52
+ x = self.vit_backbone.forward_features(x)
53
+ x = self.vit_backbone.forward_head(x)
54
+ return x # type: ignore[no-any-return]
55
+
56
+ def forward_masked(
57
+ self,
58
+ x: torch.Tensor,
59
+ mask_ratio: float,
60
+ constant_noise: Union[torch.Tensor, None] = None,
61
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
62
+ x = self.vit_backbone.patch_embed(x)
63
+ x = self.vit_backbone._pos_embed(x) # adds class token
64
+ x_ = x[:, 1:, :] # no class token
65
+ x_, mask, ind_restore = transformer_random_masking(
66
+ x_, mask_ratio, constant_noise
67
+ )
68
+ x = torch.cat([x[:, :1, :], x_], dim=1) # add class token
69
+ x = self.vit_backbone.norm_pre(x)
70
+
71
+ if self.vit_backbone.grad_checkpointing and not torch.jit.is_scripting():
72
+ x = checkpoint_seq(self.vit_backbone.blocks, x)
73
+ else:
74
+ x = self.vit_backbone.blocks(x)
75
+ x = self.vit_backbone.norm(x)
76
+ return x, mask, ind_restore
77
+
78
+
79
+ class MAEDecoder(nn.Module):
80
+ def __init__(
81
+ self,
82
+ embed_dim: int = 512,
83
+ depth: int = 8,
84
+ num_heads: int = 16,
85
+ mlp_ratio: float = 4,
86
+ qkv_bias: bool = True,
87
+ norm_layer: nn.Module = partial(nn.LayerNorm, eps=1e-6), # type: ignore[assignment]
88
+ ) -> None:
89
+ super().__init__()
90
+ self.embed_dim = embed_dim
91
+ self.pos_embeddings = None # to be overwritten by MAE class
92
+ self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
93
+ self.blocks = nn.Sequential(
94
+ *[
95
+ Block(
96
+ embed_dim,
97
+ num_heads,
98
+ mlp_ratio,
99
+ qkv_bias=qkv_bias,
100
+ norm_layer=norm_layer,
101
+ )
102
+ for i in range(depth)
103
+ ]
104
+ )
105
+ self.norm = norm_layer(embed_dim)
106
+
107
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
108
+ x = x + self.pos_embeddings
109
+ x = self.blocks(x)
110
+ x = self.norm(x)
111
+ return x # type: ignore[no-any-return]
112
+
113
+ def forward_masked(
114
+ self, x: torch.Tensor, ind_restore: torch.Tensor
115
+ ) -> torch.Tensor:
116
+ mask_tokens = self.mask_token.repeat(
117
+ x.shape[0], ind_restore.shape[1] + 1 - x.shape[1], 1
118
+ )
119
+ x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1) # remove class token
120
+ x_ = torch.gather(
121
+ x_, dim=1, index=ind_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])
122
+ ) # unshuffle
123
+ x = torch.cat([x[:, :1, :], x_], dim=1) # add class token
124
+
125
+ x = x + self.pos_embeddings
126
+ x = self.blocks(x)
127
+ x = self.norm(x)
128
+ return x # type: ignore[no-any-return]
129
+
130
+
131
+ class CrossAttention(nn.Module):
132
+ def __init__(
133
+ self, embed_dim, num_heads=8, qkv_bias=False, attn_drop=0.0, proj_drop=0.0
134
+ ):
135
+ super().__init__()
136
+ self.num_heads = num_heads
137
+ head_dim = embed_dim // num_heads
138
+ self.scale = head_dim**-0.5
139
+
140
+ self.q = nn.Linear(embed_dim, embed_dim, bias=qkv_bias)
141
+ self.kv = nn.Linear(embed_dim, embed_dim * 2, bias=qkv_bias)
142
+
143
+ self.attn_drop = nn.Dropout(attn_drop)
144
+ self.proj = nn.Linear(embed_dim, embed_dim)
145
+ self.proj_drop = nn.Dropout(proj_drop)
146
+
147
+ def forward(self, x, context):
148
+ B, N, C = x.shape
149
+ _, M, _ = context.shape
150
+
151
+ q = (
152
+ self.q(x)
153
+ .reshape(B, N, self.num_heads, C // self.num_heads)
154
+ .permute(0, 2, 1, 3)
155
+ )
156
+ kv = (
157
+ self.kv(context)
158
+ .reshape(B, M, 2, self.num_heads, C // self.num_heads)
159
+ .permute(2, 0, 3, 1, 4)
160
+ )
161
+ k, v = kv[0], kv[1]
162
+
163
+ attn = (q @ k.transpose(-2, -1)) * self.scale
164
+ attn = attn.softmax(dim=-1)
165
+ attn = self.attn_drop(attn)
166
+
167
+ x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
168
+ x = self.proj(x)
169
+ x = self.proj_drop(x)
170
+ return x
171
+
172
+
173
+ class CAMAEDecoder(nn.Module):
174
+ def __init__(
175
+ self,
176
+ num_modalities: int = 6,
177
+ tokens_per_modality: int = 256,
178
+ embed_dim: int = 256,
179
+ depth: int = 2,
180
+ num_heads: int = 16,
181
+ mlp_ratio: float = 4,
182
+ qkv_bias: bool = True,
183
+ norm_layer: nn.Module = partial(nn.LayerNorm, eps=1e-6), # type: ignore[assignment]
184
+ ) -> None:
185
+ super().__init__()
186
+ self.num_modalities = num_modalities
187
+ self.tokens_per_modality = tokens_per_modality
188
+ self.embed_dim = embed_dim
189
+ self.pos_embeddings = None # to be overwritten by MAE class
190
+ self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
191
+ self.placeholder = nn.Parameter(
192
+ torch.zeros(1, 1, embed_dim), requires_grad=False
193
+ )
194
+ self.modality_tokens = nn.ParameterList(
195
+ [
196
+ nn.Parameter(torch.zeros(1, 1, self.embed_dim))
197
+ for modality in range(self.num_modalities)
198
+ ]
199
+ )
200
+
201
+ self.cross_attention = CrossAttention(embed_dim=self.embed_dim)
202
+ self.mlp = Mlp(self.embed_dim, hidden_features=int(self.embed_dim * mlp_ratio))
203
+
204
+ self.decoders = nn.ModuleList(
205
+ [
206
+ nn.Sequential(
207
+ *[
208
+ Block(
209
+ embed_dim,
210
+ num_heads,
211
+ mlp_ratio,
212
+ qkv_bias=qkv_bias,
213
+ norm_layer=norm_layer,
214
+ )
215
+ for i in range(depth)
216
+ ]
217
+ )
218
+ for modality in range(self.num_modalities)
219
+ ]
220
+ )
221
+ # self.norm = norm_layer(embed_dim) # we decided to drop the last layer norm
222
+ self.context_norm = norm_layer(embed_dim)
223
+ self.query_norm = norm_layer(embed_dim)
224
+ self.out_norm = norm_layer(embed_dim)
225
+
226
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
227
+ x_m_s = []
228
+
229
+ modality_tokens_concat = torch.cat(
230
+ [
231
+ self.placeholder,
232
+ ] # placeholder for class token
233
+ + [
234
+ m_t.repeat(1, self.tokens_per_modality, 1)
235
+ for m_t in self.modality_tokens
236
+ ],
237
+ dim=1,
238
+ )
239
+
240
+ x = (
241
+ x + self.pos_embeddings + modality_tokens_concat
242
+ ) # add pos and tiled modality tokens
243
+ x_ = x[:, 1:, :] # no class token
244
+ for m, decoder in enumerate(
245
+ self.decoders
246
+ ): # iterate through modalities and decoders
247
+ x_m = x_[
248
+ :, m * self.tokens_per_modality : (m + 1) * self.tokens_per_modality, :
249
+ ]
250
+ x_m = self.cross_attention(self.query_norm(x_m), self.context_norm(x_))
251
+ x_m = x_m + self.mlp(self.out_norm(x_m))
252
+ x_m = decoder(x_m)
253
+ x_m_s.append(x_m)
254
+ x_m_s = torch.cat(x_m_s, dim=1) # concat all tokens
255
+ # x_m_s = self.norm(x_m_s) # we decided to drop the last layer norm
256
+ x_m_s = torch.cat([x[:, :1, :], x_m_s], dim=1) # add back class token
257
+
258
+ return x_m_s
259
+
260
+ def forward_masked(
261
+ self, x: torch.Tensor, ind_restore: torch.Tensor
262
+ ) -> torch.Tensor:
263
+ mask_tokens = self.mask_token.repeat(
264
+ x.shape[0], ind_restore.shape[1] + 1 - x.shape[1], 1
265
+ )
266
+ x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1) # remove class token
267
+ x_ = torch.gather(
268
+ x_, dim=1, index=ind_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])
269
+ ) # unshuffle
270
+ x = torch.cat([x[:, :1, :], x_], dim=1) # add class token
271
+ x = self.forward(x)
272
+ return x
mae_utils.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import torch
4
+
5
+
6
+ def flatten_images(img: torch.Tensor, patch_size: int, channel_agnostic: bool = False) -> torch.Tensor:
7
+ """
8
+ Flattens 2D images into tokens with the same pixel values
9
+
10
+ Parameters
11
+ ----------
12
+ img : input image tensor (N, C, H, W)
13
+
14
+ Returns
15
+ -------
16
+ flattened_img: flattened image tensor (N, L, patch_size**2 * C)
17
+ """
18
+
19
+ if (img.shape[2] != img.shape[3]) or (img.shape[2] % patch_size != 0):
20
+ raise ValueError("image H must equal image W and be divisible by patch_size")
21
+ in_chans = img.shape[1]
22
+
23
+ h = w = int(img.shape[2] // patch_size)
24
+ x = img.reshape(shape=(img.shape[0], in_chans, h, patch_size, w, patch_size))
25
+
26
+ if channel_agnostic:
27
+ x = torch.permute(x, (0, 1, 2, 4, 3, 5)) # NCHPWQ -> NCHWPQ
28
+ x = x.reshape(shape=(img.shape[0], in_chans * h * w, int(patch_size**2)))
29
+ else:
30
+ x = torch.permute(x, (0, 2, 4, 3, 5, 1)) # NCHPWQ -> NHWPQC
31
+ x = x.reshape(shape=(img.shape[0], h * w, int(patch_size**2 * in_chans)))
32
+ return x
33
+
34
+
35
+ def unflatten_tokens(
36
+ tokens: torch.Tensor, patch_size: int, num_modalities: int = 1, channel_agnostic: bool = False
37
+ ) -> torch.Tensor:
38
+ """
39
+ Unflattens tokens (N,L,patch_size**2 * C) into image tensor (N,C,H,W) with the pixel values
40
+
41
+ Parameters
42
+ ----------
43
+ tokens : input token tensor (N,L,patch_size**2 * C)
44
+
45
+ Returns
46
+ -------
47
+ img: image tensor (N,C,H,W)
48
+ """
49
+ if num_modalities > 1 and not channel_agnostic:
50
+ raise ValueError("Multiple modalities requires channel agnostic unflattening.")
51
+
52
+ h = w = int(math.sqrt(tokens.shape[1] // num_modalities))
53
+ if h * w != (tokens.shape[1] // num_modalities):
54
+ raise ValueError("sqrt of number of tokens not integer")
55
+
56
+ if channel_agnostic:
57
+ x = tokens.reshape(shape=(tokens.shape[0], -1, h, w, patch_size, patch_size))
58
+ x = torch.permute(x, (0, 1, 2, 4, 3, 5)) # NCHWPQ -> NCHPWQ
59
+ else:
60
+ x = tokens.reshape(shape=(tokens.shape[0], h, w, patch_size, patch_size, -1))
61
+ x = torch.permute(x, (0, 5, 1, 3, 2, 4)) # NHWPQC -> NCHPWQ
62
+ img = x.reshape(shape=(x.shape[0], -1, h * patch_size, h * patch_size))
63
+
64
+ return img
masking.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple, Union
2
+
3
+ import torch
4
+
5
+
6
+ def transformer_random_masking(
7
+ x: torch.Tensor, mask_ratio: float, constant_noise: Union[torch.Tensor, None] = None
8
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
9
+ """
10
+ Random mask patches per sample
11
+
12
+ Parameters
13
+ ----------
14
+ x : token tensor (N, L, D)
15
+ mask_ratio: float - ratio of image to mask
16
+ constant_noise: None, if provided should be a tensor of shape (N, L) to produce consistent masks
17
+
18
+ Returns
19
+ -------
20
+ x_masked : sub-sampled version of x ( int(mask_ratio * N), L, D)
21
+ mask : binary mask indicated masked tokens (1 where masked) (N, L)
22
+ ind_restore : locations of masked tokens, needed for decoder
23
+ """
24
+
25
+ N, L, D = x.shape # batch, length, dim
26
+ len_keep = int(L * (1 - mask_ratio))
27
+
28
+ # use random noise to generate batch based random masks
29
+ if constant_noise is not None:
30
+ noise = constant_noise
31
+ else:
32
+ noise = torch.rand(N, L, device=x.device)
33
+
34
+ shuffled_tokens = torch.argsort(noise, dim=1) # shuffled index
35
+ ind_restore = torch.argsort(shuffled_tokens, dim=1) # unshuffled index
36
+
37
+ # get masked input
38
+ tokens_to_keep = shuffled_tokens[:, :len_keep] # keep the first len_keep indices
39
+ x_masked = torch.gather(x, dim=1, index=tokens_to_keep.unsqueeze(-1).repeat(1, 1, D))
40
+
41
+ # get binary mask used for loss masking: 0 is keep, 1 is remove
42
+ mask = torch.ones([N, L], device=x.device)
43
+ mask[:, :len_keep] = 0
44
+ mask = torch.gather(mask, dim=1, index=ind_restore) # unshuffle to get the binary mask
45
+
46
+ return x_masked, mask, ind_restore
vit.py ADDED
@@ -0,0 +1,284 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import timm.models.vision_transformer as vit
2
+ import torch
3
+
4
+
5
+ def generate_2d_sincos_pos_embeddings(
6
+ embedding_dim: int, length: int, scale: float = 10000.0, use_class_token: bool = True, num_modality: int = 1
7
+ ) -> torch.nn.Parameter:
8
+ """
9
+ Generate 2Dimensional sin/cosine positional embeddings
10
+
11
+ Parameters
12
+ ----------
13
+ embedding_dim : int
14
+ embedding dimension used in vit
15
+ length : int
16
+ number of tokens along height or width of image after patching (assuming square)
17
+ scale : float
18
+ scale for sin/cos functions
19
+ use_class_token : bool
20
+ True - add zero vector to be added to class_token, False - no vector added
21
+ num_modality: number of modalities. If 0, a single modality is assumed.
22
+ Otherwise one-hot modality encoding is added and sincos encoding size is appropriately reduced.
23
+
24
+ Returns
25
+ -------
26
+ positional_encoding : torch.Tensor
27
+ positional encoding to add to vit patch encodings
28
+ [num_modality*length*length, embedding_dim] or [1+num_modality*length*length, embedding_dim]
29
+ (w/ or w/o cls_token)
30
+ """
31
+
32
+ linear_positions = torch.arange(length, dtype=torch.float32)
33
+ height_mesh, width_mesh = torch.meshgrid(linear_positions, linear_positions, indexing="ij")
34
+ positional_dim = embedding_dim // 4 # accomodate h and w x cos and sin embeddings
35
+ positional_weights = torch.arange(positional_dim, dtype=torch.float32) / positional_dim
36
+ positional_weights = 1.0 / (scale**positional_weights)
37
+
38
+ height_weights = torch.outer(height_mesh.flatten(), positional_weights)
39
+ width_weights = torch.outer(width_mesh.flatten(), positional_weights)
40
+
41
+ positional_encoding = torch.cat(
42
+ [torch.sin(height_weights), torch.cos(height_weights), torch.sin(width_weights), torch.cos(width_weights)],
43
+ dim=1,
44
+ )[None, :, :]
45
+
46
+ # repeat positional encoding for multiple channel modalities
47
+ positional_encoding = positional_encoding.repeat(1, num_modality, 1)
48
+
49
+ if use_class_token:
50
+ class_token = torch.zeros([1, 1, embedding_dim], dtype=torch.float32)
51
+ positional_encoding = torch.cat([class_token, positional_encoding], dim=1)
52
+
53
+ positional_encoding = torch.nn.Parameter(positional_encoding, requires_grad=False)
54
+
55
+ return positional_encoding
56
+
57
+
58
+ class ChannelAgnosticPatchEmbed(vit.PatchEmbed): # type: ignore[misc]
59
+ def __init__(
60
+ self,
61
+ img_size: int,
62
+ patch_size: int,
63
+ embed_dim: int,
64
+ bias: bool = True,
65
+ ) -> None:
66
+ super().__init__(
67
+ img_size=img_size,
68
+ patch_size=patch_size,
69
+ in_chans=1, # in_chans is used by self.proj, which we override anyway
70
+ embed_dim=embed_dim,
71
+ norm_layer=None,
72
+ flatten=False,
73
+ bias=bias,
74
+ )
75
+ # channel-agnostic MAE has a single projection for all chans
76
+ self.proj = torch.nn.Conv2d(1, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias)
77
+
78
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
79
+ in_chans = x.shape[1]
80
+ x = torch.stack([self.proj(x[:, i : i + 1]) for i in range(in_chans)], dim=2) # single project for all chans
81
+ x = x.flatten(2).transpose(1, 2) # BCMHW -> BNC
82
+ return x
83
+
84
+
85
+ class ChannelAgnosticViT(vit.VisionTransformer): # type: ignore[misc]
86
+ def _pos_embed(self, x: torch.Tensor) -> torch.Tensor:
87
+ # rewrite https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L586
88
+ to_cat = []
89
+ if self.cls_token is not None:
90
+ to_cat.append(self.cls_token.expand(x.shape[0], -1, -1))
91
+
92
+ # TODO: upgrade timm to get access to register tokens
93
+ # if self.vit_backbone.reg_token is not None:
94
+ # to_cat.append(self.reg_token.expand(x.shape[0], -1, -1))
95
+
96
+ # MAIN DIFFERENCE with Timm - we DYNAMICALLY ADDING POS EMBEDDINGS based on shape of inputs
97
+ # this supports having CA-MAEs actually be channel-agnostic at inference time
98
+ if self.no_embed_class:
99
+ x = x + self.pos_embed[:, : x.shape[1]]
100
+ if to_cat:
101
+ x = torch.cat(to_cat + [x], dim=1)
102
+ else:
103
+ if to_cat:
104
+ x = torch.cat(to_cat + [x], dim=1)
105
+ x = x + self.pos_embed[:, : x.shape[1]]
106
+ return self.pos_drop(x) # type: ignore[no-any-return]
107
+
108
+
109
+ def channel_agnostic_vit(vit_backbone: vit.VisionTransformer, max_in_chans: int) -> vit.VisionTransformer:
110
+ # replace patch embedding with channel-agnostic version
111
+ vit_backbone.patch_embed = ChannelAgnosticPatchEmbed(
112
+ img_size=vit_backbone.patch_embed.img_size[0],
113
+ patch_size=vit_backbone.patch_embed.patch_size[0],
114
+ embed_dim=vit_backbone.embed_dim,
115
+ )
116
+
117
+ # replace positional embedding with channel-agnostic version
118
+ vit_backbone.pos_embed = generate_2d_sincos_pos_embeddings(
119
+ embedding_dim=vit_backbone.embed_dim,
120
+ length=vit_backbone.patch_embed.grid_size[0],
121
+ use_class_token=vit_backbone.cls_token is not None,
122
+ num_modality=max_in_chans,
123
+ )
124
+
125
+ # change the class to be ChannelAgnostic so that it actually uses the new _pos_embed
126
+ vit_backbone.__class__ = ChannelAgnosticViT
127
+ return vit_backbone
128
+
129
+
130
+ def sincos_positional_encoding_vit(
131
+ vit_backbone: vit.VisionTransformer, scale: float = 10000.0
132
+ ) -> vit.VisionTransformer:
133
+ """Attaches no-grad sin-cos positional embeddings to a pre-constructed ViT backbone model.
134
+
135
+ Parameters
136
+ ----------
137
+ vit_backbone : timm.models.vision_transformer.VisionTransformer
138
+ the constructed vision transformer from timm
139
+ scale : float (default 10000.0)
140
+ hyperparameter for sincos positional embeddings, recommend keeping at 10,000
141
+
142
+ Returns
143
+ -------
144
+ timm.models.vision_transformer.VisionTransformer
145
+ the same ViT but with fixed no-grad positional encodings to add to vit patch encodings
146
+ """
147
+ # length: number of tokens along height or width of image after patching (assuming square)
148
+ length = vit_backbone.patch_embed.img_size[0] // vit_backbone.patch_embed.patch_size[0]
149
+ pos_embeddings = generate_2d_sincos_pos_embeddings(
150
+ vit_backbone.embed_dim, length=length, scale=scale, use_class_token=vit_backbone.cls_token is not None
151
+ )
152
+ # note, if the model had weight_init == 'skip', this might get overwritten
153
+ vit_backbone.pos_embed = pos_embeddings
154
+ return vit_backbone
155
+
156
+
157
+ def vit_small_patch16_256(**kwargs):
158
+ default_kwargs = dict(
159
+ img_size=256,
160
+ in_chans=6,
161
+ num_classes=0,
162
+ fc_norm=None,
163
+ class_token=True,
164
+ drop_path_rate=0.1,
165
+ init_values=0.0001,
166
+ block_fn=vit.ParallelScalingBlock,
167
+ qkv_bias=False,
168
+ qk_norm=True,
169
+ )
170
+ for k, v in kwargs.items():
171
+ default_kwargs[k] = v
172
+ return vit.vit_small_patch16_224(**default_kwargs)
173
+
174
+
175
+ def vit_small_patch32_512(**kwargs):
176
+ default_kwargs = dict(
177
+ img_size=512,
178
+ in_chans=6,
179
+ num_classes=0,
180
+ fc_norm=None,
181
+ class_token=True,
182
+ drop_path_rate=0.1,
183
+ init_values=0.0001,
184
+ block_fn=vit.ParallelScalingBlock,
185
+ qkv_bias=False,
186
+ qk_norm=True,
187
+ )
188
+ for k, v in kwargs.items():
189
+ default_kwargs[k] = v
190
+ return vit.vit_small_patch32_384(**default_kwargs)
191
+
192
+
193
+ def vit_base_patch8_256(**kwargs):
194
+ default_kwargs = dict(
195
+ img_size=256,
196
+ in_chans=6,
197
+ num_classes=0,
198
+ fc_norm=None,
199
+ class_token=True,
200
+ drop_path_rate=0.1,
201
+ init_values=0.0001,
202
+ block_fn=vit.ParallelScalingBlock,
203
+ qkv_bias=False,
204
+ qk_norm=True,
205
+ )
206
+ for k, v in kwargs.items():
207
+ default_kwargs[k] = v
208
+ return vit.vit_base_patch8_224(**default_kwargs)
209
+
210
+
211
+ def vit_base_patch16_256(**kwargs):
212
+ default_kwargs = dict(
213
+ img_size=256,
214
+ in_chans=6,
215
+ num_classes=0,
216
+ fc_norm=None,
217
+ class_token=True,
218
+ drop_path_rate=0.1,
219
+ init_values=0.0001,
220
+ block_fn=vit.ParallelScalingBlock,
221
+ qkv_bias=False,
222
+ qk_norm=True,
223
+ )
224
+ for k, v in kwargs.items():
225
+ default_kwargs[k] = v
226
+ return vit.vit_base_patch16_224(**default_kwargs)
227
+
228
+
229
+ def vit_base_patch32_512(**kwargs):
230
+ default_kwargs = dict(
231
+ img_size=512,
232
+ in_chans=6,
233
+ num_classes=0,
234
+ fc_norm=None,
235
+ class_token=True,
236
+ drop_path_rate=0.1,
237
+ init_values=0.0001,
238
+ block_fn=vit.ParallelScalingBlock,
239
+ qkv_bias=False,
240
+ qk_norm=True,
241
+ )
242
+ for k, v in kwargs.items():
243
+ default_kwargs[k] = v
244
+ return vit.vit_base_patch32_384(**default_kwargs)
245
+
246
+
247
+ def vit_large_patch8_256(**kwargs):
248
+ default_kwargs = dict(
249
+ img_size=256,
250
+ in_chans=6,
251
+ num_classes=0,
252
+ fc_norm=None,
253
+ class_token=True,
254
+ patch_size=8,
255
+ embed_dim=1024,
256
+ depth=24,
257
+ num_heads=16,
258
+ drop_path_rate=0.3,
259
+ init_values=0.0001,
260
+ block_fn=vit.ParallelScalingBlock,
261
+ qkv_bias=False,
262
+ qk_norm=True,
263
+ )
264
+ for k, v in kwargs.items():
265
+ default_kwargs[k] = v
266
+ return vit.VisionTransformer(**default_kwargs)
267
+
268
+
269
+ def vit_large_patch16_256(**kwargs):
270
+ default_kwargs = dict(
271
+ img_size=256,
272
+ in_chans=6,
273
+ num_classes=0,
274
+ fc_norm=None,
275
+ class_token=True,
276
+ drop_path_rate=0.3,
277
+ init_values=0.0001,
278
+ block_fn=vit.ParallelScalingBlock,
279
+ qkv_bias=False,
280
+ qk_norm=True,
281
+ )
282
+ for k, v in kwargs.items():
283
+ default_kwargs[k] = v
284
+ return vit.vit_large_patch16_384(**default_kwargs)