rr-ss commited on
Commit
0677249
·
verified ·
1 Parent(s): 05ebc17

Upload folder using huggingface_hub

Browse files
__pycache__/polarisnet.cpython-39.pyc ADDED
Binary file (14.4 kB). View file
 
polarisnet.py ADDED
@@ -0,0 +1,526 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from operator import itemgetter
4
+
5
+ from typing import Type, Callable, Tuple, Optional, Set, List, Union
6
+ from timm.models.layers import drop_path, trunc_normal_, Mlp, DropPath
7
+ from timm.models.efficientnet_blocks import SqueezeExcite, DepthwiseSeparableConv
8
+
9
+ def exists(val):
10
+
11
+ return val is not None
12
+
13
+ def map_el_ind(arr, ind):
14
+
15
+ return list(map(itemgetter(ind), arr))
16
+
17
+ def sort_and_return_indices(arr):
18
+
19
+ indices = [ind for ind in range(len(arr))]
20
+ arr = zip(arr, indices)
21
+ arr = sorted(arr)
22
+
23
+ return map_el_ind(arr, 0), map_el_ind(arr, 1)
24
+
25
+ def calculate_permutations(num_dimensions, emb_dim):
26
+ total_dimensions = num_dimensions + 2
27
+ axial_dims = [ind for ind in range(1, total_dimensions) if ind != emb_dim]
28
+
29
+ permutations = []
30
+
31
+ for axial_dim in axial_dims:
32
+ last_two_dims = [axial_dim, emb_dim]
33
+ dims_rest = set(range(0, total_dimensions)) - set(last_two_dims)
34
+ permutation = [*dims_rest, *last_two_dims]
35
+ permutations.append(permutation)
36
+
37
+ return permutations
38
+
39
+ class ChanLayerNorm(nn.Module):
40
+ def __init__(self, dim, eps = 1e-5):
41
+ super().__init__()
42
+ self.eps = eps
43
+ self.g = nn.Parameter(torch.ones(1, dim, 1, 1))
44
+ self.b = nn.Parameter(torch.zeros(1, dim, 1, 1))
45
+
46
+ def forward(self, x):
47
+
48
+ std = torch.var(x, dim = 1, unbiased = False, keepdim = True).sqrt()
49
+ mean = torch.mean(x, dim = 1, keepdim = True)
50
+ return (x - mean) / (std + self.eps) * self.g + self.b
51
+
52
+ class PreNorm(nn.Module):
53
+ def __init__(self, dim, fn):
54
+ super().__init__()
55
+ self.fn = fn
56
+ self.norm = nn.LayerNorm(dim)
57
+
58
+ def forward(self, x):
59
+
60
+ x = self.norm(x)
61
+
62
+ return self.fn(x)
63
+
64
+ class PermuteToFrom(nn.Module):
65
+
66
+ def __init__(self, permutation, fn):
67
+ super().__init__()
68
+
69
+ self.fn = fn
70
+ _, inv_permutation = sort_and_return_indices(permutation)
71
+ self.permutation = permutation
72
+ self.inv_permutation = inv_permutation
73
+
74
+ def forward(self, x, **kwargs):
75
+
76
+ axial = x.permute(*self.permutation).contiguous()
77
+ shape = axial.shape
78
+ *_, t, d = shape
79
+ axial = axial.reshape(-1, t, d)
80
+ axial = self.fn(axial, **kwargs)
81
+ axial = axial.reshape(*shape)
82
+ axial = axial.permute(*self.inv_permutation).contiguous()
83
+
84
+ return axial
85
+
86
+ class AxialPositionalEmbedding(nn.Module):
87
+ def __init__(self, dim, shape, emb_dim_index = 1):
88
+ super().__init__()
89
+ parameters = []
90
+ total_dimensions = len(shape) + 2
91
+ ax_dim_indexes = [i for i in range(1, total_dimensions) if i != emb_dim_index]
92
+
93
+ self.num_axials = len(shape)
94
+
95
+ for i, (axial_dim, axial_dim_index) in enumerate(zip(shape, ax_dim_indexes)):
96
+ shape = [1] * total_dimensions
97
+ shape[emb_dim_index] = dim
98
+ shape[axial_dim_index] = axial_dim
99
+ parameter = nn.Parameter(torch.randn(*shape))
100
+ setattr(self, f'param_{i}', parameter)
101
+
102
+ def forward(self, x):
103
+
104
+ for i in range(self.num_axials):
105
+ x = x + getattr(self, f'param_{i}')
106
+
107
+ return x
108
+
109
+ class SelfAttention(nn.Module):
110
+ def __init__(self, dim, heads, dim_heads=None, drop=0):
111
+ super().__init__()
112
+ self.dim_heads = (dim // heads) if dim_heads is None else dim_heads
113
+ dim_hidden = self.dim_heads * heads
114
+ self.drop_rate = drop
115
+ self.heads = heads
116
+ self.to_q = nn.Linear(dim, dim_hidden, bias = False)
117
+ self.to_kv = nn.Linear(dim, 2 * dim_hidden, bias = False)
118
+ self.to_out = nn.Linear(dim_hidden, dim)
119
+ self.proj_drop = DropPath(drop)
120
+
121
+ def forward(self, x, kv = None):
122
+ kv = x if kv is None else kv
123
+ q, k, v = (self.to_q(x), *self.to_kv(kv).chunk(2, dim=-1))
124
+ b, t, d, h, e = *q.shape, self.heads, self.dim_heads
125
+ merge_heads = lambda x: x.reshape(b, -1, h, e).transpose(1, 2).reshape(b * h, -1, e)
126
+
127
+ q, k, v = map(merge_heads, (q, k, v))
128
+ dots = torch.einsum('bie,bje->bij', q, k) * (e ** -0.5)
129
+ dots = dots.softmax(dim=-1)
130
+
131
+ out = torch.einsum('bij,bje->bie', dots, v)
132
+ out = out.reshape(b, h, -1, e).transpose(1, 2).reshape(b, -1, d)
133
+ out = self.to_out(out)
134
+ out = self.proj_drop(out)
135
+
136
+ return out
137
+
138
+ class AxialTransformerBlock(nn.Module):
139
+ def __init__(self,
140
+ dim,
141
+ axial_pos_emb_shape,
142
+ pos_embed,
143
+ heads = 8,
144
+ dim_heads = None,
145
+ drop = 0.,
146
+ drop_path_rate=0.,
147
+ ):
148
+ super().__init__()
149
+
150
+ dim_index = 1
151
+
152
+ permutations = calculate_permutations(2, dim_index)
153
+
154
+ self.pos_emb = AxialPositionalEmbedding(dim, axial_pos_emb_shape, dim_index) if pos_embed else nn.Identity()
155
+
156
+ self.height_attn, self.width_attn = nn.ModuleList([PermuteToFrom(permutation, PreNorm(dim, SelfAttention(dim, heads, dim_heads, drop=drop))) for permutation in permutations])
157
+
158
+ self.FFN = nn.Sequential(
159
+ ChanLayerNorm(dim),
160
+ nn.Conv2d(dim, dim * 4, 3, padding = 1),
161
+ nn.GELU(),
162
+ DropPath(drop),
163
+ nn.Conv2d(dim * 4, dim, 3, padding = 1),
164
+ DropPath(drop),
165
+
166
+ ChanLayerNorm(dim),
167
+ nn.Conv2d(dim, dim * 4, 3, padding = 1),
168
+ nn.GELU(),
169
+ DropPath(drop),
170
+ nn.Conv2d(dim * 4, dim, 3, padding = 1),
171
+ DropPath(drop),
172
+ )
173
+
174
+ self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
175
+
176
+ def forward(self, x):
177
+
178
+ x = self.pos_emb(x)
179
+ x = x + self.drop_path(self.height_attn(x))
180
+ x = x + self.drop_path(self.width_attn(x))
181
+ x = x + self.drop_path(self.FFN(x))
182
+
183
+ return x
184
+
185
+ def pair(t):
186
+
187
+ return t if isinstance(t, tuple) else (t, t)
188
+
189
+ def _gelu_ignore_parameters(*args, **kwargs) -> nn.Module:
190
+
191
+ activation = nn.GELU()
192
+
193
+ return activation
194
+
195
+ class DoubleConv(nn.Module):
196
+
197
+ def __init__(
198
+ self,
199
+ in_channels: int,
200
+ out_channels: int,
201
+ downscale: bool = False,
202
+ act_layer: Type[nn.Module] = nn.GELU,
203
+ norm_layer: Type[nn.Module] = nn.BatchNorm2d,
204
+ drop_path: float = 0.,
205
+ ) -> None:
206
+
207
+ super(DoubleConv, self).__init__()
208
+
209
+ self.drop_path_rate: float = drop_path
210
+
211
+ if act_layer == nn.GELU:
212
+ act_layer = _gelu_ignore_parameters
213
+
214
+ self.main_path = nn.Sequential(
215
+ norm_layer(in_channels),
216
+ nn.Conv2d(in_channels=in_channels, out_channels=in_channels, kernel_size=(1, 1)),
217
+ DepthwiseSeparableConv(in_chs=in_channels, out_chs=out_channels, stride=2 if downscale else 1,
218
+ act_layer=act_layer, norm_layer=norm_layer, drop_path_rate=drop_path),
219
+ SqueezeExcite(in_chs=out_channels, rd_ratio=0.25),
220
+ nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=(1, 1))
221
+ )
222
+
223
+ if downscale:
224
+ self.skip_path = nn.Sequential(
225
+ nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2)),
226
+ nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=(1, 1))
227
+ )
228
+ else:
229
+ self.skip_path = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=(1, 1))
230
+
231
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
232
+
233
+ output = self.main_path(x)
234
+
235
+ if self.drop_path_rate > 0.:
236
+ output = drop_path(output, self.drop_path_rate, self.training)
237
+
238
+ x = output + self.skip_path(x)
239
+
240
+ return x
241
+
242
+
243
+ class DeconvModule(nn.Module):
244
+
245
+ def __init__(self,
246
+ in_channels,
247
+ out_channels,
248
+ norm_layer=nn.BatchNorm2d,
249
+ act_layer=nn.Mish,
250
+ kernel_size=4,
251
+ scale_factor=2):
252
+ super(DeconvModule, self).__init__()
253
+
254
+ assert (kernel_size - scale_factor >= 0) and\
255
+ (kernel_size - scale_factor) % 2 == 0,\
256
+ f'kernel_size should be greater than or equal to scale_factor '\
257
+ f'and (kernel_size - scale_factor) should be even numbers, '\
258
+ f'while the kernel size is {kernel_size} and scale_factor is '\
259
+ f'{scale_factor}.'
260
+
261
+ stride = scale_factor
262
+ padding = (kernel_size - scale_factor) // 2
263
+ deconv = nn.ConvTranspose2d(
264
+ in_channels,
265
+ out_channels,
266
+ kernel_size=kernel_size,
267
+ stride=stride,
268
+ padding=padding)
269
+
270
+ norm = norm_layer(out_channels)
271
+ activate = act_layer()
272
+ self.deconv_upsamping = nn.Sequential(deconv, norm, activate)
273
+
274
+ def forward(self, x):
275
+
276
+ out = self.deconv_upsamping(x)
277
+
278
+ return out
279
+
280
+ class Stage(nn.Module):
281
+
282
+ def __init__(self,
283
+ image_size: int,
284
+ depth: int,
285
+ in_channels: int,
286
+ out_channels: int,
287
+ type_name: str,
288
+ pos_embed: bool,
289
+ num_heads: int = 32,
290
+ drop: float = 0.,
291
+ drop_path: Union[List[float], float] = 0.,
292
+ act_layer: Type[nn.Module] = nn.GELU,
293
+ norm_layer: Type[nn.Module] = nn.BatchNorm2d,
294
+ ):
295
+ super().__init__()
296
+ self.type_name = type_name
297
+
298
+ if self.type_name == "encoder":
299
+
300
+ self.conv = DoubleConv(
301
+ in_channels=in_channels,
302
+ out_channels=out_channels,
303
+ downscale=True,
304
+ act_layer=act_layer,
305
+ norm_layer=norm_layer,
306
+ drop_path=drop_path[0],
307
+ )
308
+
309
+ self.blocks = nn.Sequential(*[
310
+ AxialTransformerBlock(
311
+ dim=out_channels,
312
+ axial_pos_emb_shape=pair(image_size),
313
+ heads = num_heads,
314
+ drop = drop,
315
+ drop_path_rate=drop_path[index],
316
+ dim_heads = None,
317
+ pos_embed=pos_embed
318
+ )
319
+ for index in range(depth)
320
+ ])
321
+
322
+ elif self.type_name == "decoder":
323
+
324
+ self.upsample = DeconvModule(
325
+ in_channels=in_channels,
326
+ out_channels=out_channels,
327
+ norm_layer=norm_layer,
328
+ act_layer=act_layer
329
+ )
330
+
331
+ self.conv = DoubleConv(
332
+ in_channels=in_channels,
333
+ out_channels=out_channels,
334
+ downscale=False,
335
+ act_layer=act_layer,
336
+ norm_layer=norm_layer,
337
+ drop_path=drop_path[0],
338
+ )
339
+
340
+ self.blocks = nn.Sequential(*[
341
+ AxialTransformerBlock(
342
+ dim=out_channels,
343
+ axial_pos_emb_shape=pair(image_size),
344
+ heads = num_heads,
345
+ drop = drop,
346
+ drop_path_rate=drop_path[index],
347
+ dim_heads = None,
348
+ pos_embed=pos_embed
349
+ )
350
+ for index in range(depth)
351
+ ])
352
+
353
+ def forward(self, x, skip=None):
354
+
355
+ if self.type_name == "encoder":
356
+ x = self.conv(x)
357
+ x = self.blocks(x)
358
+
359
+ elif self.type_name == "decoder":
360
+ x = self.upsample(x)
361
+ x = torch.cat([skip, x], dim=1)
362
+ x = self.conv(x)
363
+ x = self.blocks(x)
364
+
365
+ return x
366
+
367
+ class FinalExpand(nn.Module):
368
+ def __init__(
369
+ self,
370
+ in_channels,
371
+ embed_dim,
372
+ out_channels,
373
+ norm_layer,
374
+ act_layer,
375
+ ):
376
+ super().__init__()
377
+ self.upsample = DeconvModule(
378
+ in_channels=in_channels,
379
+ out_channels=embed_dim,
380
+ norm_layer=norm_layer,
381
+ act_layer=act_layer
382
+ )
383
+
384
+ self.conv = nn.Sequential(
385
+ nn.Conv2d(in_channels=embed_dim*2, out_channels=embed_dim, kernel_size=3, stride=1, padding=1),
386
+ act_layer(),
387
+ nn.Conv2d(in_channels=embed_dim, out_channels=embed_dim, kernel_size=3, stride=1, padding=1),
388
+ act_layer(),
389
+ )
390
+
391
+ def forward(self, skip, x):
392
+ x = self.upsample(x)
393
+ x = torch.cat([skip, x], dim=1)
394
+ x = self.conv(x)
395
+
396
+ return x
397
+
398
+ class polarisnet(nn.Module):
399
+ def __init__(
400
+ self,
401
+ image_size=224,
402
+ in_channels=1,
403
+ out_channels=1,
404
+ embed_dim=64,
405
+ depths=[2,2,2,2],
406
+ channels=[64,128,256,512],
407
+ num_heads = 16,
408
+ drop=0.,
409
+ drop_path=0.1,
410
+ act_layer=nn.GELU,
411
+ norm_layer=nn.BatchNorm2d,
412
+ pos_embed=False
413
+ ):
414
+
415
+ super(polarisnet, self).__init__()
416
+ self.num_stages = len(depths)
417
+ self.num_features = channels[-1]
418
+ self.embed_dim = channels[0]
419
+
420
+ self.conv_first = nn.Sequential(
421
+ nn.Conv2d(in_channels=in_channels, out_channels=embed_dim, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
422
+ act_layer(),
423
+ nn.Conv2d(in_channels=embed_dim, out_channels=embed_dim, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
424
+ act_layer(),
425
+ )
426
+
427
+ drop_path = torch.linspace(0.0, drop_path, sum(depths)).tolist()
428
+ encoder_stages = []
429
+
430
+ for index in range(self.num_stages):
431
+
432
+ encoder_stages.append(
433
+ Stage(
434
+ image_size=image_size//(pow(2,1+index)),
435
+ depth=depths[index],
436
+ in_channels=embed_dim if index == 0 else channels[index - 1],
437
+ out_channels=channels[index],
438
+ num_heads=num_heads,
439
+ drop=drop,
440
+ drop_path=drop_path[sum(depths[:index]):sum(depths[:index + 1])],
441
+ act_layer=act_layer,
442
+ norm_layer=norm_layer,
443
+ type_name = "encoder",
444
+ pos_embed=pos_embed
445
+ )
446
+ )
447
+
448
+ self.encoder_stages = nn.ModuleList(encoder_stages)
449
+
450
+ decoder_stages = []
451
+
452
+ for index in range(self.num_stages-1):
453
+
454
+ decoder_stages.append(
455
+ Stage(
456
+ image_size=image_size//(pow(2,self.num_stages-index-1)),
457
+ depth=depths[self.num_stages - index - 2],
458
+ in_channels=channels[self.num_stages - index - 1],
459
+ out_channels=channels[self.num_stages - index - 2],
460
+ num_heads=num_heads,
461
+ drop=drop,
462
+ drop_path=drop_path[sum(depths[:(self.num_stages-2-index)]):sum(depths[:(self.num_stages-2-index) + 1])],
463
+ act_layer=act_layer,
464
+ norm_layer=norm_layer,
465
+ type_name = "decoder",
466
+ pos_embed=pos_embed
467
+ )
468
+ )
469
+
470
+ self.decoder_stages = nn.ModuleList(decoder_stages)
471
+
472
+ self.norm = norm_layer(self.num_features)
473
+ self.norm_up= norm_layer(self.embed_dim)
474
+
475
+ self.up = FinalExpand(
476
+ in_channels=channels[0],
477
+ embed_dim=embed_dim,
478
+ out_channels=embed_dim,
479
+ norm_layer=norm_layer,
480
+ act_layer=act_layer
481
+ )
482
+
483
+ self.output = nn.Conv2d(embed_dim, out_channels, kernel_size=3, padding=1)
484
+
485
+ def encoder_forward(self, x: torch.Tensor) -> torch.Tensor:
486
+
487
+ outs = []
488
+ x = self.conv_first(x)
489
+
490
+ for stage in self.encoder_stages:
491
+ outs.append(x)
492
+ x = stage(x)
493
+
494
+ x = self.norm(x)
495
+
496
+ return x, outs
497
+
498
+ def decoder_forward(self, x: torch.Tensor, x_downsample: list) -> torch.Tensor:
499
+
500
+ for inx, stage in enumerate(self.decoder_stages):
501
+ x = stage(x, x_downsample[len(x_downsample)-1-inx])
502
+
503
+ x = self.norm_up(x)
504
+
505
+ return x
506
+
507
+ def up_x4(self, x: torch.Tensor, x_downsample: list):
508
+ x = self.up(x_downsample[0],x)
509
+ x = self.output(x)
510
+
511
+ return x
512
+
513
+ def forward(self, x):
514
+ x, x_downsample = self.encoder_forward(x)
515
+ x = self.decoder_forward(x,x_downsample)
516
+ x = self.up_x4(x,x_downsample)
517
+
518
+ return x
519
+
520
+ if __name__ == '__main__':
521
+ net = polarisnet(in_channels=1, embed_dim=64, pos_embed=True).cuda()
522
+
523
+ X = torch.randn(5, 1, 224, 224).cuda()
524
+ y = net(X)
525
+ print(y.shape)
526
+
sft_loop.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cae9e9a28e5c3ff0d328934c066d275371d5301db084a914431198134f66ada2
3
+ size 547572280