Vishu26 commited on
Commit
fa28aab
·
1 Parent(s): 2c3470d
.gitattributes CHANGED
@@ -36,3 +36,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
36
  data/class_70b.npy filter=lfs diff=lfs merge=lfs -text
37
  data/order_70b.npy filter=lfs diff=lfs merge=lfs -text
38
  data/species_70b.npy filter=lfs diff=lfs merge=lfs -text
 
 
 
36
  data/class_70b.npy filter=lfs diff=lfs merge=lfs -text
37
  data/order_70b.npy filter=lfs diff=lfs merge=lfs -text
38
  data/species_70b.npy filter=lfs diff=lfs merge=lfs -text
39
+ data/pos_embeds_model.npy filter=lfs diff=lfs merge=lfs -text
40
+ model/demo_model.pt filter=lfs diff=lfs merge=lfs -text
app.py CHANGED
@@ -13,6 +13,8 @@ order_list = list(order[()].keys())
13
  #genus_list = list(genus[()].keys())
14
  #family_list = list(family[()].keys())
15
 
 
 
16
  def update_fn(val):
17
  if val=="Class":
18
  return gr.Dropdown(label="Name", choices=class_list, interactive=True)
@@ -25,6 +27,20 @@ def update_fn(val):
25
  elif val=="Species":
26
  return gr.Dropdown(label="Name", choices=species_list, interactive=True)
27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  with gr.Blocks() as demo:
29
  gr.Markdown(
30
  """
@@ -39,5 +55,10 @@ with gr.Blocks() as demo:
39
  with gr.Row():
40
  submit_button = gr.Button("Run Model")
41
 
 
 
 
 
42
 
 
43
  demo.launch()
 
13
  #genus_list = list(genus[()].keys())
14
  #family_list = list(family[()].keys())
15
 
16
+ pos_embed = np.load("data/pos_embed.npy", allow_pickle=True)
17
+
18
  def update_fn(val):
19
  if val=="Class":
20
  return gr.Dropdown(label="Name", choices=class_list, interactive=True)
 
27
  elif val=="Species":
28
  return gr.Dropdown(label="Name", choices=species_list, interactive=True)
29
 
30
+ def pred_fn(taxon, name):
31
+ if taxon=="Class":
32
+ text_embeds = clas[()][name]
33
+ elif taxon=="Order":
34
+ text_embeds = order[()][name]
35
+ elif taxon=="Family":
36
+ text_embeds = family[()][name]
37
+ elif taxon=="Genus":
38
+ text_embeds = genus[()][name]
39
+ elif taxon=="Species":
40
+ text_embeds = species[()][name]
41
+
42
+
43
+
44
  with gr.Blocks() as demo:
45
  gr.Markdown(
46
  """
 
55
  with gr.Row():
56
  submit_button = gr.Button("Run Model")
57
 
58
+ with gr.Row():
59
+ pred = gr.Image(label="Predicted Heatmap", visible=False)
60
+
61
+ submit_button.click(pred_fn, inputs=[inp, out], outputs=[pred])
62
 
63
+
64
  demo.launch()
data/pos_embeds_model.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e010369a67f1d2946dd787494a65f88c1ed79d1cc6d4a5be3f5ac98568492630
3
+ size 829440128
model/demo_model.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b5d592a3a09086658aeac9b51574f11977962c0f2d5703e0225c3a236be4592d
3
+ size 76024944
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ numpy==1.23.4
2
+ torch==2.0.1
3
+ rasterio==1.3.8
4
+ einops==0.6.1
sfno_encoder.py ADDED
@@ -0,0 +1,479 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from torch_harmonics import *
5
+
6
+ from torch_harmonics.examples.sfno.models.layers import *
7
+
8
+ from functools import partial
9
+
10
+ from einops import repeat
11
+
12
+ import numpy as np
13
+
14
+ class SpectralFilterLayer(nn.Module):
15
+ """
16
+ Fourier layer. Contains the convolution part of the FNO/SFNO
17
+ """
18
+
19
+ def __init__(
20
+ self,
21
+ forward_transform,
22
+ inverse_transform,
23
+ embed_dim,
24
+ filter_type = 'non-linear',
25
+ operator_type = 'diagonal',
26
+ sparsity_threshold = 0.0,
27
+ use_complex_kernels = True,
28
+ hidden_size_factor = 2,
29
+ factorization = None,
30
+ separable = False,
31
+ rank = 1e-2,
32
+ complex_activation = 'real',
33
+ spectral_layers = 1,
34
+ drop_rate = 0):
35
+ super(SpectralFilterLayer, self).__init__()
36
+
37
+ if filter_type == 'non-linear' and isinstance(forward_transform, RealSHT):
38
+ self.filter = SpectralAttentionS2(forward_transform,
39
+ inverse_transform,
40
+ embed_dim,
41
+ operator_type = operator_type,
42
+ sparsity_threshold = sparsity_threshold,
43
+ hidden_size_factor = hidden_size_factor,
44
+ complex_activation = complex_activation,
45
+ spectral_layers = spectral_layers,
46
+ drop_rate = drop_rate,
47
+ bias = False)
48
+
49
+ elif filter_type == 'non-linear' and isinstance(forward_transform, RealFFT2):
50
+ self.filter = SpectralAttention2d(forward_transform,
51
+ inverse_transform,
52
+ embed_dim,
53
+ sparsity_threshold = sparsity_threshold,
54
+ use_complex_kernels = use_complex_kernels,
55
+ hidden_size_factor = hidden_size_factor,
56
+ complex_activation = complex_activation,
57
+ spectral_layers = spectral_layers,
58
+ drop_rate = drop_rate,
59
+ bias = False)
60
+
61
+ elif filter_type == 'linear':
62
+ self.filter = SpectralConvS2(forward_transform,
63
+ inverse_transform,
64
+ embed_dim,
65
+ embed_dim,
66
+ operator_type = operator_type,
67
+ rank = rank,
68
+ factorization = factorization,
69
+ separable = separable,
70
+ bias = True)
71
+
72
+ else:
73
+ raise(NotImplementedError)
74
+
75
+ def forward(self, x):
76
+ return self.filter(x)
77
+
78
+ class SphericalFourierNeuralOperatorBlock(nn.Module):
79
+ """
80
+ Helper module for a single SFNO/FNO block. Can use both FFTs and SHTs to represent either FNO or SFNO blocks.
81
+ """
82
+ def __init__(
83
+ self,
84
+ forward_transform,
85
+ inverse_transform,
86
+ embed_dim,
87
+ filter_type = 'non-linear',
88
+ operator_type = 'diagonal',
89
+ mlp_ratio = 2.,
90
+ drop_rate = 0.,
91
+ drop_path = 0.,
92
+ act_layer = nn.GELU,
93
+ norm_layer = (nn.LayerNorm, nn.LayerNorm),
94
+ sparsity_threshold = 0.0,
95
+ use_complex_kernels = True,
96
+ factorization = None,
97
+ separable = False,
98
+ rank = 128,
99
+ inner_skip = 'linear',
100
+ outer_skip = None, # None, nn.linear or nn.Identity
101
+ concat_skip = False,
102
+ use_mlp = True,
103
+ complex_activation = 'real',
104
+ spectral_layers = 3):
105
+ super(SphericalFourierNeuralOperatorBlock, self).__init__()
106
+
107
+ # norm layer
108
+ self.norm0 = norm_layer[0]() #((h,w))
109
+
110
+ # convolution layer
111
+ self.filter = SpectralFilterLayer(forward_transform,
112
+ inverse_transform,
113
+ embed_dim,
114
+ filter_type,
115
+ operator_type = operator_type,
116
+ sparsity_threshold = sparsity_threshold,
117
+ use_complex_kernels = use_complex_kernels,
118
+ hidden_size_factor = mlp_ratio,
119
+ factorization = factorization,
120
+ separable = separable,
121
+ rank = rank,
122
+ complex_activation = complex_activation,
123
+ spectral_layers = spectral_layers,
124
+ drop_rate = drop_rate)
125
+
126
+ if inner_skip == 'linear':
127
+ self.inner_skip = nn.Conv2d(embed_dim, embed_dim, 1, 1)
128
+ elif inner_skip == 'identity':
129
+ self.inner_skip = nn.Identity()
130
+
131
+ self.concat_skip = concat_skip
132
+
133
+ if concat_skip and inner_skip is not None:
134
+ self.inner_skip_conv = nn.Conv2d(2*embed_dim, embed_dim, 1, bias=False)
135
+
136
+ if filter_type == 'linear' or filter_type == 'local':
137
+ self.act_layer = act_layer()
138
+
139
+ # dropout
140
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
141
+
142
+ # norm layer
143
+ self.norm1 = norm_layer[1]() #((h,w))
144
+
145
+ if use_mlp == True:
146
+ mlp_hidden_dim = int(embed_dim * mlp_ratio)
147
+ self.mlp = MLP(in_features = embed_dim,
148
+ hidden_features = mlp_hidden_dim,
149
+ act_layer = act_layer,
150
+ drop_rate = drop_rate,
151
+ checkpointing = False)
152
+
153
+ if outer_skip == 'linear':
154
+ self.outer_skip = nn.Conv2d(embed_dim, embed_dim, 1, 1)
155
+ elif outer_skip == 'identity':
156
+ self.outer_skip = nn.Identity()
157
+
158
+ if concat_skip and outer_skip is not None:
159
+ self.outer_skip_conv = nn.Conv2d(2*embed_dim, embed_dim, 1, bias=False)
160
+
161
+ def forward(self, x):
162
+
163
+ x = self.norm0(x)
164
+
165
+ x, residual = self.filter(x)
166
+
167
+ if hasattr(self, 'inner_skip'):
168
+ if self.concat_skip:
169
+ x = torch.cat((x, self.inner_skip(residual)), dim=1)
170
+ x = self.inner_skip_conv(x)
171
+ else:
172
+ x = x + self.inner_skip(residual)
173
+
174
+ if hasattr(self, 'act_layer'):
175
+ x = self.act_layer(x)
176
+
177
+ x = self.norm1(x)
178
+
179
+ if hasattr(self, 'mlp'):
180
+ x = self.mlp(x)
181
+
182
+ x = self.drop_path(x)
183
+
184
+ if hasattr(self, 'outer_skip'):
185
+ if self.concat_skip:
186
+ x = torch.cat((x, self.outer_skip(residual)), dim=1)
187
+ x = self.outer_skip_conv(x)
188
+ else:
189
+ x = x + self.outer_skip(residual)
190
+
191
+ return x
192
+
193
+ class SphericalFourierNeuralOperatorNet(nn.Module):
194
+ """
195
+ SphericalFourierNeuralOperator module. Can use both FFTs and SHTs to represent either FNO or SFNO,
196
+ both linear and non-linear variants.
197
+
198
+ Parameters
199
+ ----------
200
+ filter_type : str, optional
201
+ Type of filter to use ('linear', 'non-linear'), by default "linear"
202
+ spectral_transform : str, optional
203
+ Type of spectral transformation to use, by default "sht"
204
+ operator_type : str, optional
205
+ Type of operator to use ('vector', 'diagonal'), by default "vector"
206
+ img_shape : tuple, optional
207
+ Shape of the input channels, by default (128, 256)
208
+ scale_factor : int, optional
209
+ Scale factor to use, by default 3
210
+ in_chans : int, optional
211
+ Number of input channels, by default 3
212
+ out_chans : int, optional
213
+ Number of output channels, by default 3
214
+ embed_dim : int, optional
215
+ Dimension of the embeddings, by default 256
216
+ num_layers : int, optional
217
+ Number of layers in the network, by default 4
218
+ activation_function : str, optional
219
+ Activation function to use, by default "gelu"
220
+ encoder_layers : int, optional
221
+ Number of layers in the encoder, by default 1
222
+ use_mlp : int, optional
223
+ Whether to use MLP, by default True
224
+ mlp_ratio : int, optional
225
+ Ratio of MLP to use, by default 2.0
226
+ drop_rate : float, optional
227
+ Dropout rate, by default 0.0
228
+ drop_path_rate : float, optional
229
+ Dropout path rate, by default 0.0
230
+ sparsity_threshold : float, optional
231
+ Threshold for sparsity, by default 0.0
232
+ normalization_layer : str, optional
233
+ Type of normalization layer to use ("layer_norm", "instance_norm", "none"), by default "instance_norm"
234
+ hard_thresholding_fraction : float, optional
235
+ Fraction of hard thresholding (frequency cutoff) to apply, by default 1.0
236
+ use_complex_kernels : bool, optional
237
+ Whether to use complex kernels, by default True
238
+ big_skip : bool, optional
239
+ Whether to add a single large skip connection, by default True
240
+ rank : float, optional
241
+ Rank of the approximation, by default 1.0
242
+ factorization : Any, optional
243
+ Type of factorization to use, by default None
244
+ separable : bool, optional
245
+ Whether to use separable convolutions, by default False
246
+ rank : (int, Tuple[int]), optional
247
+ If a factorization is used, which rank to use. Argument is passed to tensorly
248
+ complex_activation : str, optional
249
+ Type of complex activation function to use, by default "real"
250
+ spectral_layers : int, optional
251
+ Number of spectral layers, by default 3
252
+ pos_embed : bool, optional
253
+ Whether to use positional embedding, by default True
254
+
255
+ Example:
256
+ --------
257
+ >>> model = SphericalFourierNeuralOperatorNet(
258
+ ... img_shape=(128, 256),
259
+ ... scale_factor=4,
260
+ ... in_chans=2,
261
+ ... out_chans=2,
262
+ ... embed_dim=16,
263
+ ... num_layers=2,
264
+ ... encoder_layers=1,
265
+ ... num_blocks=4,
266
+ ... spectral_layers=2,
267
+ ... use_mlp=True,)
268
+ >>> model(torch.randn(1, 2, 128, 256)).shape
269
+ torch.Size([1, 2, 128, 256])
270
+ """
271
+
272
+ def __init__(
273
+ self,
274
+ filter_type = 'linear',
275
+ spectral_transform = 'sht',
276
+ operator_type = 'vector',
277
+ img_size = (128, 256),
278
+ scale_factor = 4,
279
+ in_chans = 3,
280
+ out_chans = 3,
281
+ embed_dim = 256,
282
+ num_layers = 4,
283
+ activation_function = 'gelu',
284
+ encoder_layers = 1,
285
+ use_mlp = True,
286
+ mlp_ratio = 2.,
287
+ drop_rate = 0.,
288
+ drop_path_rate = 0.,
289
+ sparsity_threshold = 0.0,
290
+ normalization_layer = 'instance_norm',
291
+ hard_thresholding_fraction = 1.0,
292
+ use_complex_kernels = True,
293
+ big_skip = False,
294
+ factorization = None,
295
+ separable = False,
296
+ rank = 128,
297
+ complex_activation = 'real',
298
+ spectral_layers = 2,
299
+ pos_embed = True
300
+ ):
301
+
302
+ super(SphericalFourierNeuralOperatorNet, self).__init__()
303
+
304
+ self.filter_type = filter_type
305
+ self.spectral_transform = spectral_transform
306
+ self.operator_type = operator_type
307
+ self.img_size = img_size
308
+ self.scale_factor = scale_factor
309
+ self.in_chans = in_chans
310
+ self.out_chans = out_chans
311
+ self.embed_dim = self.num_features = embed_dim
312
+ self.pos_embed_dim = self.embed_dim
313
+ self.num_layers = num_layers
314
+ self.hard_thresholding_fraction = hard_thresholding_fraction
315
+ self.normalization_layer = normalization_layer
316
+ self.use_mlp = use_mlp
317
+ self.encoder_layers = encoder_layers
318
+ self.big_skip = big_skip
319
+ self.factorization = factorization
320
+ self.separable = separable,
321
+ self.rank = rank
322
+ self.complex_activation = complex_activation
323
+ self.spectral_layers = spectral_layers
324
+
325
+ # activation function
326
+ if activation_function == 'relu':
327
+ self.activation_function = nn.ReLU
328
+ elif activation_function == 'gelu':
329
+ self.activation_function = nn.GELU
330
+ else:
331
+ raise ValueError(f"Unknown activation function {activation_function}")
332
+
333
+ # compute downsampled image size
334
+ self.h = self.img_size[0] // scale_factor
335
+ self.w = self.img_size[1] // scale_factor
336
+
337
+ # dropout
338
+ self.pos_drop = nn.Dropout(p=drop_rate) if drop_rate > 0. else nn.Identity()
339
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, self.num_layers)]
340
+
341
+ # pick norm layer
342
+ if self.normalization_layer == "layer_norm":
343
+ norm_layer0 = partial(nn.LayerNorm, normalized_shape=(self.img_size[0], self.img_size[1]), eps=1e-6)
344
+ norm_layer1 = partial(nn.LayerNorm, normalized_shape=(self.h, self.w), eps=1e-6)
345
+ elif self.normalization_layer == "instance_norm":
346
+ norm_layer0 = partial(nn.InstanceNorm2d, num_features=self.embed_dim, eps=1e-6, affine=True, track_running_stats=False)
347
+ norm_layer1 = norm_layer0
348
+ elif self.normalization_layer == "none":
349
+ norm_layer0 = nn.Identity
350
+ norm_layer1 = norm_layer0
351
+ else:
352
+ raise NotImplementedError(f"Error, normalization {self.normalization_layer} not implemented.")
353
+
354
+ if pos_embed:
355
+ self.pos_embed = nn.Parameter(torch.zeros(1, self.embed_dim, self.img_size[0], self.img_size[1]))
356
+ #self.pos_embed = posemb_sincos_2d(900, 1800, 128)
357
+ pass
358
+ #x = torch.linspace(-np.pi, np.pi, 900)
359
+ #y = torch.linspace(-np.pi, np.pi, 1800)
360
+ #x, y = torch.meshgrid(x, y)
361
+ #self.pos_embed = torch.stack((torch.sin(x), torch.sin(y), torch.cos(x), torch.cos(y)), dim=0).unsqueeze(0).cuda()
362
+ #self.pos_embed = nn.Parameter(torch.zeros(1, self.embed_dim, self.img_size[0], self.img_size[1]))
363
+ #self.pos_direct = nn.Conv2d(4, self.embed_dim, 1, bias=False)
364
+ else:
365
+ self.pos_embed = None
366
+
367
+ # encoder
368
+ """encoder_hidden_dim = self.embed_dim
369
+ current_dim = self.in_chans
370
+ encoder_modules = []
371
+ for i in range(self.encoder_layers):
372
+ encoder_modules.append(nn.Conv2d(current_dim, encoder_hidden_dim, 1, bias=True))
373
+ encoder_modules.append(self.activation_function())
374
+ current_dim = encoder_hidden_dim
375
+ encoder_modules.append(nn.Conv2d(current_dim, self.embed_dim, 1, bias=False))
376
+ self.encoder = nn.Sequential(*encoder_modules)"""
377
+
378
+ # prepare the spectral transform
379
+ if self.spectral_transform == 'sht':
380
+
381
+ modes_lat = int(self.h * self.hard_thresholding_fraction)
382
+ modes_lon = int((self.w // 2 + 1) * self.hard_thresholding_fraction)
383
+
384
+ self.trans_down = RealSHT(*self.img_size, lmax=modes_lat, mmax=modes_lon, grid='equiangular').float()
385
+ self.itrans_up = InverseRealSHT(*self.img_size, lmax=modes_lat, mmax=modes_lon, grid='equiangular').float()
386
+ self.trans = RealSHT(self.h, self.w, lmax=modes_lat, mmax=modes_lon, grid='legendre-gauss').float()
387
+ self.itrans = InverseRealSHT(self.h, self.w, lmax=modes_lat, mmax=modes_lon, grid='legendre-gauss').float()
388
+
389
+ elif self.spectral_transform == 'fft':
390
+
391
+ modes_lat = int(self.h * self.hard_thresholding_fraction)
392
+ modes_lon = int((self.w // 2 + 1) * self.hard_thresholding_fraction)
393
+
394
+ self.trans_down = RealFFT2(*self.img_size, lmax=modes_lat, mmax=modes_lon).float()
395
+ self.itrans_up = InverseRealFFT2(*self.img_size, lmax=modes_lat, mmax=modes_lon).float()
396
+ self.trans = RealFFT2(self.h, self.w, lmax=modes_lat, mmax=modes_lon).float()
397
+ self.itrans = InverseRealFFT2(self.h, self.w, lmax=modes_lat, mmax=modes_lon).float()
398
+
399
+ else:
400
+ raise(ValueError('Unknown spectral transform'))
401
+
402
+ self.blocks = nn.ModuleList([])
403
+ for i in range(self.num_layers):
404
+
405
+ first_layer = i == 0
406
+ last_layer = i == self.num_layers-1
407
+
408
+ forward_transform = self.trans_down if first_layer else self.trans
409
+ inverse_transform = self.itrans_up if last_layer else self.itrans
410
+
411
+ inner_skip = 'linear'
412
+ outer_skip = 'identity'
413
+
414
+ if first_layer:
415
+ norm_layer = (norm_layer0, norm_layer1)
416
+ elif last_layer:
417
+ norm_layer = (norm_layer1, norm_layer0)
418
+ else:
419
+ norm_layer = (norm_layer1, norm_layer1)
420
+
421
+ block = SphericalFourierNeuralOperatorBlock(forward_transform,
422
+ inverse_transform,
423
+ self.embed_dim,
424
+ filter_type = filter_type,
425
+ operator_type = self.operator_type,
426
+ mlp_ratio = mlp_ratio,
427
+ drop_rate = drop_rate,
428
+ drop_path = dpr[i],
429
+ act_layer = self.activation_function,
430
+ norm_layer = norm_layer,
431
+ sparsity_threshold = sparsity_threshold,
432
+ use_complex_kernels = use_complex_kernels,
433
+ inner_skip = inner_skip,
434
+ outer_skip = outer_skip,
435
+ use_mlp = use_mlp,
436
+ factorization = self.factorization,
437
+ separable = self.separable,
438
+ rank = self.rank,
439
+ complex_activation = self.complex_activation,
440
+ spectral_layers = self.spectral_layers)
441
+
442
+ self.blocks.append(block)
443
+
444
+ # trunc_normal_(self.pos_embed, std=.02)
445
+ self.apply(self._init_weights)
446
+
447
+ def _init_weights(self, m):
448
+ if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d):
449
+ trunc_normal_(m.weight, std=.02)
450
+ #nn.init.normal_(m.weight, std=0.02)
451
+ if m.bias is not None:
452
+ nn.init.constant_(m.bias, 0)
453
+
454
+ @torch.jit.ignore
455
+ def no_weight_decay(self):
456
+ return {'pos_embed', 'cls_token'}
457
+
458
+ def forward_features(self, x):
459
+
460
+ x = self.pos_drop(x)
461
+
462
+ for blk in self.blocks:
463
+ x = blk(x)
464
+
465
+ return x
466
+
467
+ def forward(self, x):
468
+
469
+ #if self.big_skip:
470
+ #residual = x
471
+
472
+ #x = self.encoder(x)
473
+
474
+ #x = x + self.pos_embed
475
+ x = self.pos_embed
476
+
477
+ x = self.forward_features(x)
478
+
479
+ return x