Capx
/

Alyosha11 commited on
Commit
5e83696
·
verified ·
1 Parent(s): 5dc18c5

Upload 8 files

Browse files
Files changed (8) hide show
  1. VIDSatCLIP.ipynb +0 -0
  2. __init__.py +5 -0
  3. load.py +18 -0
  4. load_lightweight.py +36 -0
  5. location_encoder.py +275 -0
  6. loss.py +47 -0
  7. main.py +159 -0
  8. model.py +400 -0
VIDSatCLIP.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from . import *
2
+ from .main import *
3
+ from .model import *
4
+ from .loss import *
5
+ from .location_encoder import *
load.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from main import *
2
+
3
+ def get_satclip(ckpt_path, device, return_all=False):
4
+ ckpt = torch.load(ckpt_path,map_location=device)
5
+ ckpt['hyper_parameters'].pop('eval_downstream')
6
+ ckpt['hyper_parameters'].pop('air_temp_data_path')
7
+ ckpt['hyper_parameters'].pop('election_data_path')
8
+ lightning_model = SatCLIPLightningModule(**ckpt['hyper_parameters']).to(device)
9
+
10
+ lightning_model.load_state_dict(ckpt['state_dict'])
11
+ lightning_model.eval()
12
+
13
+ geo_model = lightning_model.model
14
+
15
+ if return_all:
16
+ return geo_model
17
+ else:
18
+ return geo_model.location
load_lightweight.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from location_encoder import get_neural_network, get_positional_encoding, LocationEncoder
3
+
4
+
5
+ def get_satclip_loc_encoder(ckpt_path, device):
6
+ ckpt = torch.load(ckpt_path,map_location=device)
7
+ hp = ckpt['hyper_parameters']
8
+
9
+ posenc = get_positional_encoding(
10
+ hp['le_type'],
11
+ hp['legendre_polys'],
12
+ hp['harmonics_calculation'],
13
+ hp['min_radius'],
14
+ hp['max_radius'],
15
+ hp['frequency_num']
16
+ )
17
+
18
+ nnet = get_neural_network(
19
+ hp['pe_type'],
20
+ posenc.embedding_dim,
21
+ hp['embed_dim'],
22
+ hp['capacity'],
23
+ hp['num_hidden_layers']
24
+ )
25
+
26
+ # only load nnet params from state dict
27
+ state_dict = ckpt['state_dict']
28
+ state_dict = {k[k.index('nnet'):]:state_dict[k]
29
+ for k in state_dict.keys() if 'nnet' in k}
30
+
31
+ loc_encoder = LocationEncoder(posenc, nnet).double()
32
+ loc_encoder.load_state_dict(state_dict)
33
+ loc_encoder.eval()
34
+
35
+ return loc_encoder
36
+
location_encoder.py ADDED
@@ -0,0 +1,275 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn, optim
2
+ import math
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from einops import rearrange
6
+ import numpy as np
7
+ from datetime import datetime
8
+ import positional_encoding as PE
9
+
10
+ """
11
+ FCNet
12
+ """
13
+ class ResLayer(nn.Module):
14
+ def __init__(self, linear_size):
15
+ super(ResLayer, self).__init__()
16
+ self.l_size = linear_size
17
+ self.nonlin1 = nn.ReLU(inplace=True)
18
+ self.nonlin2 = nn.ReLU(inplace=True)
19
+ self.dropout1 = nn.Dropout()
20
+ self.w1 = nn.Linear(self.l_size, self.l_size)
21
+ self.w2 = nn.Linear(self.l_size, self.l_size)
22
+
23
+ def forward(self, x):
24
+ y = self.w1(x)
25
+ y = self.nonlin1(y)
26
+ y = self.dropout1(y)
27
+ y = self.w2(y)
28
+ y = self.nonlin2(y)
29
+ out = x + y
30
+
31
+ return out
32
+
33
+ class FCNet(nn.Module):
34
+ def __init__(self, num_inputs, num_classes, dim_hidden):
35
+ super(FCNet, self).__init__()
36
+ self.inc_bias = False
37
+ self.class_emb = nn.Linear(dim_hidden, num_classes, bias=self.inc_bias)
38
+
39
+ self.feats = nn.Sequential(nn.Linear(num_inputs, dim_hidden),
40
+ nn.ReLU(inplace=True),
41
+ ResLayer(dim_hidden),
42
+ ResLayer(dim_hidden),
43
+ ResLayer(dim_hidden),
44
+ ResLayer(dim_hidden))
45
+
46
+ def forward(self, x):
47
+ loc_emb = self.feats(x)
48
+ class_pred = self.class_emb(loc_emb)
49
+ return class_pred
50
+
51
+ """A simple Multi Layer Perceptron"""
52
+ class MLP(nn.Module):
53
+ def __init__(self, input_dim, dim_hidden, num_layers, out_dims):
54
+ super(MLP, self).__init__()
55
+
56
+ layers = []
57
+ layers += [nn.Linear(input_dim, dim_hidden, bias=True), nn.ReLU()] # input layer
58
+ layers += [nn.Linear(dim_hidden, dim_hidden, bias=True), nn.ReLU()] * num_layers # hidden layers
59
+ layers += [nn.Linear(dim_hidden, out_dims, bias=True)] # output layer
60
+
61
+ self.features = nn.Sequential(*layers)
62
+
63
+ def forward(self, x):
64
+ return self.features(x)
65
+
66
+ def exists(val):
67
+ return val is not None
68
+
69
+ def cast_tuple(val, repeat = 1):
70
+ return val if isinstance(val, tuple) else ((val,) * repeat)
71
+
72
+ """Sinusoidal Representation Network (SIREN)"""
73
+ class SirenNet(nn.Module):
74
+ def __init__(self, dim_in, dim_hidden, dim_out, num_layers, w0 = 1., w0_initial = 30., use_bias = True, final_activation = None, degreeinput = False, dropout = True):
75
+ super().__init__()
76
+ self.num_layers = num_layers
77
+ self.dim_hidden = dim_hidden
78
+ self.degreeinput = degreeinput
79
+
80
+ self.layers = nn.ModuleList([])
81
+ for ind in range(num_layers):
82
+ is_first = ind == 0
83
+ layer_w0 = w0_initial if is_first else w0
84
+ layer_dim_in = dim_in if is_first else dim_hidden
85
+
86
+ self.layers.append(Siren(
87
+ dim_in = layer_dim_in,
88
+ dim_out = dim_hidden,
89
+ w0 = layer_w0,
90
+ use_bias = use_bias,
91
+ is_first = is_first,
92
+ dropout = dropout
93
+ ))
94
+
95
+ final_activation = nn.Identity() if not exists(final_activation) else final_activation
96
+ self.last_layer = Siren(dim_in = dim_hidden, dim_out = dim_out, w0 = w0, use_bias = use_bias, activation = final_activation, dropout = False)
97
+
98
+ def forward(self, x, mods = None):
99
+
100
+ # do some normalization to bring degrees in a -pi to pi range
101
+ if self.degreeinput:
102
+ x = torch.deg2rad(x) - torch.pi
103
+
104
+ mods = cast_tuple(mods, self.num_layers)
105
+
106
+ for layer, mod in zip(self.layers, mods):
107
+ x = layer(x)
108
+
109
+ if exists(mod):
110
+ x *= rearrange(mod, 'd -> () d')
111
+
112
+ return self.last_layer(x)
113
+
114
+ class Sine(nn.Module):
115
+ def __init__(self, w0 = 1.):
116
+ super().__init__()
117
+ self.w0 = w0
118
+ def forward(self, x):
119
+ return torch.sin(self.w0 * x)
120
+
121
+ class Siren(nn.Module):
122
+ def __init__(self, dim_in, dim_out, w0 = 1., c = 6., is_first = False, use_bias = True, activation = None, dropout = False):
123
+ super().__init__()
124
+ self.dim_in = dim_in
125
+ self.is_first = is_first
126
+ self.dim_out = dim_out
127
+ self.dropout = dropout
128
+
129
+ weight = torch.zeros(dim_out, dim_in)
130
+ bias = torch.zeros(dim_out) if use_bias else None
131
+ self.init_(weight, bias, c = c, w0 = w0)
132
+
133
+ self.weight = nn.Parameter(weight)
134
+ self.bias = nn.Parameter(bias) if use_bias else None
135
+ self.activation = Sine(w0) if activation is None else activation
136
+
137
+ def init_(self, weight, bias, c, w0):
138
+ dim = self.dim_in
139
+
140
+ w_std = (1 / dim) if self.is_first else (math.sqrt(c / dim) / w0)
141
+ weight.uniform_(-w_std, w_std)
142
+
143
+ if exists(bias):
144
+ bias.uniform_(-w_std, w_std)
145
+
146
+ def forward(self, x):
147
+ out = F.linear(x, self.weight, self.bias)
148
+ if self.dropout:
149
+ out = F.dropout(out, training=self.training)
150
+ out = self.activation(out)
151
+ return out
152
+
153
+
154
+ class Modulator(nn.Module):
155
+ def __init__(self, dim_in, dim_hidden, num_layers):
156
+ super().__init__()
157
+ self.layers = nn.ModuleList([])
158
+
159
+ for ind in range(num_layers):
160
+ is_first = ind == 0
161
+ dim = dim_in if is_first else (dim_hidden + dim_in)
162
+
163
+ self.layers.append(nn.Sequential(
164
+ nn.Linear(dim, dim_hidden),
165
+ nn.ReLU()
166
+ ))
167
+
168
+ def forward(self, z):
169
+ x = z
170
+ hiddens = []
171
+
172
+ for layer in self.layers:
173
+ x = layer(x)
174
+ hiddens.append(x)
175
+ x = torch.cat((x, z))
176
+
177
+ return tuple(hiddens)
178
+
179
+ class SirenWrapper(nn.Module):
180
+ def __init__(self, net, image_width, image_height, latent_dim = None):
181
+ super().__init__()
182
+ assert isinstance(net, SirenNet), 'SirenWrapper must receive a Siren network'
183
+
184
+ self.net = net
185
+ self.image_width = image_width
186
+ self.image_height = image_height
187
+
188
+ self.modulator = None
189
+ if exists(latent_dim):
190
+ self.modulator = Modulator(
191
+ dim_in = latent_dim,
192
+ dim_hidden = net.dim_hidden,
193
+ num_layers = net.num_layers
194
+ )
195
+
196
+ tensors = [torch.linspace(-1, 1, steps = image_height), torch.linspace(-1, 1, steps = image_width)]
197
+ mgrid = torch.stack(torch.meshgrid(*tensors, indexing = 'ij'), dim=-1)
198
+ mgrid = rearrange(mgrid, 'h w c -> (h w) c')
199
+ self.register_buffer('grid', mgrid)
200
+
201
+ def forward(self, img = None, *, latent = None):
202
+ modulate = exists(self.modulator)
203
+ assert not (modulate ^ exists(latent)), 'latent vector must be only supplied if `latent_dim` was passed in on instantiation'
204
+
205
+ mods = self.modulator(latent) if modulate else None
206
+
207
+ coords = self.grid.clone().detach().requires_grad_()
208
+ out = self.net(coords, mods)
209
+ out = rearrange(out, '(h w) c -> () c h w', h = self.image_height, w = self.image_width)
210
+
211
+ if exists(img):
212
+ return F.mse_loss(img, out)
213
+
214
+ return out
215
+
216
+ def get_positional_encoding(name, legendre_polys=10, harmonics_calculation='analytic', min_radius=1, max_radius=360, frequency_num=10):
217
+ if name == "direct":
218
+ return PE.Direct()
219
+ elif name == "cartesian3d":
220
+ return PE.Cartesian3D()
221
+ elif name == "sphericalharmonics":
222
+ if harmonics_calculation == 'discretized':
223
+ return PE.DiscretizedSphericalHarmonics(legendre_polys=legendre_polys)
224
+ else:
225
+ return PE.SphericalHarmonics(legendre_polys=legendre_polys,
226
+ harmonics_calculation=harmonics_calculation)
227
+ elif name == "theory":
228
+ return PE.Theory(min_radius=min_radius,
229
+ max_radius=max_radius,
230
+ frequency_num=frequency_num)
231
+ elif name == "wrap":
232
+ return PE.Wrap()
233
+ elif name in ["grid", "spherec", "spherecplus", "spherem", "spheremplus"]:
234
+ return PE.GridAndSphere(min_radius=min_radius,
235
+ max_radius=max_radius,
236
+ frequency_num=frequency_num,
237
+ name=name)
238
+ else:
239
+ raise ValueError(f"{name} not a known positional encoding.")
240
+
241
+ def get_neural_network(name, input_dim, num_classes=256, dim_hidden=256, num_layers=2):
242
+ if name == "linear":
243
+ return nn.Linear(input_dim, num_classes)
244
+ elif name == "mlp":
245
+ return MLP(
246
+ input_dim=input_dim,
247
+ dim_hidden=dim_hidden,
248
+ num_layers=num_layers,
249
+ out_dims=num_classes
250
+ )
251
+ elif name == "siren":
252
+ return SirenNet(
253
+ dim_in=input_dim,
254
+ dim_hidden=dim_hidden,
255
+ num_layers=num_layers,
256
+ dim_out=num_classes
257
+ )
258
+ elif name == "fcnet":
259
+ return FCNet(
260
+ num_inputs=input_dim,
261
+ num_classes=num_classes,
262
+ dim_hidden=dim_hidden
263
+ )
264
+ else:
265
+ raise ValueError(f"{name} not a known neural networks.")
266
+
267
+ class LocationEncoder(nn.Module):
268
+ def __init__(self, posenc, nnet):
269
+ super().__init__()
270
+ self.posenc = posenc
271
+ self.nnet = nnet
272
+
273
+ def forward(self, x):
274
+ x = self.posenc(x)
275
+ return self.nnet(x)
loss.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import torch.nn as nn
4
+
5
+ class SatCLIPLoss(nn.Module):
6
+
7
+ def __init__(
8
+ self,
9
+ local_loss=False,
10
+ cache_labels=False,
11
+ rank=0,
12
+ world_size=1,
13
+ ):
14
+ super().__init__()
15
+ self.local_loss = local_loss
16
+ self.cache_labels = cache_labels
17
+ self.rank = rank
18
+ self.world_size = world_size
19
+
20
+ # cache state
21
+ self.prev_num_logits = 0
22
+ self.labels = {}
23
+
24
+ def get_ground_truth(self, device, num_logits) -> torch.Tensor:
25
+ # calculated ground-truth and cache if enabled
26
+ if self.prev_num_logits != num_logits or device not in self.labels:
27
+ labels = torch.arange(num_logits, device=device, dtype=torch.long)
28
+ if self.world_size > 1 and self.local_loss:
29
+ labels = labels + num_logits * self.rank
30
+ if self.cache_labels:
31
+ self.labels[device] = labels
32
+ self.prev_num_logits = num_logits
33
+ else:
34
+ labels = self.labels[device]
35
+ return labels
36
+
37
+ def forward(self, logits_per_image, logits_per_coord, output_dict=False):
38
+ device = logits_per_image.device
39
+
40
+ labels = self.get_ground_truth(device, logits_per_image.shape[0])
41
+
42
+ total_loss = (
43
+ F.cross_entropy(logits_per_image, labels) +
44
+ F.cross_entropy(logits_per_coord, labels)
45
+ ) / 2
46
+
47
+ return {"contrastive_loss": total_loss} if output_dict else total_loss
main.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ from datetime import datetime
4
+
5
+ import lightning.pytorch
6
+ import torch
7
+ from datamodules.s2geo_dataset import S2GeoDataModule
8
+ from lightning.pytorch.callbacks import ModelCheckpoint
9
+ from lightning.pytorch.cli import LightningCLI
10
+ from loss import SatCLIPLoss
11
+ from model import SatCLIP
12
+
13
+ torch.set_float32_matmul_precision('high')
14
+
15
+ class SatCLIPLightningModule(lightning.pytorch.LightningModule):
16
+ def __init__(
17
+ self,
18
+ embed_dim=512,
19
+ image_resolution=256,
20
+ vision_layers=12,
21
+ vision_width=768,
22
+ vision_patch_size=32,
23
+ in_channels=4,
24
+ le_type="grid",
25
+ pe_type="siren",
26
+ frequency_num=16,
27
+ max_radius=260,
28
+ min_radius=1,
29
+ legendre_polys=16,
30
+ harmonics_calculation="analytic",
31
+ sh_embedding_dims=32,
32
+ learning_rate=1e-4,
33
+ weight_decay=0.01,
34
+ num_hidden_layers=2,
35
+ capacity=256,
36
+ ) -> None:
37
+ super().__init__()
38
+
39
+ self.model = SatCLIP(
40
+ embed_dim=embed_dim,
41
+ image_resolution=image_resolution,
42
+ vision_layers=vision_layers,
43
+ vision_width=vision_width,
44
+ vision_patch_size=vision_patch_size,
45
+ in_channels=in_channels,
46
+ le_type=le_type,
47
+ pe_type=pe_type,
48
+ frequency_num=frequency_num,
49
+ max_radius=max_radius,
50
+ min_radius=min_radius,
51
+ legendre_polys=legendre_polys,
52
+ harmonics_calculation=harmonics_calculation,
53
+ sh_embedding_dims=sh_embedding_dims,
54
+ num_hidden_layers=num_hidden_layers,
55
+ capacity=capacity,
56
+ )
57
+
58
+ self.loss_fun = SatCLIPLoss()
59
+ self.learning_rate = learning_rate
60
+ self.weight_decay = weight_decay
61
+ self.save_hyperparameters()
62
+
63
+ def common_step(self, batch, batch_idx):
64
+ images = batch["image"]
65
+ t_points = batch["point"].float()
66
+ logits_per_image, logits_per_coord = self.model(images, t_points)
67
+ return self.loss_fun(logits_per_image, logits_per_coord)
68
+
69
+ def training_step(self, batch, batch_idx):
70
+ loss = self.common_step(batch, batch_idx)
71
+ self.log("train_loss", loss)
72
+ return loss
73
+
74
+ def validation_step(self, batch, batch_idx):
75
+ loss = self.common_step(batch, batch_idx)
76
+ self.log("val_loss", loss)
77
+ return loss
78
+
79
+ def configure_optimizers(self):
80
+ exclude = (
81
+ lambda n, p: p.ndim < 2
82
+ or "bn" in n
83
+ or "ln" in n
84
+ or "bias" in n
85
+ or "logit_scale" in n
86
+ )
87
+ include = lambda n, p: not exclude(n, p)
88
+
89
+ named_parameters = list(self.model.named_parameters())
90
+ gain_or_bias_params = [
91
+ p for n, p in named_parameters if exclude(n, p) and p.requires_grad
92
+ ]
93
+ rest_params = [
94
+ p for n, p in named_parameters if include(n, p) and p.requires_grad
95
+ ]
96
+
97
+ optimizer = torch.optim.AdamW(
98
+ [
99
+ {"params": gain_or_bias_params, "weight_decay": 0.0},
100
+ {
101
+ "params": rest_params,
102
+ "weight_decay": self.weight_decay,
103
+ }, # specify in configs/default.yaml
104
+ ],
105
+ lr=self.learning_rate, # specify in configs/default.yaml
106
+ )
107
+
108
+ return optimizer
109
+
110
+
111
+ class MyLightningCLI(LightningCLI):
112
+ def add_arguments_to_parser(self, parser):
113
+ parser.add_argument("--watchmodel", action="store_true")
114
+
115
+
116
+ def cli_main(default_config_filename="/configs/default.yaml"):
117
+
118
+
119
+ save_config_fn = default_config_filename.replace(".yaml", "-latest.yaml")
120
+ # modify configs/default.yaml for learning rate etc.
121
+ cli = MyLightningCLI(
122
+ model_class=SatCLIPLightningModule,
123
+ datamodule_class=S2GeoDataModule,
124
+ save_config_kwargs=dict(
125
+ config_filename=save_config_fn,
126
+ overwrite=True,
127
+ ),
128
+ trainer_defaults={
129
+ "accumulate_grad_batches": 16,
130
+ "log_every_n_steps": 10,
131
+ },
132
+ parser_kwargs={"default_config_files": [default_config_filename]},
133
+ seed_everything_default=0,
134
+ run=False,
135
+ )
136
+
137
+ ts = datetime.now().strftime("%Y-%m-%d_%H:%M:%S")
138
+ run_name = f"SatCLIP_S2_{ts}"
139
+ if cli.trainer.logger is not None:
140
+ cli.trainer.logger.experiment.name = run_name
141
+ # this seems to be necessary to force logging of datamodule hyperparams
142
+ cli.trainer.logger.log_hyperparams(cli.datamodule.hparams)
143
+
144
+ cli.trainer.fit(
145
+ model=cli.model,
146
+ datamodule=cli.datamodule,
147
+ )
148
+
149
+
150
+ if __name__ == "__main__":
151
+ config_fn = "./configs/default.yaml"
152
+
153
+ #A100 go vroom vroom 🚗💨
154
+ if torch.cuda.get_device_name(device=0)=='NVIDIA A100 80GB PCIe':
155
+ torch.backends.cuda.matmul.allow_tf32 = True
156
+ print('Superfastmode! 🚀')
157
+ else:
158
+ torch.backends.cuda.matmul.allow_tf32 = False
159
+ cli_main(config_fn)
model.py ADDED
@@ -0,0 +1,400 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import OrderedDict
2
+ from typing import Tuple, Union, Optional
3
+
4
+ import numpy as np
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from torch import nn
8
+ import math
9
+
10
+ import timm
11
+ import torchgeo.models
12
+ from torchgeo.models import ResNet18_Weights, ResNet50_Weights, ViTSmall16_Weights
13
+ from location_encoder import get_positional_encoding, get_neural_network, LocationEncoder
14
+ from datamodules.s2geo_dataset import S2Geo
15
+
16
+ class Bottleneck(nn.Module):
17
+ expansion = 4
18
+
19
+ def __init__(self, inplanes, planes, stride=1):
20
+ super().__init__()
21
+
22
+ # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
23
+ self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
24
+ self.bn1 = nn.BatchNorm2d(planes)
25
+ self.relu1 = nn.ReLU(inplace=True)
26
+
27
+ self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
28
+ self.bn2 = nn.BatchNorm2d(planes)
29
+ self.relu2 = nn.ReLU(inplace=True)
30
+
31
+ self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
32
+
33
+ self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
34
+ self.bn3 = nn.BatchNorm2d(planes * self.expansion)
35
+ self.relu3 = nn.ReLU(inplace=True)
36
+
37
+ self.downsample = None
38
+ self.stride = stride
39
+
40
+ if stride > 1 or inplanes != planes * Bottleneck.expansion:
41
+ # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
42
+ self.downsample = nn.Sequential(OrderedDict([
43
+ ("-1", nn.AvgPool2d(stride)),
44
+ ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
45
+ ("1", nn.BatchNorm2d(planes * self.expansion))
46
+ ]))
47
+
48
+ def forward(self, x: torch.Tensor):
49
+ identity = x
50
+
51
+ out = self.relu1(self.bn1(self.conv1(x)))
52
+ out = self.relu2(self.bn2(self.conv2(out)))
53
+ out = self.avgpool(out)
54
+ out = self.bn3(self.conv3(out))
55
+
56
+ if self.downsample is not None:
57
+ identity = self.downsample(x)
58
+
59
+ out += identity
60
+ out = self.relu3(out)
61
+ return out
62
+
63
+
64
+ class AttentionPool2d(nn.Module):
65
+ def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
66
+ super().__init__()
67
+ self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
68
+ self.k_proj = nn.Linear(embed_dim, embed_dim)
69
+ self.q_proj = nn.Linear(embed_dim, embed_dim)
70
+ self.v_proj = nn.Linear(embed_dim, embed_dim)
71
+ self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
72
+ self.num_heads = num_heads
73
+
74
+ def forward(self, x):
75
+ x = x.flatten(start_dim=2).permute(2, 0, 1) # NCHW -> (HW)NC
76
+ x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
77
+ x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
78
+ x, _ = F.multi_head_attention_forward(
79
+ query=x[:1], key=x, value=x,
80
+ embed_dim_to_check=x.shape[-1],
81
+ num_heads=self.num_heads,
82
+ q_proj_weight=self.q_proj.weight,
83
+ k_proj_weight=self.k_proj.weight,
84
+ v_proj_weight=self.v_proj.weight,
85
+ in_proj_weight=None,
86
+ in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
87
+ bias_k=None,
88
+ bias_v=None,
89
+ add_zero_attn=False,
90
+ dropout_p=0,
91
+ out_proj_weight=self.c_proj.weight,
92
+ out_proj_bias=self.c_proj.bias,
93
+ use_separate_proj_weight=True,
94
+ training=self.training,
95
+ need_weights=False
96
+ )
97
+ return x.squeeze(0)
98
+
99
+
100
+ class ModifiedResNet(nn.Module):
101
+ """
102
+ A ResNet class that is similar to torchvision's but contains the following changes:
103
+ - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
104
+ - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
105
+ - The final pooling layer is a QKV attention instead of an average pool
106
+ """
107
+
108
+ def __init__(self, layers, output_dim, heads, input_resolution=224, width=64, in_channels=3):
109
+ super().__init__()
110
+ self.output_dim = output_dim
111
+ self.input_resolution = input_resolution
112
+
113
+ # the 3-layer stem
114
+ self.conv1 = nn.Conv2d(in_channels, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
115
+ self.bn1 = nn.BatchNorm2d(width // 2)
116
+ self.relu1 = nn.ReLU(inplace=True)
117
+ self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)
118
+ self.bn2 = nn.BatchNorm2d(width // 2)
119
+ self.relu2 = nn.ReLU(inplace=True)
120
+ self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
121
+ self.bn3 = nn.BatchNorm2d(width)
122
+ self.relu3 = nn.ReLU(inplace=True)
123
+ self.avgpool = nn.AvgPool2d(2)
124
+
125
+ # residual layers
126
+ self._inplanes = width # this is a *mutable* variable used during construction
127
+ self.layer1 = self._make_layer(width, layers[0])
128
+ self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
129
+ self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
130
+ self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
131
+
132
+ embed_dim = width * 32 # the ResNet feature dimension
133
+ self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim)
134
+
135
+ def _make_layer(self, planes, blocks, stride=1):
136
+ layers = [Bottleneck(self._inplanes, planes, stride)]
137
+
138
+ self._inplanes = planes * Bottleneck.expansion
139
+ for _ in range(1, blocks):
140
+ layers.append(Bottleneck(self._inplanes, planes))
141
+
142
+ return nn.Sequential(*layers)
143
+
144
+ def forward(self, x):
145
+ def stem(x):
146
+ x = self.relu1(self.bn1(self.conv1(x)))
147
+ x = self.relu2(self.bn2(self.conv2(x)))
148
+ x = self.relu3(self.bn3(self.conv3(x)))
149
+ x = self.avgpool(x)
150
+ return x
151
+
152
+ x = x.type(self.conv1.weight.dtype)
153
+ x = stem(x)
154
+ x = self.layer1(x)
155
+ x = self.layer2(x)
156
+ x = self.layer3(x)
157
+ x = self.layer4(x)
158
+ x = self.attnpool(x)
159
+
160
+ return x
161
+
162
+
163
+ class LayerNorm(nn.LayerNorm):
164
+ """Subclass torch's LayerNorm to handle fp16."""
165
+
166
+ def forward(self, x: torch.Tensor):
167
+ orig_type = x.dtype
168
+ ret = super().forward(x.type(torch.float32))
169
+ return ret.type(orig_type)
170
+
171
+
172
+ class QuickGELU(nn.Module):
173
+ def forward(self, x: torch.Tensor):
174
+ return x * torch.sigmoid(1.702 * x)
175
+
176
+
177
+ class ResidualAttentionBlock(nn.Module):
178
+ def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
179
+ super().__init__()
180
+
181
+ self.attn = nn.MultiheadAttention(d_model, n_head)
182
+ self.ln_1 = LayerNorm(d_model)
183
+ self.mlp = nn.Sequential(OrderedDict([
184
+ ("c_fc", nn.Linear(d_model, d_model * 4)),
185
+ ("gelu", QuickGELU()),
186
+ ("c_proj", nn.Linear(d_model * 4, d_model))
187
+ ]))
188
+ self.ln_2 = LayerNorm(d_model)
189
+ self.attn_mask = attn_mask
190
+
191
+ def attention(self, x: torch.Tensor):
192
+ self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
193
+ return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
194
+
195
+ def forward(self, x: torch.Tensor):
196
+ x = x + self.attention(self.ln_1(x))
197
+ x = x + self.mlp(self.ln_2(x))
198
+ return x
199
+
200
+
201
+ class Transformer(nn.Module):
202
+ def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None):
203
+ super().__init__()
204
+ self.width = width
205
+ self.layers = layers
206
+ self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)])
207
+
208
+ def forward(self, x: torch.Tensor):
209
+ return self.resblocks(x)
210
+
211
+
212
+ class VisionTransformer(nn.Module):
213
+ def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, in_channels: int, output_dim: int):
214
+ super().__init__()
215
+ self.input_resolution = input_resolution
216
+ self.output_dim = output_dim
217
+ self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
218
+
219
+ scale = width ** -0.5
220
+ self.class_embedding = nn.Parameter(scale * torch.randn(width))
221
+ self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width))
222
+ self.ln_pre = LayerNorm(width)
223
+
224
+ self.transformer = Transformer(width, layers, heads)
225
+
226
+ self.ln_post = LayerNorm(width)
227
+ self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
228
+
229
+ def forward(self, x: torch.Tensor):
230
+ x = self.conv1(x) # shape = [*, width, grid, grid]
231
+ x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
232
+ x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
233
+ x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
234
+ x = x + self.positional_embedding.to(x.dtype)
235
+ x = self.ln_pre(x)
236
+
237
+ x = x.permute(1, 0, 2) # NLD -> LND
238
+ x = self.transformer(x)
239
+ x = x.permute(1, 0, 2) # LND -> NLD
240
+
241
+ x = self.ln_post(x[:, 0, :])
242
+
243
+ if self.proj is not None:
244
+ x = x @ self.proj
245
+
246
+ return x
247
+
248
+ class SatCLIP(nn.Module):
249
+ def __init__(self,
250
+ embed_dim: int,
251
+ # vision
252
+ image_resolution: int,
253
+ vision_layers: Union[Tuple[int, int, int, int], int, str],
254
+ vision_width: int,
255
+ vision_patch_size: int,
256
+ in_channels: int,
257
+ # location
258
+ le_type: str,
259
+ pe_type: str,
260
+ frequency_num: int,
261
+ max_radius: int,
262
+ min_radius: int,
263
+ harmonics_calculation: str,
264
+ legendre_polys: int=10,
265
+ sh_embedding_dims: int=16,
266
+ ffn: bool=True,
267
+ num_hidden_layers: int=2,
268
+ capacity: int=256,
269
+ *args,
270
+ **kwargs
271
+ ):
272
+ super().__init__()
273
+
274
+ if isinstance(vision_layers, (tuple, list)):
275
+ print('using modified resnet')
276
+ vision_heads = vision_width * 32 // 64
277
+ self.visual = ModifiedResNet(
278
+ layers=vision_layers,
279
+ output_dim=embed_dim,
280
+ heads=vision_heads,
281
+ input_resolution=image_resolution,
282
+ width=vision_width,
283
+ in_channels=in_channels
284
+ )
285
+
286
+ elif vision_layers == 'moco_resnet18':
287
+ print('using pretrained moco resnet18')
288
+ weights = ResNet18_Weights.SENTINEL2_ALL_MOCO
289
+ in_chans = weights.meta["in_chans"]
290
+ self.visual = timm.create_model("resnet18", in_chans=in_chans, num_classes=embed_dim)
291
+ self.visual.load_state_dict(weights.get_state_dict(progress=True), strict=False)
292
+ self.visual.requires_grad_(False)
293
+ self.visual.fc.requires_grad_(True)
294
+
295
+ elif vision_layers == 'moco_resnet50':
296
+ print('using pretrained moco resnet50')
297
+ weights = ResNet50_Weights.SENTINEL2_ALL_MOCO
298
+ in_chans = weights.meta["in_chans"]
299
+ self.visual = timm.create_model("resnet50", in_chans=in_chans, num_classes=embed_dim)
300
+ self.visual.load_state_dict(weights.get_state_dict(progress=True), strict=False)
301
+ self.visual.requires_grad_(False)
302
+ self.visual.fc.requires_grad_(True)
303
+
304
+ elif vision_layers == 'moco_vit16':
305
+ print('using pretrained moco vit16')
306
+ weights = ViTSmall16_Weights.SENTINEL2_ALL_MOCO
307
+ in_chans = weights.meta["in_chans"]
308
+ self.visual = timm.create_model("vit_small_patch16_224", in_chans=in_chans, num_classes=embed_dim)
309
+ self.visual.load_state_dict(weights.get_state_dict(progress=True), strict=False)
310
+ self.visual.requires_grad_(False)
311
+ self.visual.head.requires_grad_(True)
312
+
313
+ else:
314
+ print('using vision transformer')
315
+ vision_heads = vision_width // 64
316
+ self.visual = VisionTransformer(
317
+ input_resolution=image_resolution,
318
+ patch_size=vision_patch_size,
319
+ width=vision_width,
320
+ layers=vision_layers,
321
+ heads=vision_heads,
322
+ output_dim=embed_dim,
323
+ in_channels=in_channels
324
+ )
325
+
326
+ self.posenc = get_positional_encoding(name=le_type, harmonics_calculation=harmonics_calculation, legendre_polys=legendre_polys, min_radius=min_radius, max_radius=max_radius, frequency_num=frequency_num).double()
327
+ self.nnet = get_neural_network(name=pe_type, input_dim=self.posenc.embedding_dim, num_classes=embed_dim, dim_hidden=capacity, num_layers=num_hidden_layers).double()
328
+ self.location = LocationEncoder(self.posenc,
329
+ self.nnet
330
+ ).double()
331
+
332
+ self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
333
+
334
+ self.initialize_parameters()
335
+
336
+ def initialize_parameters(self):
337
+ if isinstance(self.visual, ModifiedResNet):
338
+ if self.visual.attnpool is not None:
339
+ std = self.visual.attnpool.c_proj.in_features ** -0.5
340
+ nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std)
341
+ nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std)
342
+ nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std)
343
+ nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std)
344
+
345
+ for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]:
346
+ for name, param in resnet_block.named_parameters():
347
+ if name.endswith("bn3.weight"):
348
+ nn.init.zeros_(param)
349
+
350
+ @property
351
+ def dtype(self):
352
+ if isinstance(self.visual, timm.models.vision_transformer.VisionTransformer):
353
+ return self.visual.patch_embed.proj.weight.dtype
354
+ else:
355
+ return self.visual.conv1.weight.dtype
356
+
357
+ def encode_image(self, image):
358
+ return self.visual(image.type(self.dtype))
359
+
360
+ def encode_location(self, coords):
361
+ return self.location(coords.double())
362
+
363
+ def forward(self, image, coords):
364
+
365
+ image_features = self.encode_image(image)
366
+ location_features = self.encode_location(coords).float()
367
+ # normalized features
368
+ image_features = image_features / image_features.norm(dim=1, keepdim=True)
369
+ location_features = location_features / location_features.norm(dim=1, keepdim=True)
370
+
371
+ # cosine similarity as logits
372
+ logit_scale = self.logit_scale.exp()
373
+ logits_per_image = logit_scale * image_features @ location_features.t()
374
+ logits_per_location = logits_per_image.t()
375
+
376
+ # shape = [global_batch_size, global_batch_size]
377
+ return logits_per_image, logits_per_location
378
+
379
+ def convert_weights(model: nn.Module):
380
+ """Convert applicable model parameters to fp16"""
381
+
382
+ def _convert_weights_to_fp16(l):
383
+ if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
384
+ l.weight.data = l.weight.data.half()
385
+ if l.bias is not None:
386
+ l.bias.data = l.bias.data.half()
387
+
388
+ if isinstance(l, nn.MultiheadAttention):
389
+ for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
390
+ tensor = getattr(l, attr)
391
+ if tensor is not None:
392
+ tensor.data = tensor.data.half()
393
+
394
+ for name in ["text_projection", "proj"]:
395
+ if hasattr(l, name):
396
+ attr = getattr(l, name)
397
+ if attr is not None:
398
+ attr.data = attr.data.half()
399
+
400
+ model.apply(_convert_weights_to_fp16)