data
Browse files- .gitattributes +2 -0
- app.py +21 -0
- data/pos_embeds_model.npy +3 -0
- model/demo_model.pt +3 -0
- requirements.txt +4 -0
- sfno_encoder.py +479 -0
.gitattributes
CHANGED
@@ -36,3 +36,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
36 |
data/class_70b.npy filter=lfs diff=lfs merge=lfs -text
|
37 |
data/order_70b.npy filter=lfs diff=lfs merge=lfs -text
|
38 |
data/species_70b.npy filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
36 |
data/class_70b.npy filter=lfs diff=lfs merge=lfs -text
|
37 |
data/order_70b.npy filter=lfs diff=lfs merge=lfs -text
|
38 |
data/species_70b.npy filter=lfs diff=lfs merge=lfs -text
|
39 |
+
data/pos_embeds_model.npy filter=lfs diff=lfs merge=lfs -text
|
40 |
+
model/demo_model.pt filter=lfs diff=lfs merge=lfs -text
|
app.py
CHANGED
@@ -13,6 +13,8 @@ order_list = list(order[()].keys())
|
|
13 |
#genus_list = list(genus[()].keys())
|
14 |
#family_list = list(family[()].keys())
|
15 |
|
|
|
|
|
16 |
def update_fn(val):
|
17 |
if val=="Class":
|
18 |
return gr.Dropdown(label="Name", choices=class_list, interactive=True)
|
@@ -25,6 +27,20 @@ def update_fn(val):
|
|
25 |
elif val=="Species":
|
26 |
return gr.Dropdown(label="Name", choices=species_list, interactive=True)
|
27 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
28 |
with gr.Blocks() as demo:
|
29 |
gr.Markdown(
|
30 |
"""
|
@@ -39,5 +55,10 @@ with gr.Blocks() as demo:
|
|
39 |
with gr.Row():
|
40 |
submit_button = gr.Button("Run Model")
|
41 |
|
|
|
|
|
|
|
|
|
42 |
|
|
|
43 |
demo.launch()
|
|
|
13 |
#genus_list = list(genus[()].keys())
|
14 |
#family_list = list(family[()].keys())
|
15 |
|
16 |
+
pos_embed = np.load("data/pos_embed.npy", allow_pickle=True)
|
17 |
+
|
18 |
def update_fn(val):
|
19 |
if val=="Class":
|
20 |
return gr.Dropdown(label="Name", choices=class_list, interactive=True)
|
|
|
27 |
elif val=="Species":
|
28 |
return gr.Dropdown(label="Name", choices=species_list, interactive=True)
|
29 |
|
30 |
+
def pred_fn(taxon, name):
|
31 |
+
if taxon=="Class":
|
32 |
+
text_embeds = clas[()][name]
|
33 |
+
elif taxon=="Order":
|
34 |
+
text_embeds = order[()][name]
|
35 |
+
elif taxon=="Family":
|
36 |
+
text_embeds = family[()][name]
|
37 |
+
elif taxon=="Genus":
|
38 |
+
text_embeds = genus[()][name]
|
39 |
+
elif taxon=="Species":
|
40 |
+
text_embeds = species[()][name]
|
41 |
+
|
42 |
+
|
43 |
+
|
44 |
with gr.Blocks() as demo:
|
45 |
gr.Markdown(
|
46 |
"""
|
|
|
55 |
with gr.Row():
|
56 |
submit_button = gr.Button("Run Model")
|
57 |
|
58 |
+
with gr.Row():
|
59 |
+
pred = gr.Image(label="Predicted Heatmap", visible=False)
|
60 |
+
|
61 |
+
submit_button.click(pred_fn, inputs=[inp, out], outputs=[pred])
|
62 |
|
63 |
+
|
64 |
demo.launch()
|
data/pos_embeds_model.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:e010369a67f1d2946dd787494a65f88c1ed79d1cc6d4a5be3f5ac98568492630
|
3 |
+
size 829440128
|
model/demo_model.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:b5d592a3a09086658aeac9b51574f11977962c0f2d5703e0225c3a236be4592d
|
3 |
+
size 76024944
|
requirements.txt
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
numpy==1.23.4
|
2 |
+
torch==2.0.1
|
3 |
+
rasterio==1.3.8
|
4 |
+
einops==0.6.1
|
sfno_encoder.py
ADDED
@@ -0,0 +1,479 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
from torch_harmonics import *
|
5 |
+
|
6 |
+
from torch_harmonics.examples.sfno.models.layers import *
|
7 |
+
|
8 |
+
from functools import partial
|
9 |
+
|
10 |
+
from einops import repeat
|
11 |
+
|
12 |
+
import numpy as np
|
13 |
+
|
14 |
+
class SpectralFilterLayer(nn.Module):
|
15 |
+
"""
|
16 |
+
Fourier layer. Contains the convolution part of the FNO/SFNO
|
17 |
+
"""
|
18 |
+
|
19 |
+
def __init__(
|
20 |
+
self,
|
21 |
+
forward_transform,
|
22 |
+
inverse_transform,
|
23 |
+
embed_dim,
|
24 |
+
filter_type = 'non-linear',
|
25 |
+
operator_type = 'diagonal',
|
26 |
+
sparsity_threshold = 0.0,
|
27 |
+
use_complex_kernels = True,
|
28 |
+
hidden_size_factor = 2,
|
29 |
+
factorization = None,
|
30 |
+
separable = False,
|
31 |
+
rank = 1e-2,
|
32 |
+
complex_activation = 'real',
|
33 |
+
spectral_layers = 1,
|
34 |
+
drop_rate = 0):
|
35 |
+
super(SpectralFilterLayer, self).__init__()
|
36 |
+
|
37 |
+
if filter_type == 'non-linear' and isinstance(forward_transform, RealSHT):
|
38 |
+
self.filter = SpectralAttentionS2(forward_transform,
|
39 |
+
inverse_transform,
|
40 |
+
embed_dim,
|
41 |
+
operator_type = operator_type,
|
42 |
+
sparsity_threshold = sparsity_threshold,
|
43 |
+
hidden_size_factor = hidden_size_factor,
|
44 |
+
complex_activation = complex_activation,
|
45 |
+
spectral_layers = spectral_layers,
|
46 |
+
drop_rate = drop_rate,
|
47 |
+
bias = False)
|
48 |
+
|
49 |
+
elif filter_type == 'non-linear' and isinstance(forward_transform, RealFFT2):
|
50 |
+
self.filter = SpectralAttention2d(forward_transform,
|
51 |
+
inverse_transform,
|
52 |
+
embed_dim,
|
53 |
+
sparsity_threshold = sparsity_threshold,
|
54 |
+
use_complex_kernels = use_complex_kernels,
|
55 |
+
hidden_size_factor = hidden_size_factor,
|
56 |
+
complex_activation = complex_activation,
|
57 |
+
spectral_layers = spectral_layers,
|
58 |
+
drop_rate = drop_rate,
|
59 |
+
bias = False)
|
60 |
+
|
61 |
+
elif filter_type == 'linear':
|
62 |
+
self.filter = SpectralConvS2(forward_transform,
|
63 |
+
inverse_transform,
|
64 |
+
embed_dim,
|
65 |
+
embed_dim,
|
66 |
+
operator_type = operator_type,
|
67 |
+
rank = rank,
|
68 |
+
factorization = factorization,
|
69 |
+
separable = separable,
|
70 |
+
bias = True)
|
71 |
+
|
72 |
+
else:
|
73 |
+
raise(NotImplementedError)
|
74 |
+
|
75 |
+
def forward(self, x):
|
76 |
+
return self.filter(x)
|
77 |
+
|
78 |
+
class SphericalFourierNeuralOperatorBlock(nn.Module):
|
79 |
+
"""
|
80 |
+
Helper module for a single SFNO/FNO block. Can use both FFTs and SHTs to represent either FNO or SFNO blocks.
|
81 |
+
"""
|
82 |
+
def __init__(
|
83 |
+
self,
|
84 |
+
forward_transform,
|
85 |
+
inverse_transform,
|
86 |
+
embed_dim,
|
87 |
+
filter_type = 'non-linear',
|
88 |
+
operator_type = 'diagonal',
|
89 |
+
mlp_ratio = 2.,
|
90 |
+
drop_rate = 0.,
|
91 |
+
drop_path = 0.,
|
92 |
+
act_layer = nn.GELU,
|
93 |
+
norm_layer = (nn.LayerNorm, nn.LayerNorm),
|
94 |
+
sparsity_threshold = 0.0,
|
95 |
+
use_complex_kernels = True,
|
96 |
+
factorization = None,
|
97 |
+
separable = False,
|
98 |
+
rank = 128,
|
99 |
+
inner_skip = 'linear',
|
100 |
+
outer_skip = None, # None, nn.linear or nn.Identity
|
101 |
+
concat_skip = False,
|
102 |
+
use_mlp = True,
|
103 |
+
complex_activation = 'real',
|
104 |
+
spectral_layers = 3):
|
105 |
+
super(SphericalFourierNeuralOperatorBlock, self).__init__()
|
106 |
+
|
107 |
+
# norm layer
|
108 |
+
self.norm0 = norm_layer[0]() #((h,w))
|
109 |
+
|
110 |
+
# convolution layer
|
111 |
+
self.filter = SpectralFilterLayer(forward_transform,
|
112 |
+
inverse_transform,
|
113 |
+
embed_dim,
|
114 |
+
filter_type,
|
115 |
+
operator_type = operator_type,
|
116 |
+
sparsity_threshold = sparsity_threshold,
|
117 |
+
use_complex_kernels = use_complex_kernels,
|
118 |
+
hidden_size_factor = mlp_ratio,
|
119 |
+
factorization = factorization,
|
120 |
+
separable = separable,
|
121 |
+
rank = rank,
|
122 |
+
complex_activation = complex_activation,
|
123 |
+
spectral_layers = spectral_layers,
|
124 |
+
drop_rate = drop_rate)
|
125 |
+
|
126 |
+
if inner_skip == 'linear':
|
127 |
+
self.inner_skip = nn.Conv2d(embed_dim, embed_dim, 1, 1)
|
128 |
+
elif inner_skip == 'identity':
|
129 |
+
self.inner_skip = nn.Identity()
|
130 |
+
|
131 |
+
self.concat_skip = concat_skip
|
132 |
+
|
133 |
+
if concat_skip and inner_skip is not None:
|
134 |
+
self.inner_skip_conv = nn.Conv2d(2*embed_dim, embed_dim, 1, bias=False)
|
135 |
+
|
136 |
+
if filter_type == 'linear' or filter_type == 'local':
|
137 |
+
self.act_layer = act_layer()
|
138 |
+
|
139 |
+
# dropout
|
140 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
141 |
+
|
142 |
+
# norm layer
|
143 |
+
self.norm1 = norm_layer[1]() #((h,w))
|
144 |
+
|
145 |
+
if use_mlp == True:
|
146 |
+
mlp_hidden_dim = int(embed_dim * mlp_ratio)
|
147 |
+
self.mlp = MLP(in_features = embed_dim,
|
148 |
+
hidden_features = mlp_hidden_dim,
|
149 |
+
act_layer = act_layer,
|
150 |
+
drop_rate = drop_rate,
|
151 |
+
checkpointing = False)
|
152 |
+
|
153 |
+
if outer_skip == 'linear':
|
154 |
+
self.outer_skip = nn.Conv2d(embed_dim, embed_dim, 1, 1)
|
155 |
+
elif outer_skip == 'identity':
|
156 |
+
self.outer_skip = nn.Identity()
|
157 |
+
|
158 |
+
if concat_skip and outer_skip is not None:
|
159 |
+
self.outer_skip_conv = nn.Conv2d(2*embed_dim, embed_dim, 1, bias=False)
|
160 |
+
|
161 |
+
def forward(self, x):
|
162 |
+
|
163 |
+
x = self.norm0(x)
|
164 |
+
|
165 |
+
x, residual = self.filter(x)
|
166 |
+
|
167 |
+
if hasattr(self, 'inner_skip'):
|
168 |
+
if self.concat_skip:
|
169 |
+
x = torch.cat((x, self.inner_skip(residual)), dim=1)
|
170 |
+
x = self.inner_skip_conv(x)
|
171 |
+
else:
|
172 |
+
x = x + self.inner_skip(residual)
|
173 |
+
|
174 |
+
if hasattr(self, 'act_layer'):
|
175 |
+
x = self.act_layer(x)
|
176 |
+
|
177 |
+
x = self.norm1(x)
|
178 |
+
|
179 |
+
if hasattr(self, 'mlp'):
|
180 |
+
x = self.mlp(x)
|
181 |
+
|
182 |
+
x = self.drop_path(x)
|
183 |
+
|
184 |
+
if hasattr(self, 'outer_skip'):
|
185 |
+
if self.concat_skip:
|
186 |
+
x = torch.cat((x, self.outer_skip(residual)), dim=1)
|
187 |
+
x = self.outer_skip_conv(x)
|
188 |
+
else:
|
189 |
+
x = x + self.outer_skip(residual)
|
190 |
+
|
191 |
+
return x
|
192 |
+
|
193 |
+
class SphericalFourierNeuralOperatorNet(nn.Module):
|
194 |
+
"""
|
195 |
+
SphericalFourierNeuralOperator module. Can use both FFTs and SHTs to represent either FNO or SFNO,
|
196 |
+
both linear and non-linear variants.
|
197 |
+
|
198 |
+
Parameters
|
199 |
+
----------
|
200 |
+
filter_type : str, optional
|
201 |
+
Type of filter to use ('linear', 'non-linear'), by default "linear"
|
202 |
+
spectral_transform : str, optional
|
203 |
+
Type of spectral transformation to use, by default "sht"
|
204 |
+
operator_type : str, optional
|
205 |
+
Type of operator to use ('vector', 'diagonal'), by default "vector"
|
206 |
+
img_shape : tuple, optional
|
207 |
+
Shape of the input channels, by default (128, 256)
|
208 |
+
scale_factor : int, optional
|
209 |
+
Scale factor to use, by default 3
|
210 |
+
in_chans : int, optional
|
211 |
+
Number of input channels, by default 3
|
212 |
+
out_chans : int, optional
|
213 |
+
Number of output channels, by default 3
|
214 |
+
embed_dim : int, optional
|
215 |
+
Dimension of the embeddings, by default 256
|
216 |
+
num_layers : int, optional
|
217 |
+
Number of layers in the network, by default 4
|
218 |
+
activation_function : str, optional
|
219 |
+
Activation function to use, by default "gelu"
|
220 |
+
encoder_layers : int, optional
|
221 |
+
Number of layers in the encoder, by default 1
|
222 |
+
use_mlp : int, optional
|
223 |
+
Whether to use MLP, by default True
|
224 |
+
mlp_ratio : int, optional
|
225 |
+
Ratio of MLP to use, by default 2.0
|
226 |
+
drop_rate : float, optional
|
227 |
+
Dropout rate, by default 0.0
|
228 |
+
drop_path_rate : float, optional
|
229 |
+
Dropout path rate, by default 0.0
|
230 |
+
sparsity_threshold : float, optional
|
231 |
+
Threshold for sparsity, by default 0.0
|
232 |
+
normalization_layer : str, optional
|
233 |
+
Type of normalization layer to use ("layer_norm", "instance_norm", "none"), by default "instance_norm"
|
234 |
+
hard_thresholding_fraction : float, optional
|
235 |
+
Fraction of hard thresholding (frequency cutoff) to apply, by default 1.0
|
236 |
+
use_complex_kernels : bool, optional
|
237 |
+
Whether to use complex kernels, by default True
|
238 |
+
big_skip : bool, optional
|
239 |
+
Whether to add a single large skip connection, by default True
|
240 |
+
rank : float, optional
|
241 |
+
Rank of the approximation, by default 1.0
|
242 |
+
factorization : Any, optional
|
243 |
+
Type of factorization to use, by default None
|
244 |
+
separable : bool, optional
|
245 |
+
Whether to use separable convolutions, by default False
|
246 |
+
rank : (int, Tuple[int]), optional
|
247 |
+
If a factorization is used, which rank to use. Argument is passed to tensorly
|
248 |
+
complex_activation : str, optional
|
249 |
+
Type of complex activation function to use, by default "real"
|
250 |
+
spectral_layers : int, optional
|
251 |
+
Number of spectral layers, by default 3
|
252 |
+
pos_embed : bool, optional
|
253 |
+
Whether to use positional embedding, by default True
|
254 |
+
|
255 |
+
Example:
|
256 |
+
--------
|
257 |
+
>>> model = SphericalFourierNeuralOperatorNet(
|
258 |
+
... img_shape=(128, 256),
|
259 |
+
... scale_factor=4,
|
260 |
+
... in_chans=2,
|
261 |
+
... out_chans=2,
|
262 |
+
... embed_dim=16,
|
263 |
+
... num_layers=2,
|
264 |
+
... encoder_layers=1,
|
265 |
+
... num_blocks=4,
|
266 |
+
... spectral_layers=2,
|
267 |
+
... use_mlp=True,)
|
268 |
+
>>> model(torch.randn(1, 2, 128, 256)).shape
|
269 |
+
torch.Size([1, 2, 128, 256])
|
270 |
+
"""
|
271 |
+
|
272 |
+
def __init__(
|
273 |
+
self,
|
274 |
+
filter_type = 'linear',
|
275 |
+
spectral_transform = 'sht',
|
276 |
+
operator_type = 'vector',
|
277 |
+
img_size = (128, 256),
|
278 |
+
scale_factor = 4,
|
279 |
+
in_chans = 3,
|
280 |
+
out_chans = 3,
|
281 |
+
embed_dim = 256,
|
282 |
+
num_layers = 4,
|
283 |
+
activation_function = 'gelu',
|
284 |
+
encoder_layers = 1,
|
285 |
+
use_mlp = True,
|
286 |
+
mlp_ratio = 2.,
|
287 |
+
drop_rate = 0.,
|
288 |
+
drop_path_rate = 0.,
|
289 |
+
sparsity_threshold = 0.0,
|
290 |
+
normalization_layer = 'instance_norm',
|
291 |
+
hard_thresholding_fraction = 1.0,
|
292 |
+
use_complex_kernels = True,
|
293 |
+
big_skip = False,
|
294 |
+
factorization = None,
|
295 |
+
separable = False,
|
296 |
+
rank = 128,
|
297 |
+
complex_activation = 'real',
|
298 |
+
spectral_layers = 2,
|
299 |
+
pos_embed = True
|
300 |
+
):
|
301 |
+
|
302 |
+
super(SphericalFourierNeuralOperatorNet, self).__init__()
|
303 |
+
|
304 |
+
self.filter_type = filter_type
|
305 |
+
self.spectral_transform = spectral_transform
|
306 |
+
self.operator_type = operator_type
|
307 |
+
self.img_size = img_size
|
308 |
+
self.scale_factor = scale_factor
|
309 |
+
self.in_chans = in_chans
|
310 |
+
self.out_chans = out_chans
|
311 |
+
self.embed_dim = self.num_features = embed_dim
|
312 |
+
self.pos_embed_dim = self.embed_dim
|
313 |
+
self.num_layers = num_layers
|
314 |
+
self.hard_thresholding_fraction = hard_thresholding_fraction
|
315 |
+
self.normalization_layer = normalization_layer
|
316 |
+
self.use_mlp = use_mlp
|
317 |
+
self.encoder_layers = encoder_layers
|
318 |
+
self.big_skip = big_skip
|
319 |
+
self.factorization = factorization
|
320 |
+
self.separable = separable,
|
321 |
+
self.rank = rank
|
322 |
+
self.complex_activation = complex_activation
|
323 |
+
self.spectral_layers = spectral_layers
|
324 |
+
|
325 |
+
# activation function
|
326 |
+
if activation_function == 'relu':
|
327 |
+
self.activation_function = nn.ReLU
|
328 |
+
elif activation_function == 'gelu':
|
329 |
+
self.activation_function = nn.GELU
|
330 |
+
else:
|
331 |
+
raise ValueError(f"Unknown activation function {activation_function}")
|
332 |
+
|
333 |
+
# compute downsampled image size
|
334 |
+
self.h = self.img_size[0] // scale_factor
|
335 |
+
self.w = self.img_size[1] // scale_factor
|
336 |
+
|
337 |
+
# dropout
|
338 |
+
self.pos_drop = nn.Dropout(p=drop_rate) if drop_rate > 0. else nn.Identity()
|
339 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, self.num_layers)]
|
340 |
+
|
341 |
+
# pick norm layer
|
342 |
+
if self.normalization_layer == "layer_norm":
|
343 |
+
norm_layer0 = partial(nn.LayerNorm, normalized_shape=(self.img_size[0], self.img_size[1]), eps=1e-6)
|
344 |
+
norm_layer1 = partial(nn.LayerNorm, normalized_shape=(self.h, self.w), eps=1e-6)
|
345 |
+
elif self.normalization_layer == "instance_norm":
|
346 |
+
norm_layer0 = partial(nn.InstanceNorm2d, num_features=self.embed_dim, eps=1e-6, affine=True, track_running_stats=False)
|
347 |
+
norm_layer1 = norm_layer0
|
348 |
+
elif self.normalization_layer == "none":
|
349 |
+
norm_layer0 = nn.Identity
|
350 |
+
norm_layer1 = norm_layer0
|
351 |
+
else:
|
352 |
+
raise NotImplementedError(f"Error, normalization {self.normalization_layer} not implemented.")
|
353 |
+
|
354 |
+
if pos_embed:
|
355 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, self.embed_dim, self.img_size[0], self.img_size[1]))
|
356 |
+
#self.pos_embed = posemb_sincos_2d(900, 1800, 128)
|
357 |
+
pass
|
358 |
+
#x = torch.linspace(-np.pi, np.pi, 900)
|
359 |
+
#y = torch.linspace(-np.pi, np.pi, 1800)
|
360 |
+
#x, y = torch.meshgrid(x, y)
|
361 |
+
#self.pos_embed = torch.stack((torch.sin(x), torch.sin(y), torch.cos(x), torch.cos(y)), dim=0).unsqueeze(0).cuda()
|
362 |
+
#self.pos_embed = nn.Parameter(torch.zeros(1, self.embed_dim, self.img_size[0], self.img_size[1]))
|
363 |
+
#self.pos_direct = nn.Conv2d(4, self.embed_dim, 1, bias=False)
|
364 |
+
else:
|
365 |
+
self.pos_embed = None
|
366 |
+
|
367 |
+
# encoder
|
368 |
+
"""encoder_hidden_dim = self.embed_dim
|
369 |
+
current_dim = self.in_chans
|
370 |
+
encoder_modules = []
|
371 |
+
for i in range(self.encoder_layers):
|
372 |
+
encoder_modules.append(nn.Conv2d(current_dim, encoder_hidden_dim, 1, bias=True))
|
373 |
+
encoder_modules.append(self.activation_function())
|
374 |
+
current_dim = encoder_hidden_dim
|
375 |
+
encoder_modules.append(nn.Conv2d(current_dim, self.embed_dim, 1, bias=False))
|
376 |
+
self.encoder = nn.Sequential(*encoder_modules)"""
|
377 |
+
|
378 |
+
# prepare the spectral transform
|
379 |
+
if self.spectral_transform == 'sht':
|
380 |
+
|
381 |
+
modes_lat = int(self.h * self.hard_thresholding_fraction)
|
382 |
+
modes_lon = int((self.w // 2 + 1) * self.hard_thresholding_fraction)
|
383 |
+
|
384 |
+
self.trans_down = RealSHT(*self.img_size, lmax=modes_lat, mmax=modes_lon, grid='equiangular').float()
|
385 |
+
self.itrans_up = InverseRealSHT(*self.img_size, lmax=modes_lat, mmax=modes_lon, grid='equiangular').float()
|
386 |
+
self.trans = RealSHT(self.h, self.w, lmax=modes_lat, mmax=modes_lon, grid='legendre-gauss').float()
|
387 |
+
self.itrans = InverseRealSHT(self.h, self.w, lmax=modes_lat, mmax=modes_lon, grid='legendre-gauss').float()
|
388 |
+
|
389 |
+
elif self.spectral_transform == 'fft':
|
390 |
+
|
391 |
+
modes_lat = int(self.h * self.hard_thresholding_fraction)
|
392 |
+
modes_lon = int((self.w // 2 + 1) * self.hard_thresholding_fraction)
|
393 |
+
|
394 |
+
self.trans_down = RealFFT2(*self.img_size, lmax=modes_lat, mmax=modes_lon).float()
|
395 |
+
self.itrans_up = InverseRealFFT2(*self.img_size, lmax=modes_lat, mmax=modes_lon).float()
|
396 |
+
self.trans = RealFFT2(self.h, self.w, lmax=modes_lat, mmax=modes_lon).float()
|
397 |
+
self.itrans = InverseRealFFT2(self.h, self.w, lmax=modes_lat, mmax=modes_lon).float()
|
398 |
+
|
399 |
+
else:
|
400 |
+
raise(ValueError('Unknown spectral transform'))
|
401 |
+
|
402 |
+
self.blocks = nn.ModuleList([])
|
403 |
+
for i in range(self.num_layers):
|
404 |
+
|
405 |
+
first_layer = i == 0
|
406 |
+
last_layer = i == self.num_layers-1
|
407 |
+
|
408 |
+
forward_transform = self.trans_down if first_layer else self.trans
|
409 |
+
inverse_transform = self.itrans_up if last_layer else self.itrans
|
410 |
+
|
411 |
+
inner_skip = 'linear'
|
412 |
+
outer_skip = 'identity'
|
413 |
+
|
414 |
+
if first_layer:
|
415 |
+
norm_layer = (norm_layer0, norm_layer1)
|
416 |
+
elif last_layer:
|
417 |
+
norm_layer = (norm_layer1, norm_layer0)
|
418 |
+
else:
|
419 |
+
norm_layer = (norm_layer1, norm_layer1)
|
420 |
+
|
421 |
+
block = SphericalFourierNeuralOperatorBlock(forward_transform,
|
422 |
+
inverse_transform,
|
423 |
+
self.embed_dim,
|
424 |
+
filter_type = filter_type,
|
425 |
+
operator_type = self.operator_type,
|
426 |
+
mlp_ratio = mlp_ratio,
|
427 |
+
drop_rate = drop_rate,
|
428 |
+
drop_path = dpr[i],
|
429 |
+
act_layer = self.activation_function,
|
430 |
+
norm_layer = norm_layer,
|
431 |
+
sparsity_threshold = sparsity_threshold,
|
432 |
+
use_complex_kernels = use_complex_kernels,
|
433 |
+
inner_skip = inner_skip,
|
434 |
+
outer_skip = outer_skip,
|
435 |
+
use_mlp = use_mlp,
|
436 |
+
factorization = self.factorization,
|
437 |
+
separable = self.separable,
|
438 |
+
rank = self.rank,
|
439 |
+
complex_activation = self.complex_activation,
|
440 |
+
spectral_layers = self.spectral_layers)
|
441 |
+
|
442 |
+
self.blocks.append(block)
|
443 |
+
|
444 |
+
# trunc_normal_(self.pos_embed, std=.02)
|
445 |
+
self.apply(self._init_weights)
|
446 |
+
|
447 |
+
def _init_weights(self, m):
|
448 |
+
if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d):
|
449 |
+
trunc_normal_(m.weight, std=.02)
|
450 |
+
#nn.init.normal_(m.weight, std=0.02)
|
451 |
+
if m.bias is not None:
|
452 |
+
nn.init.constant_(m.bias, 0)
|
453 |
+
|
454 |
+
@torch.jit.ignore
|
455 |
+
def no_weight_decay(self):
|
456 |
+
return {'pos_embed', 'cls_token'}
|
457 |
+
|
458 |
+
def forward_features(self, x):
|
459 |
+
|
460 |
+
x = self.pos_drop(x)
|
461 |
+
|
462 |
+
for blk in self.blocks:
|
463 |
+
x = blk(x)
|
464 |
+
|
465 |
+
return x
|
466 |
+
|
467 |
+
def forward(self, x):
|
468 |
+
|
469 |
+
#if self.big_skip:
|
470 |
+
#residual = x
|
471 |
+
|
472 |
+
#x = self.encoder(x)
|
473 |
+
|
474 |
+
#x = x + self.pos_embed
|
475 |
+
x = self.pos_embed
|
476 |
+
|
477 |
+
x = self.forward_features(x)
|
478 |
+
|
479 |
+
return x
|