Vishu26 commited on
Commit
9ae1458
·
1 Parent(s): 2bebf28
Files changed (2) hide show
  1. app.py +6 -0
  2. sfno_encoder.py +0 -479
app.py CHANGED
@@ -76,6 +76,9 @@ def update_fn(val):
76
  elif val=="Species":
77
  return gr.Dropdown(label="Name", choices=species_list, interactive=True)
78
 
 
 
 
79
  def pred_fn(taxon, name):
80
  if taxon=="Class":
81
  text_embeds = clas[()][name]
@@ -102,14 +105,17 @@ with gr.Blocks() as demo:
102
  with gr.Row():
103
  inp = gr.Dropdown(label="Taxonomic Hierarchy", choices=["Class", "Order", "Family", "Genus", "Species"])
104
  out = gr.Dropdown(label="Name", interactive=True)
 
105
  inp.change(update_fn, inp, out)
106
 
107
  with gr.Row():
 
108
  submit_button = gr.Button("Run Model")
109
 
110
  with gr.Row():
111
  pred = gr.Image(label="Predicted Heatmap", visible=False)
112
 
 
113
  submit_button.click(pred_fn, inputs=[inp, out], outputs=[pred])
114
 
115
  demo.launch()
 
76
  elif val=="Species":
77
  return gr.Dropdown(label="Name", choices=species_list, interactive=True)
78
 
79
+ def text_fn(taxon, name):
80
+ return taxon + ": " + name
81
+
82
  def pred_fn(taxon, name):
83
  if taxon=="Class":
84
  text_embeds = clas[()][name]
 
105
  with gr.Row():
106
  inp = gr.Dropdown(label="Taxonomic Hierarchy", choices=["Class", "Order", "Family", "Genus", "Species"])
107
  out = gr.Dropdown(label="Name", interactive=True)
108
+ text = gr.Textbox(label="Text", visible=True, interactive=True)
109
  inp.change(update_fn, inp, out)
110
 
111
  with gr.Row():
112
+ check_button = gr.Button("Check")
113
  submit_button = gr.Button("Run Model")
114
 
115
  with gr.Row():
116
  pred = gr.Image(label="Predicted Heatmap", visible=False)
117
 
118
+ check_button.click(text_fn, inputs=[inp, out], outputs=[text])
119
  submit_button.click(pred_fn, inputs=[inp, out], outputs=[pred])
120
 
121
  demo.launch()
sfno_encoder.py DELETED
@@ -1,479 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
-
4
- from torch_harmonics import *
5
-
6
- from torch_harmonics.examples.sfno.models.layers import *
7
-
8
- from functools import partial
9
-
10
- from einops import repeat
11
-
12
- import numpy as np
13
-
14
- class SpectralFilterLayer(nn.Module):
15
- """
16
- Fourier layer. Contains the convolution part of the FNO/SFNO
17
- """
18
-
19
- def __init__(
20
- self,
21
- forward_transform,
22
- inverse_transform,
23
- embed_dim,
24
- filter_type = 'non-linear',
25
- operator_type = 'diagonal',
26
- sparsity_threshold = 0.0,
27
- use_complex_kernels = True,
28
- hidden_size_factor = 2,
29
- factorization = None,
30
- separable = False,
31
- rank = 1e-2,
32
- complex_activation = 'real',
33
- spectral_layers = 1,
34
- drop_rate = 0):
35
- super(SpectralFilterLayer, self).__init__()
36
-
37
- if filter_type == 'non-linear' and isinstance(forward_transform, RealSHT):
38
- self.filter = SpectralAttentionS2(forward_transform,
39
- inverse_transform,
40
- embed_dim,
41
- operator_type = operator_type,
42
- sparsity_threshold = sparsity_threshold,
43
- hidden_size_factor = hidden_size_factor,
44
- complex_activation = complex_activation,
45
- spectral_layers = spectral_layers,
46
- drop_rate = drop_rate,
47
- bias = False)
48
-
49
- elif filter_type == 'non-linear' and isinstance(forward_transform, RealFFT2):
50
- self.filter = SpectralAttention2d(forward_transform,
51
- inverse_transform,
52
- embed_dim,
53
- sparsity_threshold = sparsity_threshold,
54
- use_complex_kernels = use_complex_kernels,
55
- hidden_size_factor = hidden_size_factor,
56
- complex_activation = complex_activation,
57
- spectral_layers = spectral_layers,
58
- drop_rate = drop_rate,
59
- bias = False)
60
-
61
- elif filter_type == 'linear':
62
- self.filter = SpectralConvS2(forward_transform,
63
- inverse_transform,
64
- embed_dim,
65
- embed_dim,
66
- operator_type = operator_type,
67
- rank = rank,
68
- factorization = factorization,
69
- separable = separable,
70
- bias = True)
71
-
72
- else:
73
- raise(NotImplementedError)
74
-
75
- def forward(self, x):
76
- return self.filter(x)
77
-
78
- class SphericalFourierNeuralOperatorBlock(nn.Module):
79
- """
80
- Helper module for a single SFNO/FNO block. Can use both FFTs and SHTs to represent either FNO or SFNO blocks.
81
- """
82
- def __init__(
83
- self,
84
- forward_transform,
85
- inverse_transform,
86
- embed_dim,
87
- filter_type = 'non-linear',
88
- operator_type = 'diagonal',
89
- mlp_ratio = 2.,
90
- drop_rate = 0.,
91
- drop_path = 0.,
92
- act_layer = nn.GELU,
93
- norm_layer = (nn.LayerNorm, nn.LayerNorm),
94
- sparsity_threshold = 0.0,
95
- use_complex_kernels = True,
96
- factorization = None,
97
- separable = False,
98
- rank = 128,
99
- inner_skip = 'linear',
100
- outer_skip = None, # None, nn.linear or nn.Identity
101
- concat_skip = False,
102
- use_mlp = True,
103
- complex_activation = 'real',
104
- spectral_layers = 3):
105
- super(SphericalFourierNeuralOperatorBlock, self).__init__()
106
-
107
- # norm layer
108
- self.norm0 = norm_layer[0]() #((h,w))
109
-
110
- # convolution layer
111
- self.filter = SpectralFilterLayer(forward_transform,
112
- inverse_transform,
113
- embed_dim,
114
- filter_type,
115
- operator_type = operator_type,
116
- sparsity_threshold = sparsity_threshold,
117
- use_complex_kernels = use_complex_kernels,
118
- hidden_size_factor = mlp_ratio,
119
- factorization = factorization,
120
- separable = separable,
121
- rank = rank,
122
- complex_activation = complex_activation,
123
- spectral_layers = spectral_layers,
124
- drop_rate = drop_rate)
125
-
126
- if inner_skip == 'linear':
127
- self.inner_skip = nn.Conv2d(embed_dim, embed_dim, 1, 1)
128
- elif inner_skip == 'identity':
129
- self.inner_skip = nn.Identity()
130
-
131
- self.concat_skip = concat_skip
132
-
133
- if concat_skip and inner_skip is not None:
134
- self.inner_skip_conv = nn.Conv2d(2*embed_dim, embed_dim, 1, bias=False)
135
-
136
- if filter_type == 'linear' or filter_type == 'local':
137
- self.act_layer = act_layer()
138
-
139
- # dropout
140
- self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
141
-
142
- # norm layer
143
- self.norm1 = norm_layer[1]() #((h,w))
144
-
145
- if use_mlp == True:
146
- mlp_hidden_dim = int(embed_dim * mlp_ratio)
147
- self.mlp = MLP(in_features = embed_dim,
148
- hidden_features = mlp_hidden_dim,
149
- act_layer = act_layer,
150
- drop_rate = drop_rate,
151
- checkpointing = False)
152
-
153
- if outer_skip == 'linear':
154
- self.outer_skip = nn.Conv2d(embed_dim, embed_dim, 1, 1)
155
- elif outer_skip == 'identity':
156
- self.outer_skip = nn.Identity()
157
-
158
- if concat_skip and outer_skip is not None:
159
- self.outer_skip_conv = nn.Conv2d(2*embed_dim, embed_dim, 1, bias=False)
160
-
161
- def forward(self, x):
162
-
163
- x = self.norm0(x)
164
-
165
- x, residual = self.filter(x)
166
-
167
- if hasattr(self, 'inner_skip'):
168
- if self.concat_skip:
169
- x = torch.cat((x, self.inner_skip(residual)), dim=1)
170
- x = self.inner_skip_conv(x)
171
- else:
172
- x = x + self.inner_skip(residual)
173
-
174
- if hasattr(self, 'act_layer'):
175
- x = self.act_layer(x)
176
-
177
- x = self.norm1(x)
178
-
179
- if hasattr(self, 'mlp'):
180
- x = self.mlp(x)
181
-
182
- x = self.drop_path(x)
183
-
184
- if hasattr(self, 'outer_skip'):
185
- if self.concat_skip:
186
- x = torch.cat((x, self.outer_skip(residual)), dim=1)
187
- x = self.outer_skip_conv(x)
188
- else:
189
- x = x + self.outer_skip(residual)
190
-
191
- return x
192
-
193
- class SphericalFourierNeuralOperatorNet(nn.Module):
194
- """
195
- SphericalFourierNeuralOperator module. Can use both FFTs and SHTs to represent either FNO or SFNO,
196
- both linear and non-linear variants.
197
-
198
- Parameters
199
- ----------
200
- filter_type : str, optional
201
- Type of filter to use ('linear', 'non-linear'), by default "linear"
202
- spectral_transform : str, optional
203
- Type of spectral transformation to use, by default "sht"
204
- operator_type : str, optional
205
- Type of operator to use ('vector', 'diagonal'), by default "vector"
206
- img_shape : tuple, optional
207
- Shape of the input channels, by default (128, 256)
208
- scale_factor : int, optional
209
- Scale factor to use, by default 3
210
- in_chans : int, optional
211
- Number of input channels, by default 3
212
- out_chans : int, optional
213
- Number of output channels, by default 3
214
- embed_dim : int, optional
215
- Dimension of the embeddings, by default 256
216
- num_layers : int, optional
217
- Number of layers in the network, by default 4
218
- activation_function : str, optional
219
- Activation function to use, by default "gelu"
220
- encoder_layers : int, optional
221
- Number of layers in the encoder, by default 1
222
- use_mlp : int, optional
223
- Whether to use MLP, by default True
224
- mlp_ratio : int, optional
225
- Ratio of MLP to use, by default 2.0
226
- drop_rate : float, optional
227
- Dropout rate, by default 0.0
228
- drop_path_rate : float, optional
229
- Dropout path rate, by default 0.0
230
- sparsity_threshold : float, optional
231
- Threshold for sparsity, by default 0.0
232
- normalization_layer : str, optional
233
- Type of normalization layer to use ("layer_norm", "instance_norm", "none"), by default "instance_norm"
234
- hard_thresholding_fraction : float, optional
235
- Fraction of hard thresholding (frequency cutoff) to apply, by default 1.0
236
- use_complex_kernels : bool, optional
237
- Whether to use complex kernels, by default True
238
- big_skip : bool, optional
239
- Whether to add a single large skip connection, by default True
240
- rank : float, optional
241
- Rank of the approximation, by default 1.0
242
- factorization : Any, optional
243
- Type of factorization to use, by default None
244
- separable : bool, optional
245
- Whether to use separable convolutions, by default False
246
- rank : (int, Tuple[int]), optional
247
- If a factorization is used, which rank to use. Argument is passed to tensorly
248
- complex_activation : str, optional
249
- Type of complex activation function to use, by default "real"
250
- spectral_layers : int, optional
251
- Number of spectral layers, by default 3
252
- pos_embed : bool, optional
253
- Whether to use positional embedding, by default True
254
-
255
- Example:
256
- --------
257
- >>> model = SphericalFourierNeuralOperatorNet(
258
- ... img_shape=(128, 256),
259
- ... scale_factor=4,
260
- ... in_chans=2,
261
- ... out_chans=2,
262
- ... embed_dim=16,
263
- ... num_layers=2,
264
- ... encoder_layers=1,
265
- ... num_blocks=4,
266
- ... spectral_layers=2,
267
- ... use_mlp=True,)
268
- >>> model(torch.randn(1, 2, 128, 256)).shape
269
- torch.Size([1, 2, 128, 256])
270
- """
271
-
272
- def __init__(
273
- self,
274
- filter_type = 'linear',
275
- spectral_transform = 'sht',
276
- operator_type = 'vector',
277
- img_size = (128, 256),
278
- scale_factor = 4,
279
- in_chans = 3,
280
- out_chans = 3,
281
- embed_dim = 256,
282
- num_layers = 4,
283
- activation_function = 'gelu',
284
- encoder_layers = 1,
285
- use_mlp = True,
286
- mlp_ratio = 2.,
287
- drop_rate = 0.,
288
- drop_path_rate = 0.,
289
- sparsity_threshold = 0.0,
290
- normalization_layer = 'instance_norm',
291
- hard_thresholding_fraction = 1.0,
292
- use_complex_kernels = True,
293
- big_skip = False,
294
- factorization = None,
295
- separable = False,
296
- rank = 128,
297
- complex_activation = 'real',
298
- spectral_layers = 2,
299
- pos_embed = True
300
- ):
301
-
302
- super(SphericalFourierNeuralOperatorNet, self).__init__()
303
-
304
- self.filter_type = filter_type
305
- self.spectral_transform = spectral_transform
306
- self.operator_type = operator_type
307
- self.img_size = img_size
308
- self.scale_factor = scale_factor
309
- self.in_chans = in_chans
310
- self.out_chans = out_chans
311
- self.embed_dim = self.num_features = embed_dim
312
- self.pos_embed_dim = self.embed_dim
313
- self.num_layers = num_layers
314
- self.hard_thresholding_fraction = hard_thresholding_fraction
315
- self.normalization_layer = normalization_layer
316
- self.use_mlp = use_mlp
317
- self.encoder_layers = encoder_layers
318
- self.big_skip = big_skip
319
- self.factorization = factorization
320
- self.separable = separable,
321
- self.rank = rank
322
- self.complex_activation = complex_activation
323
- self.spectral_layers = spectral_layers
324
-
325
- # activation function
326
- if activation_function == 'relu':
327
- self.activation_function = nn.ReLU
328
- elif activation_function == 'gelu':
329
- self.activation_function = nn.GELU
330
- else:
331
- raise ValueError(f"Unknown activation function {activation_function}")
332
-
333
- # compute downsampled image size
334
- self.h = self.img_size[0] // scale_factor
335
- self.w = self.img_size[1] // scale_factor
336
-
337
- # dropout
338
- self.pos_drop = nn.Dropout(p=drop_rate) if drop_rate > 0. else nn.Identity()
339
- dpr = [x.item() for x in torch.linspace(0, drop_path_rate, self.num_layers)]
340
-
341
- # pick norm layer
342
- if self.normalization_layer == "layer_norm":
343
- norm_layer0 = partial(nn.LayerNorm, normalized_shape=(self.img_size[0], self.img_size[1]), eps=1e-6)
344
- norm_layer1 = partial(nn.LayerNorm, normalized_shape=(self.h, self.w), eps=1e-6)
345
- elif self.normalization_layer == "instance_norm":
346
- norm_layer0 = partial(nn.InstanceNorm2d, num_features=self.embed_dim, eps=1e-6, affine=True, track_running_stats=False)
347
- norm_layer1 = norm_layer0
348
- elif self.normalization_layer == "none":
349
- norm_layer0 = nn.Identity
350
- norm_layer1 = norm_layer0
351
- else:
352
- raise NotImplementedError(f"Error, normalization {self.normalization_layer} not implemented.")
353
-
354
- if pos_embed:
355
- self.pos_embed = nn.Parameter(torch.zeros(1, self.embed_dim, self.img_size[0], self.img_size[1]))
356
- #self.pos_embed = posemb_sincos_2d(900, 1800, 128)
357
- pass
358
- #x = torch.linspace(-np.pi, np.pi, 900)
359
- #y = torch.linspace(-np.pi, np.pi, 1800)
360
- #x, y = torch.meshgrid(x, y)
361
- #self.pos_embed = torch.stack((torch.sin(x), torch.sin(y), torch.cos(x), torch.cos(y)), dim=0).unsqueeze(0).cuda()
362
- #self.pos_embed = nn.Parameter(torch.zeros(1, self.embed_dim, self.img_size[0], self.img_size[1]))
363
- #self.pos_direct = nn.Conv2d(4, self.embed_dim, 1, bias=False)
364
- else:
365
- self.pos_embed = None
366
-
367
- # encoder
368
- """encoder_hidden_dim = self.embed_dim
369
- current_dim = self.in_chans
370
- encoder_modules = []
371
- for i in range(self.encoder_layers):
372
- encoder_modules.append(nn.Conv2d(current_dim, encoder_hidden_dim, 1, bias=True))
373
- encoder_modules.append(self.activation_function())
374
- current_dim = encoder_hidden_dim
375
- encoder_modules.append(nn.Conv2d(current_dim, self.embed_dim, 1, bias=False))
376
- self.encoder = nn.Sequential(*encoder_modules)"""
377
-
378
- # prepare the spectral transform
379
- if self.spectral_transform == 'sht':
380
-
381
- modes_lat = int(self.h * self.hard_thresholding_fraction)
382
- modes_lon = int((self.w // 2 + 1) * self.hard_thresholding_fraction)
383
-
384
- self.trans_down = RealSHT(*self.img_size, lmax=modes_lat, mmax=modes_lon, grid='equiangular').float()
385
- self.itrans_up = InverseRealSHT(*self.img_size, lmax=modes_lat, mmax=modes_lon, grid='equiangular').float()
386
- self.trans = RealSHT(self.h, self.w, lmax=modes_lat, mmax=modes_lon, grid='legendre-gauss').float()
387
- self.itrans = InverseRealSHT(self.h, self.w, lmax=modes_lat, mmax=modes_lon, grid='legendre-gauss').float()
388
-
389
- elif self.spectral_transform == 'fft':
390
-
391
- modes_lat = int(self.h * self.hard_thresholding_fraction)
392
- modes_lon = int((self.w // 2 + 1) * self.hard_thresholding_fraction)
393
-
394
- self.trans_down = RealFFT2(*self.img_size, lmax=modes_lat, mmax=modes_lon).float()
395
- self.itrans_up = InverseRealFFT2(*self.img_size, lmax=modes_lat, mmax=modes_lon).float()
396
- self.trans = RealFFT2(self.h, self.w, lmax=modes_lat, mmax=modes_lon).float()
397
- self.itrans = InverseRealFFT2(self.h, self.w, lmax=modes_lat, mmax=modes_lon).float()
398
-
399
- else:
400
- raise(ValueError('Unknown spectral transform'))
401
-
402
- self.blocks = nn.ModuleList([])
403
- for i in range(self.num_layers):
404
-
405
- first_layer = i == 0
406
- last_layer = i == self.num_layers-1
407
-
408
- forward_transform = self.trans_down if first_layer else self.trans
409
- inverse_transform = self.itrans_up if last_layer else self.itrans
410
-
411
- inner_skip = 'linear'
412
- outer_skip = 'identity'
413
-
414
- if first_layer:
415
- norm_layer = (norm_layer0, norm_layer1)
416
- elif last_layer:
417
- norm_layer = (norm_layer1, norm_layer0)
418
- else:
419
- norm_layer = (norm_layer1, norm_layer1)
420
-
421
- block = SphericalFourierNeuralOperatorBlock(forward_transform,
422
- inverse_transform,
423
- self.embed_dim,
424
- filter_type = filter_type,
425
- operator_type = self.operator_type,
426
- mlp_ratio = mlp_ratio,
427
- drop_rate = drop_rate,
428
- drop_path = dpr[i],
429
- act_layer = self.activation_function,
430
- norm_layer = norm_layer,
431
- sparsity_threshold = sparsity_threshold,
432
- use_complex_kernels = use_complex_kernels,
433
- inner_skip = inner_skip,
434
- outer_skip = outer_skip,
435
- use_mlp = use_mlp,
436
- factorization = self.factorization,
437
- separable = self.separable,
438
- rank = self.rank,
439
- complex_activation = self.complex_activation,
440
- spectral_layers = self.spectral_layers)
441
-
442
- self.blocks.append(block)
443
-
444
- # trunc_normal_(self.pos_embed, std=.02)
445
- self.apply(self._init_weights)
446
-
447
- def _init_weights(self, m):
448
- if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d):
449
- trunc_normal_(m.weight, std=.02)
450
- #nn.init.normal_(m.weight, std=0.02)
451
- if m.bias is not None:
452
- nn.init.constant_(m.bias, 0)
453
-
454
- @torch.jit.ignore
455
- def no_weight_decay(self):
456
- return {'pos_embed', 'cls_token'}
457
-
458
- def forward_features(self, x):
459
-
460
- x = self.pos_drop(x)
461
-
462
- for blk in self.blocks:
463
- x = blk(x)
464
-
465
- return x
466
-
467
- def forward(self, x):
468
-
469
- #if self.big_skip:
470
- #residual = x
471
-
472
- #x = self.encoder(x)
473
-
474
- #x = x + self.pos_embed
475
- x = self.pos_embed
476
-
477
- x = self.forward_features(x)
478
-
479
- return x