data
Browse files- app.py +6 -0
- sfno_encoder.py +0 -479
app.py
CHANGED
@@ -76,6 +76,9 @@ def update_fn(val):
|
|
76 |
elif val=="Species":
|
77 |
return gr.Dropdown(label="Name", choices=species_list, interactive=True)
|
78 |
|
|
|
|
|
|
|
79 |
def pred_fn(taxon, name):
|
80 |
if taxon=="Class":
|
81 |
text_embeds = clas[()][name]
|
@@ -102,14 +105,17 @@ with gr.Blocks() as demo:
|
|
102 |
with gr.Row():
|
103 |
inp = gr.Dropdown(label="Taxonomic Hierarchy", choices=["Class", "Order", "Family", "Genus", "Species"])
|
104 |
out = gr.Dropdown(label="Name", interactive=True)
|
|
|
105 |
inp.change(update_fn, inp, out)
|
106 |
|
107 |
with gr.Row():
|
|
|
108 |
submit_button = gr.Button("Run Model")
|
109 |
|
110 |
with gr.Row():
|
111 |
pred = gr.Image(label="Predicted Heatmap", visible=False)
|
112 |
|
|
|
113 |
submit_button.click(pred_fn, inputs=[inp, out], outputs=[pred])
|
114 |
|
115 |
demo.launch()
|
|
|
76 |
elif val=="Species":
|
77 |
return gr.Dropdown(label="Name", choices=species_list, interactive=True)
|
78 |
|
79 |
+
def text_fn(taxon, name):
|
80 |
+
return taxon + ": " + name
|
81 |
+
|
82 |
def pred_fn(taxon, name):
|
83 |
if taxon=="Class":
|
84 |
text_embeds = clas[()][name]
|
|
|
105 |
with gr.Row():
|
106 |
inp = gr.Dropdown(label="Taxonomic Hierarchy", choices=["Class", "Order", "Family", "Genus", "Species"])
|
107 |
out = gr.Dropdown(label="Name", interactive=True)
|
108 |
+
text = gr.Textbox(label="Text", visible=True, interactive=True)
|
109 |
inp.change(update_fn, inp, out)
|
110 |
|
111 |
with gr.Row():
|
112 |
+
check_button = gr.Button("Check")
|
113 |
submit_button = gr.Button("Run Model")
|
114 |
|
115 |
with gr.Row():
|
116 |
pred = gr.Image(label="Predicted Heatmap", visible=False)
|
117 |
|
118 |
+
check_button.click(text_fn, inputs=[inp, out], outputs=[text])
|
119 |
submit_button.click(pred_fn, inputs=[inp, out], outputs=[pred])
|
120 |
|
121 |
demo.launch()
|
sfno_encoder.py
DELETED
@@ -1,479 +0,0 @@
|
|
1 |
-
import torch
|
2 |
-
import torch.nn as nn
|
3 |
-
|
4 |
-
from torch_harmonics import *
|
5 |
-
|
6 |
-
from torch_harmonics.examples.sfno.models.layers import *
|
7 |
-
|
8 |
-
from functools import partial
|
9 |
-
|
10 |
-
from einops import repeat
|
11 |
-
|
12 |
-
import numpy as np
|
13 |
-
|
14 |
-
class SpectralFilterLayer(nn.Module):
|
15 |
-
"""
|
16 |
-
Fourier layer. Contains the convolution part of the FNO/SFNO
|
17 |
-
"""
|
18 |
-
|
19 |
-
def __init__(
|
20 |
-
self,
|
21 |
-
forward_transform,
|
22 |
-
inverse_transform,
|
23 |
-
embed_dim,
|
24 |
-
filter_type = 'non-linear',
|
25 |
-
operator_type = 'diagonal',
|
26 |
-
sparsity_threshold = 0.0,
|
27 |
-
use_complex_kernels = True,
|
28 |
-
hidden_size_factor = 2,
|
29 |
-
factorization = None,
|
30 |
-
separable = False,
|
31 |
-
rank = 1e-2,
|
32 |
-
complex_activation = 'real',
|
33 |
-
spectral_layers = 1,
|
34 |
-
drop_rate = 0):
|
35 |
-
super(SpectralFilterLayer, self).__init__()
|
36 |
-
|
37 |
-
if filter_type == 'non-linear' and isinstance(forward_transform, RealSHT):
|
38 |
-
self.filter = SpectralAttentionS2(forward_transform,
|
39 |
-
inverse_transform,
|
40 |
-
embed_dim,
|
41 |
-
operator_type = operator_type,
|
42 |
-
sparsity_threshold = sparsity_threshold,
|
43 |
-
hidden_size_factor = hidden_size_factor,
|
44 |
-
complex_activation = complex_activation,
|
45 |
-
spectral_layers = spectral_layers,
|
46 |
-
drop_rate = drop_rate,
|
47 |
-
bias = False)
|
48 |
-
|
49 |
-
elif filter_type == 'non-linear' and isinstance(forward_transform, RealFFT2):
|
50 |
-
self.filter = SpectralAttention2d(forward_transform,
|
51 |
-
inverse_transform,
|
52 |
-
embed_dim,
|
53 |
-
sparsity_threshold = sparsity_threshold,
|
54 |
-
use_complex_kernels = use_complex_kernels,
|
55 |
-
hidden_size_factor = hidden_size_factor,
|
56 |
-
complex_activation = complex_activation,
|
57 |
-
spectral_layers = spectral_layers,
|
58 |
-
drop_rate = drop_rate,
|
59 |
-
bias = False)
|
60 |
-
|
61 |
-
elif filter_type == 'linear':
|
62 |
-
self.filter = SpectralConvS2(forward_transform,
|
63 |
-
inverse_transform,
|
64 |
-
embed_dim,
|
65 |
-
embed_dim,
|
66 |
-
operator_type = operator_type,
|
67 |
-
rank = rank,
|
68 |
-
factorization = factorization,
|
69 |
-
separable = separable,
|
70 |
-
bias = True)
|
71 |
-
|
72 |
-
else:
|
73 |
-
raise(NotImplementedError)
|
74 |
-
|
75 |
-
def forward(self, x):
|
76 |
-
return self.filter(x)
|
77 |
-
|
78 |
-
class SphericalFourierNeuralOperatorBlock(nn.Module):
|
79 |
-
"""
|
80 |
-
Helper module for a single SFNO/FNO block. Can use both FFTs and SHTs to represent either FNO or SFNO blocks.
|
81 |
-
"""
|
82 |
-
def __init__(
|
83 |
-
self,
|
84 |
-
forward_transform,
|
85 |
-
inverse_transform,
|
86 |
-
embed_dim,
|
87 |
-
filter_type = 'non-linear',
|
88 |
-
operator_type = 'diagonal',
|
89 |
-
mlp_ratio = 2.,
|
90 |
-
drop_rate = 0.,
|
91 |
-
drop_path = 0.,
|
92 |
-
act_layer = nn.GELU,
|
93 |
-
norm_layer = (nn.LayerNorm, nn.LayerNorm),
|
94 |
-
sparsity_threshold = 0.0,
|
95 |
-
use_complex_kernels = True,
|
96 |
-
factorization = None,
|
97 |
-
separable = False,
|
98 |
-
rank = 128,
|
99 |
-
inner_skip = 'linear',
|
100 |
-
outer_skip = None, # None, nn.linear or nn.Identity
|
101 |
-
concat_skip = False,
|
102 |
-
use_mlp = True,
|
103 |
-
complex_activation = 'real',
|
104 |
-
spectral_layers = 3):
|
105 |
-
super(SphericalFourierNeuralOperatorBlock, self).__init__()
|
106 |
-
|
107 |
-
# norm layer
|
108 |
-
self.norm0 = norm_layer[0]() #((h,w))
|
109 |
-
|
110 |
-
# convolution layer
|
111 |
-
self.filter = SpectralFilterLayer(forward_transform,
|
112 |
-
inverse_transform,
|
113 |
-
embed_dim,
|
114 |
-
filter_type,
|
115 |
-
operator_type = operator_type,
|
116 |
-
sparsity_threshold = sparsity_threshold,
|
117 |
-
use_complex_kernels = use_complex_kernels,
|
118 |
-
hidden_size_factor = mlp_ratio,
|
119 |
-
factorization = factorization,
|
120 |
-
separable = separable,
|
121 |
-
rank = rank,
|
122 |
-
complex_activation = complex_activation,
|
123 |
-
spectral_layers = spectral_layers,
|
124 |
-
drop_rate = drop_rate)
|
125 |
-
|
126 |
-
if inner_skip == 'linear':
|
127 |
-
self.inner_skip = nn.Conv2d(embed_dim, embed_dim, 1, 1)
|
128 |
-
elif inner_skip == 'identity':
|
129 |
-
self.inner_skip = nn.Identity()
|
130 |
-
|
131 |
-
self.concat_skip = concat_skip
|
132 |
-
|
133 |
-
if concat_skip and inner_skip is not None:
|
134 |
-
self.inner_skip_conv = nn.Conv2d(2*embed_dim, embed_dim, 1, bias=False)
|
135 |
-
|
136 |
-
if filter_type == 'linear' or filter_type == 'local':
|
137 |
-
self.act_layer = act_layer()
|
138 |
-
|
139 |
-
# dropout
|
140 |
-
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
141 |
-
|
142 |
-
# norm layer
|
143 |
-
self.norm1 = norm_layer[1]() #((h,w))
|
144 |
-
|
145 |
-
if use_mlp == True:
|
146 |
-
mlp_hidden_dim = int(embed_dim * mlp_ratio)
|
147 |
-
self.mlp = MLP(in_features = embed_dim,
|
148 |
-
hidden_features = mlp_hidden_dim,
|
149 |
-
act_layer = act_layer,
|
150 |
-
drop_rate = drop_rate,
|
151 |
-
checkpointing = False)
|
152 |
-
|
153 |
-
if outer_skip == 'linear':
|
154 |
-
self.outer_skip = nn.Conv2d(embed_dim, embed_dim, 1, 1)
|
155 |
-
elif outer_skip == 'identity':
|
156 |
-
self.outer_skip = nn.Identity()
|
157 |
-
|
158 |
-
if concat_skip and outer_skip is not None:
|
159 |
-
self.outer_skip_conv = nn.Conv2d(2*embed_dim, embed_dim, 1, bias=False)
|
160 |
-
|
161 |
-
def forward(self, x):
|
162 |
-
|
163 |
-
x = self.norm0(x)
|
164 |
-
|
165 |
-
x, residual = self.filter(x)
|
166 |
-
|
167 |
-
if hasattr(self, 'inner_skip'):
|
168 |
-
if self.concat_skip:
|
169 |
-
x = torch.cat((x, self.inner_skip(residual)), dim=1)
|
170 |
-
x = self.inner_skip_conv(x)
|
171 |
-
else:
|
172 |
-
x = x + self.inner_skip(residual)
|
173 |
-
|
174 |
-
if hasattr(self, 'act_layer'):
|
175 |
-
x = self.act_layer(x)
|
176 |
-
|
177 |
-
x = self.norm1(x)
|
178 |
-
|
179 |
-
if hasattr(self, 'mlp'):
|
180 |
-
x = self.mlp(x)
|
181 |
-
|
182 |
-
x = self.drop_path(x)
|
183 |
-
|
184 |
-
if hasattr(self, 'outer_skip'):
|
185 |
-
if self.concat_skip:
|
186 |
-
x = torch.cat((x, self.outer_skip(residual)), dim=1)
|
187 |
-
x = self.outer_skip_conv(x)
|
188 |
-
else:
|
189 |
-
x = x + self.outer_skip(residual)
|
190 |
-
|
191 |
-
return x
|
192 |
-
|
193 |
-
class SphericalFourierNeuralOperatorNet(nn.Module):
|
194 |
-
"""
|
195 |
-
SphericalFourierNeuralOperator module. Can use both FFTs and SHTs to represent either FNO or SFNO,
|
196 |
-
both linear and non-linear variants.
|
197 |
-
|
198 |
-
Parameters
|
199 |
-
----------
|
200 |
-
filter_type : str, optional
|
201 |
-
Type of filter to use ('linear', 'non-linear'), by default "linear"
|
202 |
-
spectral_transform : str, optional
|
203 |
-
Type of spectral transformation to use, by default "sht"
|
204 |
-
operator_type : str, optional
|
205 |
-
Type of operator to use ('vector', 'diagonal'), by default "vector"
|
206 |
-
img_shape : tuple, optional
|
207 |
-
Shape of the input channels, by default (128, 256)
|
208 |
-
scale_factor : int, optional
|
209 |
-
Scale factor to use, by default 3
|
210 |
-
in_chans : int, optional
|
211 |
-
Number of input channels, by default 3
|
212 |
-
out_chans : int, optional
|
213 |
-
Number of output channels, by default 3
|
214 |
-
embed_dim : int, optional
|
215 |
-
Dimension of the embeddings, by default 256
|
216 |
-
num_layers : int, optional
|
217 |
-
Number of layers in the network, by default 4
|
218 |
-
activation_function : str, optional
|
219 |
-
Activation function to use, by default "gelu"
|
220 |
-
encoder_layers : int, optional
|
221 |
-
Number of layers in the encoder, by default 1
|
222 |
-
use_mlp : int, optional
|
223 |
-
Whether to use MLP, by default True
|
224 |
-
mlp_ratio : int, optional
|
225 |
-
Ratio of MLP to use, by default 2.0
|
226 |
-
drop_rate : float, optional
|
227 |
-
Dropout rate, by default 0.0
|
228 |
-
drop_path_rate : float, optional
|
229 |
-
Dropout path rate, by default 0.0
|
230 |
-
sparsity_threshold : float, optional
|
231 |
-
Threshold for sparsity, by default 0.0
|
232 |
-
normalization_layer : str, optional
|
233 |
-
Type of normalization layer to use ("layer_norm", "instance_norm", "none"), by default "instance_norm"
|
234 |
-
hard_thresholding_fraction : float, optional
|
235 |
-
Fraction of hard thresholding (frequency cutoff) to apply, by default 1.0
|
236 |
-
use_complex_kernels : bool, optional
|
237 |
-
Whether to use complex kernels, by default True
|
238 |
-
big_skip : bool, optional
|
239 |
-
Whether to add a single large skip connection, by default True
|
240 |
-
rank : float, optional
|
241 |
-
Rank of the approximation, by default 1.0
|
242 |
-
factorization : Any, optional
|
243 |
-
Type of factorization to use, by default None
|
244 |
-
separable : bool, optional
|
245 |
-
Whether to use separable convolutions, by default False
|
246 |
-
rank : (int, Tuple[int]), optional
|
247 |
-
If a factorization is used, which rank to use. Argument is passed to tensorly
|
248 |
-
complex_activation : str, optional
|
249 |
-
Type of complex activation function to use, by default "real"
|
250 |
-
spectral_layers : int, optional
|
251 |
-
Number of spectral layers, by default 3
|
252 |
-
pos_embed : bool, optional
|
253 |
-
Whether to use positional embedding, by default True
|
254 |
-
|
255 |
-
Example:
|
256 |
-
--------
|
257 |
-
>>> model = SphericalFourierNeuralOperatorNet(
|
258 |
-
... img_shape=(128, 256),
|
259 |
-
... scale_factor=4,
|
260 |
-
... in_chans=2,
|
261 |
-
... out_chans=2,
|
262 |
-
... embed_dim=16,
|
263 |
-
... num_layers=2,
|
264 |
-
... encoder_layers=1,
|
265 |
-
... num_blocks=4,
|
266 |
-
... spectral_layers=2,
|
267 |
-
... use_mlp=True,)
|
268 |
-
>>> model(torch.randn(1, 2, 128, 256)).shape
|
269 |
-
torch.Size([1, 2, 128, 256])
|
270 |
-
"""
|
271 |
-
|
272 |
-
def __init__(
|
273 |
-
self,
|
274 |
-
filter_type = 'linear',
|
275 |
-
spectral_transform = 'sht',
|
276 |
-
operator_type = 'vector',
|
277 |
-
img_size = (128, 256),
|
278 |
-
scale_factor = 4,
|
279 |
-
in_chans = 3,
|
280 |
-
out_chans = 3,
|
281 |
-
embed_dim = 256,
|
282 |
-
num_layers = 4,
|
283 |
-
activation_function = 'gelu',
|
284 |
-
encoder_layers = 1,
|
285 |
-
use_mlp = True,
|
286 |
-
mlp_ratio = 2.,
|
287 |
-
drop_rate = 0.,
|
288 |
-
drop_path_rate = 0.,
|
289 |
-
sparsity_threshold = 0.0,
|
290 |
-
normalization_layer = 'instance_norm',
|
291 |
-
hard_thresholding_fraction = 1.0,
|
292 |
-
use_complex_kernels = True,
|
293 |
-
big_skip = False,
|
294 |
-
factorization = None,
|
295 |
-
separable = False,
|
296 |
-
rank = 128,
|
297 |
-
complex_activation = 'real',
|
298 |
-
spectral_layers = 2,
|
299 |
-
pos_embed = True
|
300 |
-
):
|
301 |
-
|
302 |
-
super(SphericalFourierNeuralOperatorNet, self).__init__()
|
303 |
-
|
304 |
-
self.filter_type = filter_type
|
305 |
-
self.spectral_transform = spectral_transform
|
306 |
-
self.operator_type = operator_type
|
307 |
-
self.img_size = img_size
|
308 |
-
self.scale_factor = scale_factor
|
309 |
-
self.in_chans = in_chans
|
310 |
-
self.out_chans = out_chans
|
311 |
-
self.embed_dim = self.num_features = embed_dim
|
312 |
-
self.pos_embed_dim = self.embed_dim
|
313 |
-
self.num_layers = num_layers
|
314 |
-
self.hard_thresholding_fraction = hard_thresholding_fraction
|
315 |
-
self.normalization_layer = normalization_layer
|
316 |
-
self.use_mlp = use_mlp
|
317 |
-
self.encoder_layers = encoder_layers
|
318 |
-
self.big_skip = big_skip
|
319 |
-
self.factorization = factorization
|
320 |
-
self.separable = separable,
|
321 |
-
self.rank = rank
|
322 |
-
self.complex_activation = complex_activation
|
323 |
-
self.spectral_layers = spectral_layers
|
324 |
-
|
325 |
-
# activation function
|
326 |
-
if activation_function == 'relu':
|
327 |
-
self.activation_function = nn.ReLU
|
328 |
-
elif activation_function == 'gelu':
|
329 |
-
self.activation_function = nn.GELU
|
330 |
-
else:
|
331 |
-
raise ValueError(f"Unknown activation function {activation_function}")
|
332 |
-
|
333 |
-
# compute downsampled image size
|
334 |
-
self.h = self.img_size[0] // scale_factor
|
335 |
-
self.w = self.img_size[1] // scale_factor
|
336 |
-
|
337 |
-
# dropout
|
338 |
-
self.pos_drop = nn.Dropout(p=drop_rate) if drop_rate > 0. else nn.Identity()
|
339 |
-
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, self.num_layers)]
|
340 |
-
|
341 |
-
# pick norm layer
|
342 |
-
if self.normalization_layer == "layer_norm":
|
343 |
-
norm_layer0 = partial(nn.LayerNorm, normalized_shape=(self.img_size[0], self.img_size[1]), eps=1e-6)
|
344 |
-
norm_layer1 = partial(nn.LayerNorm, normalized_shape=(self.h, self.w), eps=1e-6)
|
345 |
-
elif self.normalization_layer == "instance_norm":
|
346 |
-
norm_layer0 = partial(nn.InstanceNorm2d, num_features=self.embed_dim, eps=1e-6, affine=True, track_running_stats=False)
|
347 |
-
norm_layer1 = norm_layer0
|
348 |
-
elif self.normalization_layer == "none":
|
349 |
-
norm_layer0 = nn.Identity
|
350 |
-
norm_layer1 = norm_layer0
|
351 |
-
else:
|
352 |
-
raise NotImplementedError(f"Error, normalization {self.normalization_layer} not implemented.")
|
353 |
-
|
354 |
-
if pos_embed:
|
355 |
-
self.pos_embed = nn.Parameter(torch.zeros(1, self.embed_dim, self.img_size[0], self.img_size[1]))
|
356 |
-
#self.pos_embed = posemb_sincos_2d(900, 1800, 128)
|
357 |
-
pass
|
358 |
-
#x = torch.linspace(-np.pi, np.pi, 900)
|
359 |
-
#y = torch.linspace(-np.pi, np.pi, 1800)
|
360 |
-
#x, y = torch.meshgrid(x, y)
|
361 |
-
#self.pos_embed = torch.stack((torch.sin(x), torch.sin(y), torch.cos(x), torch.cos(y)), dim=0).unsqueeze(0).cuda()
|
362 |
-
#self.pos_embed = nn.Parameter(torch.zeros(1, self.embed_dim, self.img_size[0], self.img_size[1]))
|
363 |
-
#self.pos_direct = nn.Conv2d(4, self.embed_dim, 1, bias=False)
|
364 |
-
else:
|
365 |
-
self.pos_embed = None
|
366 |
-
|
367 |
-
# encoder
|
368 |
-
"""encoder_hidden_dim = self.embed_dim
|
369 |
-
current_dim = self.in_chans
|
370 |
-
encoder_modules = []
|
371 |
-
for i in range(self.encoder_layers):
|
372 |
-
encoder_modules.append(nn.Conv2d(current_dim, encoder_hidden_dim, 1, bias=True))
|
373 |
-
encoder_modules.append(self.activation_function())
|
374 |
-
current_dim = encoder_hidden_dim
|
375 |
-
encoder_modules.append(nn.Conv2d(current_dim, self.embed_dim, 1, bias=False))
|
376 |
-
self.encoder = nn.Sequential(*encoder_modules)"""
|
377 |
-
|
378 |
-
# prepare the spectral transform
|
379 |
-
if self.spectral_transform == 'sht':
|
380 |
-
|
381 |
-
modes_lat = int(self.h * self.hard_thresholding_fraction)
|
382 |
-
modes_lon = int((self.w // 2 + 1) * self.hard_thresholding_fraction)
|
383 |
-
|
384 |
-
self.trans_down = RealSHT(*self.img_size, lmax=modes_lat, mmax=modes_lon, grid='equiangular').float()
|
385 |
-
self.itrans_up = InverseRealSHT(*self.img_size, lmax=modes_lat, mmax=modes_lon, grid='equiangular').float()
|
386 |
-
self.trans = RealSHT(self.h, self.w, lmax=modes_lat, mmax=modes_lon, grid='legendre-gauss').float()
|
387 |
-
self.itrans = InverseRealSHT(self.h, self.w, lmax=modes_lat, mmax=modes_lon, grid='legendre-gauss').float()
|
388 |
-
|
389 |
-
elif self.spectral_transform == 'fft':
|
390 |
-
|
391 |
-
modes_lat = int(self.h * self.hard_thresholding_fraction)
|
392 |
-
modes_lon = int((self.w // 2 + 1) * self.hard_thresholding_fraction)
|
393 |
-
|
394 |
-
self.trans_down = RealFFT2(*self.img_size, lmax=modes_lat, mmax=modes_lon).float()
|
395 |
-
self.itrans_up = InverseRealFFT2(*self.img_size, lmax=modes_lat, mmax=modes_lon).float()
|
396 |
-
self.trans = RealFFT2(self.h, self.w, lmax=modes_lat, mmax=modes_lon).float()
|
397 |
-
self.itrans = InverseRealFFT2(self.h, self.w, lmax=modes_lat, mmax=modes_lon).float()
|
398 |
-
|
399 |
-
else:
|
400 |
-
raise(ValueError('Unknown spectral transform'))
|
401 |
-
|
402 |
-
self.blocks = nn.ModuleList([])
|
403 |
-
for i in range(self.num_layers):
|
404 |
-
|
405 |
-
first_layer = i == 0
|
406 |
-
last_layer = i == self.num_layers-1
|
407 |
-
|
408 |
-
forward_transform = self.trans_down if first_layer else self.trans
|
409 |
-
inverse_transform = self.itrans_up if last_layer else self.itrans
|
410 |
-
|
411 |
-
inner_skip = 'linear'
|
412 |
-
outer_skip = 'identity'
|
413 |
-
|
414 |
-
if first_layer:
|
415 |
-
norm_layer = (norm_layer0, norm_layer1)
|
416 |
-
elif last_layer:
|
417 |
-
norm_layer = (norm_layer1, norm_layer0)
|
418 |
-
else:
|
419 |
-
norm_layer = (norm_layer1, norm_layer1)
|
420 |
-
|
421 |
-
block = SphericalFourierNeuralOperatorBlock(forward_transform,
|
422 |
-
inverse_transform,
|
423 |
-
self.embed_dim,
|
424 |
-
filter_type = filter_type,
|
425 |
-
operator_type = self.operator_type,
|
426 |
-
mlp_ratio = mlp_ratio,
|
427 |
-
drop_rate = drop_rate,
|
428 |
-
drop_path = dpr[i],
|
429 |
-
act_layer = self.activation_function,
|
430 |
-
norm_layer = norm_layer,
|
431 |
-
sparsity_threshold = sparsity_threshold,
|
432 |
-
use_complex_kernels = use_complex_kernels,
|
433 |
-
inner_skip = inner_skip,
|
434 |
-
outer_skip = outer_skip,
|
435 |
-
use_mlp = use_mlp,
|
436 |
-
factorization = self.factorization,
|
437 |
-
separable = self.separable,
|
438 |
-
rank = self.rank,
|
439 |
-
complex_activation = self.complex_activation,
|
440 |
-
spectral_layers = self.spectral_layers)
|
441 |
-
|
442 |
-
self.blocks.append(block)
|
443 |
-
|
444 |
-
# trunc_normal_(self.pos_embed, std=.02)
|
445 |
-
self.apply(self._init_weights)
|
446 |
-
|
447 |
-
def _init_weights(self, m):
|
448 |
-
if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d):
|
449 |
-
trunc_normal_(m.weight, std=.02)
|
450 |
-
#nn.init.normal_(m.weight, std=0.02)
|
451 |
-
if m.bias is not None:
|
452 |
-
nn.init.constant_(m.bias, 0)
|
453 |
-
|
454 |
-
@torch.jit.ignore
|
455 |
-
def no_weight_decay(self):
|
456 |
-
return {'pos_embed', 'cls_token'}
|
457 |
-
|
458 |
-
def forward_features(self, x):
|
459 |
-
|
460 |
-
x = self.pos_drop(x)
|
461 |
-
|
462 |
-
for blk in self.blocks:
|
463 |
-
x = blk(x)
|
464 |
-
|
465 |
-
return x
|
466 |
-
|
467 |
-
def forward(self, x):
|
468 |
-
|
469 |
-
#if self.big_skip:
|
470 |
-
#residual = x
|
471 |
-
|
472 |
-
#x = self.encoder(x)
|
473 |
-
|
474 |
-
#x = x + self.pos_embed
|
475 |
-
x = self.pos_embed
|
476 |
-
|
477 |
-
x = self.forward_features(x)
|
478 |
-
|
479 |
-
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|