wogh2012 commited on
Commit
aefacda
·
1 Parent(s): c69af64

refactor: add implementations

Browse files
res/impl/DeepLabV3Plus.py ADDED
@@ -0,0 +1,404 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ paper: https://arxiv.org/abs/1802.02611
3
+ ref:
4
+ - https://github.com/tensorflow/models/tree/master/research/deeplab
5
+ - https://github.com/VainF/DeepLabV3Plus-Pytorch
6
+ - https://github.com/Hyunjulie/KR-Reading-Computer-Vision-Papers/blob/master/DeepLabv3%2B/deeplabv3p.py
7
+ """
8
+
9
+ import math
10
+ import torch
11
+ from torch import nn
12
+ from torch.functional import F
13
+
14
+
15
+ class AtrousSeparableConv1d(nn.Module):
16
+ def __init__(
17
+ self, inplanes, planes, kernel_size=3, stride=1, dilation=1, bias=False
18
+ ):
19
+ super(AtrousSeparableConv1d, self).__init__()
20
+
21
+ self.depthwise = nn.Conv1d(
22
+ inplanes,
23
+ inplanes,
24
+ kernel_size,
25
+ stride,
26
+ 0,
27
+ dilation,
28
+ groups=inplanes,
29
+ bias=bias,
30
+ )
31
+ self.pointwise = nn.Conv1d(inplanes, planes, 1, 1, 0, 1, 1, bias=bias)
32
+
33
+ def forward(self, x):
34
+ x = self.apply_fixed_padding(
35
+ x, self.depthwise.kernel_size[0], rate=self.depthwise.dilation[0]
36
+ )
37
+ x = self.depthwise(x)
38
+ x = self.pointwise(x)
39
+ return x
40
+
41
+ def apply_fixed_padding(self, inputs, kernel_size, rate):
42
+ """
43
+ 해당 함수는 (dilation)rate 와 kernel_size 에 따라 output 의 크기가 input 의 크기와 동일해질 수 있도록 input 에 padding 을 적용합니다.
44
+ 다만, stride 가 2 이상인 경우에는 해당 함수를 거치더라도 input 과 output 크기가 동일해지지 않을 수 있습니다.
45
+ 이 경우는 최대한 input 과 output 크기를 맞춰주는 것에 의미가 있고, 전체 네트워크의 마지막 upsample 단계에서 최종적으로 크기를 맞춰줍니다.
46
+ """
47
+ kernel_size_effective = kernel_size + (kernel_size - 1) * (rate - 1)
48
+ pad_total = kernel_size_effective - 1
49
+ pad_beg = pad_total // 2
50
+ pad_end = pad_total - pad_beg
51
+ padded_inputs = F.pad(inputs, (pad_beg, pad_end))
52
+ return padded_inputs
53
+
54
+
55
+ class Block(nn.Module):
56
+ def __init__(
57
+ self,
58
+ inplanes,
59
+ planes,
60
+ reps,
61
+ kernel_size=3,
62
+ stride=1,
63
+ dilation=1,
64
+ start_with_relu=True,
65
+ grow_first=True,
66
+ is_last=False,
67
+ ):
68
+ super(Block, self).__init__()
69
+
70
+ if planes != inplanes or stride != 1:
71
+ self.skip = nn.Conv1d(inplanes, planes, 1, stride=stride, bias=False)
72
+ self.skipbn = nn.BatchNorm1d(planes)
73
+ else:
74
+ self.skip = None
75
+
76
+ self.relu = nn.ReLU(inplace=True)
77
+ rep = []
78
+
79
+ filters = inplanes
80
+ if grow_first:
81
+ rep.append(self.relu)
82
+ rep.append(
83
+ AtrousSeparableConv1d(
84
+ inplanes, planes, kernel_size, stride=1, dilation=dilation
85
+ )
86
+ )
87
+ rep.append(nn.BatchNorm1d(planes))
88
+ filters = planes
89
+
90
+ for _ in range(reps - 1):
91
+ rep.append(self.relu)
92
+ rep.append(
93
+ AtrousSeparableConv1d(
94
+ filters, filters, kernel_size, stride=1, dilation=dilation
95
+ )
96
+ )
97
+ rep.append(nn.BatchNorm1d(filters))
98
+
99
+ if not grow_first:
100
+ rep.append(self.relu)
101
+ rep.append(
102
+ AtrousSeparableConv1d(
103
+ inplanes, planes, kernel_size, stride=1, dilation=dilation
104
+ )
105
+ )
106
+ rep.append(nn.BatchNorm1d(planes))
107
+
108
+ if not start_with_relu:
109
+ rep = rep[1:]
110
+
111
+ if stride == 2:
112
+ rep.append(AtrousSeparableConv1d(planes, planes, kernel_size, stride=2))
113
+ elif stride == 1:
114
+ if is_last:
115
+ rep.append(AtrousSeparableConv1d(planes, planes, kernel_size, stride=1))
116
+ else:
117
+ raise NotImplementedError("stride must be 1 or 2 in Block.")
118
+
119
+ self.rep = nn.Sequential(*rep)
120
+
121
+ def forward(self, inp):
122
+ x = self.rep(inp)
123
+
124
+ if self.skip is not None:
125
+ skip = self.skip(inp)
126
+ skip = self.skipbn(skip)
127
+ else:
128
+ skip = inp
129
+
130
+ x += skip
131
+
132
+ return x
133
+
134
+
135
+ class Xception(nn.Module):
136
+ """Modified Aligned Xception"""
137
+
138
+ def __init__(
139
+ self,
140
+ inplanes=1,
141
+ output_stride=16,
142
+ kernel_size=3,
143
+ middle_repeat=16,
144
+ middle_block_rate=1,
145
+ exit_block_rates=(1, 2),
146
+ ):
147
+ super(Xception, self).__init__()
148
+
149
+ if output_stride == 16:
150
+ entry3_stride = 2
151
+ elif output_stride == 8:
152
+ entry3_stride = 1
153
+ else:
154
+ raise NotImplementedError
155
+
156
+ self.conv1 = nn.Conv1d(
157
+ inplanes,
158
+ 32,
159
+ kernel_size,
160
+ stride=2,
161
+ padding=(kernel_size - 1) // 2,
162
+ bias=False,
163
+ )
164
+ self.bn1 = nn.BatchNorm1d(32)
165
+ self.relu = nn.ReLU(inplace=True)
166
+
167
+ self.conv2 = nn.Conv1d(
168
+ 32, 64, kernel_size, stride=1, padding=(kernel_size - 1) // 2, bias=False
169
+ )
170
+ self.bn2 = nn.BatchNorm1d(64)
171
+
172
+ self.entry1 = Block(
173
+ 64, 128, reps=2, kernel_size=kernel_size, stride=2, start_with_relu=False
174
+ )
175
+ self.entry2 = Block(
176
+ 128,
177
+ 256,
178
+ reps=2,
179
+ kernel_size=kernel_size,
180
+ stride=2,
181
+ start_with_relu=True,
182
+ grow_first=True,
183
+ )
184
+ self.entry3 = Block(
185
+ 256,
186
+ 728,
187
+ reps=2,
188
+ kernel_size=kernel_size,
189
+ stride=entry3_stride,
190
+ start_with_relu=True,
191
+ grow_first=True,
192
+ is_last=True,
193
+ )
194
+
195
+ self.middle = nn.Sequential(
196
+ *[
197
+ Block(
198
+ 728,
199
+ 728,
200
+ reps=3,
201
+ kernel_size=kernel_size,
202
+ stride=1,
203
+ dilation=middle_block_rate,
204
+ start_with_relu=True,
205
+ grow_first=True,
206
+ )
207
+ for _ in range(middle_repeat)
208
+ ]
209
+ )
210
+
211
+ self.exit = Block(
212
+ 728,
213
+ 1024,
214
+ reps=2,
215
+ kernel_size=kernel_size,
216
+ stride=1,
217
+ dilation=exit_block_rates[0],
218
+ start_with_relu=True,
219
+ grow_first=False,
220
+ is_last=True,
221
+ )
222
+
223
+ self.conv3 = AtrousSeparableConv1d(
224
+ 1024, 1536, kernel_size, stride=1, dilation=exit_block_rates[1]
225
+ )
226
+ self.bn3 = nn.BatchNorm1d(1536)
227
+
228
+ self.conv4 = AtrousSeparableConv1d(
229
+ 1536, 1536, kernel_size, stride=1, dilation=exit_block_rates[1]
230
+ )
231
+ self.bn4 = nn.BatchNorm1d(1536)
232
+
233
+ self.conv5 = AtrousSeparableConv1d(
234
+ 1536, 2048, kernel_size, stride=1, dilation=exit_block_rates[1]
235
+ )
236
+ self.bn5 = nn.BatchNorm1d(2048)
237
+
238
+ def forward(self, x: torch.Tensor):
239
+ x = self.conv1(x)
240
+ x = self.bn1(x)
241
+ x = self.relu(x)
242
+
243
+ x = self.conv2(x)
244
+ x = self.bn2(x)
245
+ x = self.relu(x)
246
+
247
+ low_level = x = self.entry1(x)
248
+
249
+ x = self.entry2(x)
250
+ x = self.entry3(x)
251
+
252
+ x = self.middle(x)
253
+
254
+ x = self.exit(x)
255
+ x = self.conv3(x)
256
+ x = self.bn3(x)
257
+ x = self.relu(x)
258
+
259
+ x = self.conv4(x)
260
+ x = self.bn4(x)
261
+ x = self.relu(x)
262
+
263
+ x = self.conv5(x)
264
+ x = self.bn5(x)
265
+ x = self.relu(x)
266
+
267
+ return x, low_level
268
+
269
+
270
+ class ASPP(nn.Module):
271
+ """Atrous Spatial Pyramid Pooling"""
272
+
273
+ def __init__(self, inplanes, planes, rate, kernel_size=3):
274
+ super(ASPP, self).__init__()
275
+ if rate == 1:
276
+ kernel_size = 1
277
+ padding = 0
278
+ else:
279
+ padding = rate * (kernel_size - 1) // 2
280
+ self.atrous_convolution = nn.Conv1d(
281
+ inplanes,
282
+ planes,
283
+ kernel_size=kernel_size,
284
+ stride=1,
285
+ padding=padding,
286
+ dilation=rate,
287
+ bias=False,
288
+ )
289
+ self.bn = nn.BatchNorm1d(planes)
290
+ self.relu = nn.ReLU()
291
+
292
+ def forward(self, x):
293
+ x = self.atrous_convolution(x)
294
+ x = self.bn(x)
295
+
296
+ return self.relu(x)
297
+
298
+
299
+ class DeepLabV3Plus(nn.Module):
300
+ def __init__(self, config):
301
+ super(DeepLabV3Plus, self).__init__()
302
+
303
+ self.config = config
304
+ # output_stride: (input's spatial resolution / output's resolution)
305
+ output_stride = int(config.output_stride)
306
+ kernel_size = int(config.kernel_size)
307
+ middle_block_rate = int(config.middle_block_rate)
308
+ exit_block_rates: list = config.exit_block_rates
309
+ middle_repeat = int(config.middle_repeat)
310
+ self.interpolate_mode = str(config.interpolate_mode)
311
+ aspp_channel = int(config.aspp_channel)
312
+ aspp_rate: list = config.aspp_rate
313
+ output_size = config.output_size # 3(p, qrs, t)
314
+
315
+ self.xception_features = Xception(
316
+ output_stride=output_stride,
317
+ kernel_size=kernel_size,
318
+ middle_repeat=middle_repeat,
319
+ middle_block_rate=middle_block_rate,
320
+ exit_block_rates=exit_block_rates,
321
+ )
322
+
323
+ # ASPP
324
+ self.aspp1 = ASPP(
325
+ 2048, aspp_channel, rate=aspp_rate[0], kernel_size=kernel_size
326
+ )
327
+ self.aspp2 = ASPP(
328
+ 2048, aspp_channel, rate=aspp_rate[1], kernel_size=kernel_size
329
+ )
330
+ self.aspp3 = ASPP(
331
+ 2048, aspp_channel, rate=aspp_rate[2], kernel_size=kernel_size
332
+ )
333
+ self.aspp4 = ASPP(
334
+ 2048, aspp_channel, rate=aspp_rate[3], kernel_size=kernel_size
335
+ )
336
+
337
+ self.relu = nn.ReLU()
338
+
339
+ self.global_avg_pool = nn.Sequential(
340
+ nn.AdaptiveAvgPool1d(1),
341
+ nn.Conv1d(2048, aspp_channel, 1, stride=1, bias=False),
342
+ nn.BatchNorm1d(aspp_channel),
343
+ nn.ReLU(),
344
+ )
345
+
346
+ self.conv1 = nn.Conv1d(aspp_channel * 5, aspp_channel, 1, bias=False)
347
+ self.bn1 = nn.BatchNorm1d(aspp_channel)
348
+
349
+ # adopt [1x1, 48] for channel reduction.
350
+ self.conv2 = nn.Conv1d(128, 48, 1, bias=False)
351
+ self.bn2 = nn.BatchNorm1d(48)
352
+
353
+ self.last_conv = nn.Sequential(
354
+ nn.Conv1d(
355
+ aspp_channel + 48,
356
+ 256,
357
+ kernel_size=kernel_size,
358
+ stride=1,
359
+ padding=(kernel_size - 1) // 2,
360
+ bias=False,
361
+ ),
362
+ nn.BatchNorm1d(256),
363
+ nn.ReLU(),
364
+ nn.Conv1d(
365
+ 256,
366
+ 256,
367
+ kernel_size=kernel_size,
368
+ stride=1,
369
+ padding=(kernel_size - 1) // 2,
370
+ bias=False,
371
+ ),
372
+ nn.BatchNorm1d(256),
373
+ nn.ReLU(),
374
+ nn.Conv1d(256, output_size, kernel_size=1, stride=1),
375
+ )
376
+
377
+ def forward(self, input):
378
+ x, low_level_features = self.xception_features(input)
379
+
380
+ x1 = self.aspp1(x)
381
+ x2 = self.aspp2(x)
382
+ x3 = self.aspp3(x)
383
+ x4 = self.aspp4(x)
384
+ x5 = self.global_avg_pool(x)
385
+ x5 = F.interpolate(x5, size=x4.shape[2:], mode=self.interpolate_mode)
386
+
387
+ x = torch.cat((x1, x2, x3, x4, x5), dim=1)
388
+
389
+ x = self.conv1(x)
390
+ x = self.bn1(x)
391
+ x = self.relu(x)
392
+ x = F.interpolate(
393
+ x, size=int(math.ceil(input.shape[-1] / 4)), mode=self.interpolate_mode
394
+ )
395
+
396
+ low_level_features = self.conv2(low_level_features)
397
+ low_level_features = self.bn2(low_level_features)
398
+ low_level_features = self.relu(low_level_features)
399
+
400
+ x = torch.cat((x, low_level_features), dim=1)
401
+ x = self.last_conv(x)
402
+ x = F.interpolate(x, size=input.shape[2:], mode=self.interpolate_mode)
403
+
404
+ return x
res/impl/FCN.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ paper: https://arxiv.org/abs/1605.06211
3
+ ref: https://github.com/shelhamer/fcn.berkeleyvision.org/blob/master/voc-fcn8s/net.py
4
+ """
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+
9
+
10
+ class FCN(nn.Module):
11
+ def __init__(self, config):
12
+ super().__init__()
13
+
14
+ self.config = config
15
+ self.kernel_size = int(config.kernel_size)
16
+ last_layer_kernel_size = int(config.last_layer_kernel_size)
17
+ inplanes = int(config.inplanes)
18
+ combine_conf: dict = config.combine_conf
19
+ self.num_layers = int(combine_conf["num_layers"])
20
+ self.first_padding = {6: 240, 5: 130, 4: 80}[self.num_layers]
21
+ self.num_convs = int(config.num_convs)
22
+ self.dilation = int(config.dilation)
23
+ self.combine_until = int(combine_conf["combine_until"])
24
+ assert self.combine_until < self.num_layers
25
+ dropout = float(config.dropout)
26
+ output_size = config.output_size # 3(p, qrs, t)
27
+
28
+ self.layers = nn.ModuleList()
29
+ for i in range(self.num_layers):
30
+ self.layers.append(
31
+ self._make_layer(
32
+ 1 if i == 0 else inplanes * (2 ** (i - 1)),
33
+ inplanes * (2 ** (i)),
34
+ is_first=True if i == 0 else False,
35
+ )
36
+ )
37
+ # pool 단계가 없는 마지막 conv layer로 다른 layer 와 다르게 conv 개수(2)와 channel이 고정이고, dropout을 수행
38
+ self.layers.append(
39
+ nn.Sequential(
40
+ nn.Conv1d(inplanes * (2 ** (i)), 4096, last_layer_kernel_size),
41
+ nn.BatchNorm1d(4096),
42
+ nn.ReLU(),
43
+ nn.Dropout(dropout),
44
+ nn.Conv1d(4096, 4096, 1),
45
+ nn.BatchNorm1d(4096),
46
+ nn.ReLU(),
47
+ nn.Dropout(dropout),
48
+ )
49
+ )
50
+ self.score_convs = []
51
+ self.up_convs = []
52
+ for i in range(self.combine_until, self.num_layers - 1):
53
+ # pool 결과를 combine 하는 만큼만 score_convs 와 up_convs 가 생성됨
54
+ self.score_convs.append(
55
+ nn.Conv1d(inplanes * (2 ** (i)), output_size, kernel_size=1, bias=False)
56
+ )
57
+ self.up_convs.append(
58
+ nn.ConvTranspose1d(output_size, output_size, kernel_size=4, stride=2)
59
+ )
60
+ # pool 이 없는 마지막 convs 결과에 수행하는 score_convs
61
+ # self.score_convs 는 항상 self.up_convs 의 개수보다 1개 더 많음
62
+ self.score_convs.append(nn.Conv1d(4096, output_size, kernel_size=1, bias=False))
63
+
64
+ self.score_convs.reverse()
65
+ self.score_convs = nn.ModuleList(self.score_convs)
66
+ self.up_convs = nn.ModuleList(self.up_convs)
67
+ self.last_up_convs = nn.ConvTranspose1d(
68
+ output_size,
69
+ output_size,
70
+ kernel_size=2 ** (self.combine_until + 1) * 2, # stride * 2
71
+ stride=2 ** (self.combine_until + 1),
72
+ )
73
+
74
+ def _make_layer(
75
+ self,
76
+ in_channel: int,
77
+ out_channel: int,
78
+ is_first: bool = False,
79
+ ):
80
+ layer = []
81
+ plane = in_channel
82
+ for idx in range(self.num_convs):
83
+ layer.append(
84
+ nn.Conv1d(
85
+ plane,
86
+ out_channel,
87
+ kernel_size=self.kernel_size,
88
+ padding=self.first_padding
89
+ if idx == 0 and is_first
90
+ else (self.dilation * (self.kernel_size - 1)) // 2,
91
+ dilation=self.dilation,
92
+ bias=False,
93
+ )
94
+ )
95
+ layer.append(nn.BatchNorm1d(out_channel))
96
+ layer.append(nn.ReLU())
97
+ plane = out_channel
98
+
99
+ layer.append(nn.MaxPool1d(2, 2, ceil_mode=True))
100
+ return nn.Sequential(*layer)
101
+
102
+ def forward(self, input: torch.Tensor, y=None):
103
+ output: torch.Tensor = input
104
+
105
+ pools = []
106
+ for idx, layer in enumerate(self.layers):
107
+ output = layer(output)
108
+ if self.combine_until <= idx < (self.num_layers - 1):
109
+ pools.append(output)
110
+ pools.reverse()
111
+
112
+ output = self.score_convs[0](output)
113
+ if len(pools) > 0:
114
+ output = self.up_convs[0](output)
115
+ for i in range(len(pools)):
116
+ score_pool = self.score_convs[i + 1](pools[i])
117
+ offset = (score_pool.shape[2] - output.shape[2]) // 2
118
+ cropped_score_pool = torch.tensor_split(
119
+ score_pool, (offset, offset + output.shape[2]), dim=2
120
+ )[1]
121
+ output = torch.add(cropped_score_pool, output)
122
+ if i < len(pools) - 1: # 마지막 up_conv 는 last_up_convs 이용
123
+ output = self.up_convs[i + 1](output)
124
+ output = self.last_up_convs(output)
125
+
126
+ offset = (output.shape[2] - input.shape[2]) // 2
127
+ cropped_score_pool = torch.tensor_split(
128
+ output, (offset, offset + input.shape[2]), dim=2
129
+ )[1]
130
+ return cropped_score_pool
res/impl/HRNetV2.py ADDED
@@ -0,0 +1,378 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ paper: https://arxiv.org/abs/1904.04514
3
+ ref: https://github.com/HRNet/HRNet-Semantic-Segmentation/blob/HRNet-OCR/lib/models/seg_hrnet.py
4
+ """
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ from torch.functional import F
9
+ import math
10
+
11
+
12
+ def _gen_same_length_conv(in_channel, out_channel, kernel_size=1, dilation=1):
13
+ """길이가 변하지 않는 conv 생성, block 내에서 feature 를 추출하는 convolution 에서 사용"""
14
+ return nn.Conv1d(
15
+ in_channel,
16
+ out_channel,
17
+ kernel_size=kernel_size,
18
+ stride=1,
19
+ padding=(dilation * (kernel_size - 1)) // 2,
20
+ dilation=dilation,
21
+ bias=False,
22
+ )
23
+
24
+
25
+ def _gen_downsample(in_channel, out_channel):
26
+ """kernel_size:3, stride:2, padding:1 인 2배 downsample 하는 conv 생성"""
27
+ return nn.Conv1d(
28
+ in_channel, out_channel, kernel_size=3, stride=2, padding=1, bias=False
29
+ )
30
+
31
+
32
+ def _gen_channel_change_conv(in_channel, out_channel):
33
+ """kernel_size:1, stride:1 인 channel 변경하는 conv 생성"""
34
+ return nn.Conv1d(in_channel, out_channel, kernel_size=1, stride=1, bias=False)
35
+
36
+
37
+ class BasicBlock(nn.Module):
38
+ """resnet 의 basic block 으로 channel 변화는 inplanes -> planes"""
39
+
40
+ expansion = 1
41
+
42
+ def __init__(self, inplanes, planes, kernel_size=3, dilation=1):
43
+ super().__init__()
44
+ self.conv1 = _gen_same_length_conv(inplanes, planes, kernel_size, dilation)
45
+ self.bn1 = nn.BatchNorm1d(planes)
46
+ self.relu = nn.ReLU()
47
+ self.conv2 = _gen_same_length_conv(planes, planes, kernel_size, dilation)
48
+ self.bn2 = nn.BatchNorm1d(planes)
49
+ self.make_residual = (
50
+ _gen_channel_change_conv(inplanes, planes)
51
+ if inplanes != planes
52
+ else nn.Identity()
53
+ )
54
+
55
+ def forward(self, x):
56
+ out = self.conv1(x)
57
+ out = self.bn1(out)
58
+ out = self.relu(out)
59
+
60
+ out = self.conv2(out)
61
+ out = self.bn2(out)
62
+
63
+ residual = self.make_residual(x)
64
+
65
+ out = out + residual
66
+ out = self.relu(out)
67
+
68
+ return out
69
+
70
+
71
+ class Bottleneck(nn.Module):
72
+ """resnet 의 Bottleneck block 으로 channel 변화는 inplanes -> planes * 4"""
73
+
74
+ expansion = 4
75
+
76
+ def __init__(self, inplanes, planes, kernel_size=3, dilation=1):
77
+ super().__init__()
78
+ self.conv1 = _gen_same_length_conv(inplanes, planes)
79
+ self.bn1 = nn.BatchNorm1d(planes)
80
+ self.conv2 = _gen_same_length_conv(planes, planes, kernel_size, dilation)
81
+ self.bn2 = nn.BatchNorm1d(planes)
82
+ self.conv3 = _gen_same_length_conv(planes, planes * self.expansion)
83
+ self.bn3 = nn.BatchNorm1d(planes * self.expansion)
84
+ self.relu = nn.ReLU()
85
+ self.make_residual = (
86
+ _gen_channel_change_conv(inplanes, planes * self.expansion)
87
+ if inplanes != planes * self.expansion
88
+ else nn.Identity()
89
+ )
90
+
91
+ def forward(self, x):
92
+ out = self.conv1(x)
93
+ out = self.bn1(out)
94
+ out = self.relu(out)
95
+
96
+ out = self.conv2(out)
97
+ out = self.bn2(out)
98
+ out = self.relu(out)
99
+
100
+ out = self.conv3(out)
101
+ out = self.bn3(out)
102
+
103
+ residual = self.make_residual(x)
104
+
105
+ out = out + residual
106
+ out = self.relu(out)
107
+
108
+ return out
109
+
110
+
111
+ class HRModule(nn.Module):
112
+ def __init__(
113
+ self,
114
+ stage_idx,
115
+ num_blocks,
116
+ block_type_by_stage,
117
+ in_channels_by_stage,
118
+ out_channels_by_stage,
119
+ data_len_by_branch,
120
+ kernel_size,
121
+ dilation,
122
+ interpolate_mode,
123
+ ):
124
+ super().__init__()
125
+
126
+ self.branches = nn.ModuleList()
127
+ self.fusions = nn.ModuleList()
128
+
129
+ block_type: BasicBlock | Bottleneck = block_type_by_stage[stage_idx]
130
+ in_channels = in_channels_by_stage[stage_idx]
131
+ for i in range(stage_idx + 1): # branch 생성
132
+ blocks_by_branch = []
133
+ _channels = in_channels[i]
134
+ blocks_by_branch.append(
135
+ block_type(_channels, _channels, kernel_size, dilation)
136
+ )
137
+ for _ in range(1, num_blocks):
138
+ blocks_by_branch.append(
139
+ block_type(
140
+ _channels * block_type.expansion,
141
+ _channels,
142
+ kernel_size,
143
+ dilation,
144
+ )
145
+ )
146
+ self.branches.append(nn.Sequential(*blocks_by_branch))
147
+
148
+ out_channels = out_channels_by_stage[stage_idx]
149
+ for i in range(stage_idx + 1):
150
+ fusion_by_branch = nn.ModuleList()
151
+ for j in range(stage_idx + 1):
152
+ if i < j:
153
+ fusion_by_branch.append(
154
+ nn.Sequential(
155
+ _gen_channel_change_conv(out_channels[j], in_channels[i]),
156
+ nn.BatchNorm1d(in_channels[i]),
157
+ nn.Upsample(
158
+ size=data_len_by_branch[i], mode=interpolate_mode
159
+ ),
160
+ )
161
+ )
162
+ elif i == j:
163
+ if out_channels[i] != in_channels[j]:
164
+ fusion_by_branch.append(
165
+ nn.Sequential(
166
+ _gen_channel_change_conv(
167
+ out_channels[i], in_channels[j]
168
+ ),
169
+ nn.BatchNorm1d(in_channels[j]),
170
+ nn.ReLU(),
171
+ )
172
+ )
173
+ else:
174
+ fusion_by_branch.append(nn.Identity())
175
+ else:
176
+ # 차이나는 branch 만큼 2배씩 downsample, channel 은 현재 layer 의 in_channel 로 맞춰줌
177
+ downsamples = [
178
+ _gen_downsample(out_channels[j], in_channels[i]),
179
+ nn.BatchNorm1d(in_channels[i]),
180
+ ]
181
+ for _ in range(1, i - j):
182
+ downsamples.extend(
183
+ [
184
+ nn.ReLU(),
185
+ _gen_downsample(in_channels[i], in_channels[i]),
186
+ nn.BatchNorm1d(in_channels[i]),
187
+ ]
188
+ )
189
+ fusion_by_branch.append(nn.Sequential(*downsamples))
190
+ self.fusions.append(fusion_by_branch)
191
+
192
+
193
+ class HRNetV2(nn.Module):
194
+ def __init__(self, config):
195
+ super().__init__()
196
+
197
+ self.config = config
198
+ data_len = int(config.data_len) # ECGPQRSTDataset.second, hz 에 맞춰서
199
+ kernel_size = int(config.kernel_size)
200
+ dilation = int(config.dilation)
201
+ num_stages = int(config.num_stages)
202
+ num_blocks = int(config.num_blocks)
203
+ self.num_modules = config.num_modules # [1, 1, 4, 3, ..]
204
+ assert num_stages <= len(self.num_modules)
205
+ use_bottleneck = config.use_bottleneck # [1, 0, 0, 0, ..]
206
+ assert num_stages <= len(use_bottleneck)
207
+ stage1_channels = int(config.stage1_channels) # 64, 128
208
+ num_channels_init = int(config.num_channels_init) # 18, 32, 48
209
+ self.interpolate_mode = config.interpolate_mode
210
+ output_size = config.output_size # 3(p, qrs, t)
211
+
212
+ # stem
213
+ self.stem = nn.Sequential(
214
+ nn.Conv1d(
215
+ 1, stage1_channels, kernel_size=3, stride=2, padding=1, bias=False
216
+ ),
217
+ nn.BatchNorm1d(stage1_channels),
218
+ nn.Conv1d(
219
+ stage1_channels,
220
+ stage1_channels,
221
+ kernel_size=3,
222
+ stride=2,
223
+ padding=1,
224
+ bias=False,
225
+ ),
226
+ nn.BatchNorm1d(stage1_channels),
227
+ nn.ReLU(),
228
+ )
229
+ for _ in range(2): # stem 을 거친 이후 데이터 길이 계산
230
+ data_len = math.floor((data_len - 1) / 2 + 1)
231
+
232
+ # create meta: 네트워크 생성 전 각 stage 의 in_channel, out_channel 등의 정보를 먼저 만들고 시작
233
+ in_channels_by_stage = []
234
+ out_channels_by_stage = []
235
+ block_type_by_stage = []
236
+ for stage_idx in range(num_stages):
237
+ block_type_each_stage = (
238
+ Bottleneck if use_bottleneck[stage_idx] == 1 else BasicBlock
239
+ )
240
+ if stage_idx == 0:
241
+ in_channels_each_stage = [stage1_channels]
242
+ out_channels_each_stage = [
243
+ stage1_channels * block_type_each_stage.expansion
244
+ ]
245
+ data_len_by_branch = [data_len]
246
+ else:
247
+ in_channels_each_stage = [
248
+ num_channels_init * 2**idx for idx in range(stage_idx + 1)
249
+ ]
250
+ out_channels_each_stage = [
251
+ (num_channels_init * 2**idx) * block_type_each_stage.expansion
252
+ for idx in range(stage_idx + 1)
253
+ ]
254
+ data_len_by_branch.append(
255
+ math.floor((data_len_by_branch[-1] - 1) / 2 + 1)
256
+ )
257
+
258
+ block_type_by_stage.append(block_type_each_stage)
259
+ in_channels_by_stage.append(in_channels_each_stage)
260
+ out_channels_by_stage.append(out_channels_each_stage)
261
+
262
+ # create stages
263
+ self.stages = nn.ModuleList()
264
+ for stage_idx in range(num_stages):
265
+ modules_by_stage = nn.ModuleList()
266
+ for _ in range(self.num_modules[stage_idx]):
267
+ modules_by_stage.append(
268
+ HRModule(
269
+ stage_idx,
270
+ num_blocks,
271
+ block_type_by_stage,
272
+ in_channels_by_stage,
273
+ out_channels_by_stage,
274
+ data_len_by_branch,
275
+ kernel_size,
276
+ dilation,
277
+ self.interpolate_mode,
278
+ )
279
+ )
280
+ self.stages.append(modules_by_stage)
281
+
282
+ # create transition
283
+ self.transitions = nn.ModuleList()
284
+ for stage_idx in range(num_stages - 1):
285
+ # 여기에서 stage_idx 는 이전 stage 를 뜻함. transition 은 각 stage 사이에서 channel 을 바꿔주거나 새로운 branch 를 생성하는 역할
286
+ transition_by_stage = nn.ModuleList()
287
+ psc = in_channels_by_stage[stage_idx] # psc: prev_stage_channels
288
+ nsc = in_channels_by_stage[stage_idx + 1] # nsc: next_stage_channels
289
+ for nsbi in range(stage_idx + 2): # nsbi: next_stage_branch_idx
290
+ if nsbi < stage_idx + 1: # 동일한 branch level
291
+ if psc[nsbi] != nsc[nsbi]:
292
+ transition_by_stage.append(
293
+ nn.Sequential(
294
+ _gen_channel_change_conv(psc[nsbi], nsc[nsbi]),
295
+ nn.BatchNorm1d(nsc[nsbi]),
296
+ nn.ReLU(),
297
+ )
298
+ )
299
+ else:
300
+ transition_by_stage.append(nn.Identity())
301
+ else: # create new branch from exists branches
302
+ transition_from_branches = nn.ModuleList()
303
+ for psbi in range(nsbi):
304
+ # psbi: prev_stage_branch_idx
305
+ transition_from_one_branch = [
306
+ _gen_downsample(psc[psbi], nsc[nsbi]),
307
+ nn.BatchNorm1d(nsc[nsbi]),
308
+ ]
309
+ for _ in range(1, nsbi - psbi):
310
+ transition_from_one_branch.extend(
311
+ [
312
+ nn.ReLU(),
313
+ _gen_downsample(nsc[nsbi], nsc[nsbi]),
314
+ nn.BatchNorm1d(nsc[nsbi]),
315
+ ]
316
+ )
317
+ transition_from_branches.append(
318
+ nn.Sequential(*transition_from_one_branch)
319
+ )
320
+ transition_by_stage.append(transition_from_branches)
321
+ self.transitions.append(transition_by_stage)
322
+
323
+ self.cls = nn.Conv1d(sum(in_channels_each_stage), output_size, 1, bias=False)
324
+
325
+ def forward(self, input: torch.Tensor, y=None):
326
+ output: torch.Tensor = input
327
+
328
+ output = self.stem(output)
329
+
330
+ outputs = [output]
331
+ for stage_idx, stage in enumerate(self.stages):
332
+ for module_idx in range(self.num_modules[stage_idx]):
333
+ for branch_idx in range(stage_idx + 1):
334
+ outputs[branch_idx] = stage[module_idx].branches[branch_idx](
335
+ outputs[branch_idx]
336
+ )
337
+ fusion_outputs = []
338
+ for next in range(stage_idx + 1):
339
+ fusion_output_from_branches = []
340
+ for prev in range(stage_idx + 1):
341
+ fusion_output_from_branch: torch.Tensor = stage[
342
+ module_idx
343
+ ].fusions[next][prev](outputs[prev])
344
+ fusion_output_from_branches.append(fusion_output_from_branch)
345
+ fusion_outputs.append(sum(fusion_output_from_branches))
346
+ outputs = fusion_outputs
347
+
348
+ if stage_idx < len(self.stages) - 1:
349
+ transition_outputs = []
350
+ for trans_idx, transition in enumerate(self.transitions[stage_idx]):
351
+ # transition 에는 다음 stage 의 branch 개수만큼 Sequential 이나 ModuleList 가 존재
352
+ # 앞의 Sequential 들은 channel 만 다음 stage 에 맞게 변경하거나 기존 그대로 사용 (Identity)
353
+ # 마지막 ModuleList 각 branch 의 fusion 결과들을 downsample 한 결과들로부터 새로운 branch 를 생성
354
+ if trans_idx < stage_idx + 1:
355
+ transition_outputs.append(transition(outputs[trans_idx]))
356
+ else:
357
+ transition_outputs.append(
358
+ sum(
359
+ [
360
+ transition_from_each_branch(output)
361
+ for transition_from_each_branch, output in zip(
362
+ transition, outputs
363
+ )
364
+ ]
365
+ )
366
+ )
367
+ outputs = transition_outputs
368
+
369
+ # HRNetV2
370
+ outputs = [
371
+ F.interpolate(output, size=outputs[0].shape[-1], mode=self.interpolate_mode)
372
+ for output in outputs
373
+ ]
374
+ output = torch.cat(outputs, dim=1)
375
+
376
+ return F.interpolate(
377
+ self.cls(output), size=input.shape[-1], mode=self.interpolate_mode
378
+ )
res/impl/PSPNet.py ADDED
@@ -0,0 +1,240 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ paper: https://arxiv.org/abs/1612.01105
3
+ ref:
4
+ - https://github.com/hszhao/PSPNet
5
+ """
6
+
7
+ import torch
8
+ from torch import nn
9
+ from torch.functional import F
10
+
11
+
12
+ class PPM(nn.Module):
13
+ """Pyramid Pooling Module"""
14
+
15
+ def __init__(self, in_dim, reduction_dim, bins, interplate_mode):
16
+ super(PPM, self).__init__()
17
+ self.features = []
18
+ for bin in bins:
19
+ self.features.append(
20
+ nn.Sequential(
21
+ nn.AdaptiveAvgPool1d(bin),
22
+ nn.Conv1d(in_dim, reduction_dim, kernel_size=1, bias=False),
23
+ nn.BatchNorm1d(reduction_dim),
24
+ nn.ReLU(),
25
+ )
26
+ )
27
+ self.features = nn.ModuleList(self.features)
28
+ self.interplate_mode = interplate_mode
29
+
30
+ def forward(self, x: torch.Tensor):
31
+ x_size = x.size()
32
+ out = [x]
33
+ for f in self.features:
34
+ out.append(F.interpolate(f(x), x_size[2], mode=self.interplate_mode))
35
+ return torch.cat(out, dim=1)
36
+
37
+
38
+ class Bottleneck(nn.Module):
39
+ def __init__(
40
+ self,
41
+ inplanes,
42
+ planes,
43
+ expansion=4,
44
+ kernel_size=3,
45
+ stride=1,
46
+ dilation=1,
47
+ padding=1,
48
+ downsample=None,
49
+ ):
50
+ super(Bottleneck, self).__init__()
51
+ self.conv1 = nn.Conv1d(inplanes, planes, kernel_size=1, bias=False)
52
+ self.bn1 = nn.BatchNorm1d(planes)
53
+ self.conv2 = nn.Conv1d(
54
+ planes,
55
+ planes,
56
+ kernel_size=kernel_size,
57
+ stride=stride,
58
+ dilation=dilation,
59
+ padding=padding,
60
+ bias=False,
61
+ )
62
+ self.bn2 = nn.BatchNorm1d(planes)
63
+ self.conv3 = nn.Conv1d(planes, planes * expansion, kernel_size=1, bias=False)
64
+ self.bn3 = nn.BatchNorm1d(planes * expansion)
65
+ self.relu = nn.ReLU()
66
+ self.downsample = downsample
67
+
68
+ def forward(self, x):
69
+ residual = x
70
+
71
+ out = self.conv1(x)
72
+ out = self.bn1(out)
73
+ out = self.relu(out)
74
+
75
+ out = self.conv2(out)
76
+ out = self.bn2(out)
77
+ out = self.relu(out)
78
+
79
+ out = self.conv3(out)
80
+ out = self.bn3(out)
81
+
82
+ if self.downsample is not None:
83
+ residual = self.downsample(x)
84
+
85
+ out += residual
86
+ out = self.relu(out)
87
+
88
+ return out
89
+
90
+
91
+ class PSPNet(nn.Module):
92
+ def __init__(self, config):
93
+ super(PSPNet, self).__init__()
94
+
95
+ self.config = config
96
+ self.kernel_size = int(config.kernel_size)
97
+ self.padding = (self.kernel_size - 1) // 2
98
+ self.expansion = int(config.expansion)
99
+ self.inplanes = int(config.inplanes)
100
+ num_layers = int(config.num_layers)
101
+ self.num_bottlenecks = int(config.num_bottlenecks)
102
+ self.interpolate_mode = str(config.interpolate_mode)
103
+ self.dilation = int(config.dilation)
104
+ ppm_bins: list = config.ppm_bins
105
+ self.aux_idx = int(config.aux_idx)
106
+ assert self.aux_idx < num_layers
107
+ self.aux_ratio = float(config.aux_ratio)
108
+ dropout = float(config.dropout)
109
+ output_size = config.output_size # 3(p, qrs, t)
110
+
111
+ # stem 단계에서 1/4 만큼 downsample 된 상태로 시작
112
+ self.stem = nn.Sequential(
113
+ *[
114
+ nn.Conv1d(
115
+ 1,
116
+ self.inplanes,
117
+ self.kernel_size,
118
+ stride=2,
119
+ padding=self.padding,
120
+ bias=False,
121
+ ),
122
+ nn.BatchNorm1d(self.inplanes),
123
+ nn.ReLU(),
124
+ nn.MaxPool1d(self.kernel_size, stride=2, padding=self.padding),
125
+ ]
126
+ )
127
+
128
+ self.layers = []
129
+ plane = self.inplanes
130
+ for i in range(num_layers):
131
+ self.layers.append(self._make_layer(plane * (2 ** (i))))
132
+ self.layers = nn.ModuleList(self.layers)
133
+
134
+ encode_dim = self.inplanes
135
+ self.ppm = PPM(
136
+ encode_dim,
137
+ int(encode_dim / len(ppm_bins)),
138
+ ppm_bins,
139
+ self.interpolate_mode,
140
+ )
141
+ encode_dim *= 2
142
+ self.cls = nn.Sequential(
143
+ nn.Conv1d(
144
+ encode_dim,
145
+ 512,
146
+ kernel_size=self.kernel_size,
147
+ padding=self.padding,
148
+ bias=False,
149
+ ),
150
+ nn.BatchNorm1d(512),
151
+ nn.ReLU(),
152
+ nn.Dropout1d(dropout),
153
+ nn.Conv1d(512, output_size, kernel_size=1),
154
+ )
155
+ self.aux_branch = nn.Sequential(
156
+ # 추출하고자 하는 layer index 에 해당하는 channel 과 맞춰주어야 함
157
+ nn.Conv1d(
158
+ plane * self.expansion * (2**self.aux_idx),
159
+ 256,
160
+ kernel_size=self.kernel_size,
161
+ padding=self.padding,
162
+ bias=False,
163
+ ),
164
+ nn.BatchNorm1d(256),
165
+ nn.ReLU(),
166
+ nn.Dropout1d(0.1),
167
+ nn.Conv1d(256, output_size, kernel_size=1),
168
+ )
169
+
170
+ def _make_layer(self, planes: int):
171
+ """
172
+ self.num_bottlenecks 개의 bottleneck 으로 구성된 layer 를 반환
173
+ 첫번째 bottleneck 에서 2 만큼 downsample 됨
174
+ 두번째 이후부터의 bottleneck 에서 self.dilation 으로 dilated conv 수행
175
+ """
176
+ downsample = nn.Sequential(
177
+ nn.Conv1d(
178
+ self.inplanes,
179
+ planes * self.expansion,
180
+ kernel_size=1,
181
+ stride=2,
182
+ bias=False,
183
+ ),
184
+ nn.BatchNorm1d(planes * self.expansion),
185
+ )
186
+
187
+ bottlenecks = []
188
+ bottlenecks.append(
189
+ Bottleneck(
190
+ self.inplanes,
191
+ planes,
192
+ expansion=self.expansion,
193
+ kernel_size=self.kernel_size,
194
+ stride=2,
195
+ dilation=1,
196
+ padding=self.padding,
197
+ downsample=downsample,
198
+ )
199
+ )
200
+ self.inplanes = planes * self.expansion
201
+ for _ in range(1, self.num_bottlenecks):
202
+ bottlenecks.append(
203
+ Bottleneck(
204
+ self.inplanes,
205
+ planes,
206
+ expansion=self.expansion,
207
+ kernel_size=self.kernel_size,
208
+ stride=1,
209
+ dilation=self.dilation,
210
+ padding=(self.dilation * (self.kernel_size - 1)) // 2,
211
+ )
212
+ )
213
+
214
+ return nn.Sequential(*bottlenecks)
215
+
216
+ def forward(self, input: torch.Tensor, y=None):
217
+ output: torch.Tensor = input
218
+ output = self.stem(output)
219
+ for i, _layer in enumerate(self.layers):
220
+ output = _layer(output)
221
+ if i == self.aux_idx:
222
+ aux = output
223
+
224
+ output = self.ppm(output)
225
+ output = self.cls(output)
226
+ output = F.interpolate(
227
+ output,
228
+ input.shape[2],
229
+ mode=self.interpolate_mode,
230
+ )
231
+ if self.training:
232
+ aux = self.aux_branch(aux)
233
+ aux = F.interpolate(
234
+ aux,
235
+ input.shape[2],
236
+ mode=self.interpolate_mode,
237
+ )
238
+ return torch.add(output * (1 - self.aux_ratio), aux * self.aux_ratio)
239
+ else:
240
+ return output
res/impl/SETR.py ADDED
@@ -0,0 +1,291 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ paper: https://arxiv.org/abs/2012.15840
3
+ - ref
4
+ - encoder:
5
+ - https://github.com/open-mmlab/mmsegmentation/blob/master/mmseg/models/backbones/vit.py
6
+ - https://github.com/lucidrains/vit-pytorch/blob/main/vit_pytorch/vit_1d.py
7
+ - decoder:
8
+ - https://github.com/open-mmlab/mmsegmentation/blob/main/mmseg/models/decode_heads/setr_up_head.py
9
+ - https://github.com/open-mmlab/mmsegmentation/blob/main/mmseg/models/decode_heads/setr_mla_head.py
10
+
11
+ - encoder: ViT 와 구조가 동일하며, PatchEmbed 의 경우 patch_size를 kernel_size와 stride 로 하는 Conv1d를 사용
12
+ - decoder: upsample 하는 방식으로 다음 두가지를 사용 (scale_factor: 특정 배수만큼 upsample / size: 특정 크기와 동일한 크기로 upsample)
13
+ - naive: 원본 길이로 size 방식 upsample
14
+ - pup: scale_factor 방식으로 수행하다가 마지막에 원본 길이로 size 방식으로 upsample
15
+ - mla: 총 두 단계로 수행하며, 첫번째 단계에서 transformer block 의 결과들을 scale_factor 방식으로 수행하고 두번째 단계에서 첫번째 결과들을 concat 한 후 size 방식으로 upsample
16
+ """
17
+
18
+ import math
19
+ import torch
20
+ from torch import nn
21
+ from einops import rearrange
22
+
23
+
24
+ class FeedForward(nn.Module):
25
+ def __init__(self, dim, hidden_dim, dropout=0.0):
26
+ super().__init__()
27
+ self.net = nn.Sequential(
28
+ nn.LayerNorm(dim),
29
+ nn.Linear(dim, hidden_dim),
30
+ nn.GELU(),
31
+ nn.Dropout(dropout),
32
+ nn.Linear(hidden_dim, dim),
33
+ nn.Dropout(dropout),
34
+ )
35
+
36
+ def forward(self, x):
37
+ return self.net(x)
38
+
39
+
40
+ class Attention(nn.Module):
41
+ def __init__(self, dim, heads=8, dim_head=64, dropout=0.0):
42
+ super().__init__()
43
+ inner_dim = dim_head * heads
44
+ project_out = not (heads == 1 and dim_head == dim)
45
+
46
+ self.heads = heads
47
+ self.scale = dim_head**-0.5
48
+
49
+ self.norm = nn.LayerNorm(dim)
50
+ self.attend = nn.Softmax(dim=-1)
51
+ self.dropout = nn.Dropout(dropout)
52
+
53
+ self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
54
+
55
+ self.to_out = (
56
+ nn.Sequential(nn.Linear(inner_dim, dim), nn.Dropout(dropout))
57
+ if project_out
58
+ else nn.Identity()
59
+ )
60
+
61
+ def forward(self, x):
62
+ x = self.norm(x)
63
+ qkv = self.to_qkv(x).chunk(3, dim=-1)
64
+ q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.heads), qkv)
65
+
66
+ dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
67
+
68
+ attn = self.attend(dots)
69
+ attn = self.dropout(attn)
70
+
71
+ out = torch.matmul(attn, v)
72
+ out = rearrange(out, "b h n d -> b n (h d)")
73
+ return self.to_out(out)
74
+
75
+
76
+ # ========== 여기까지 https://github.com/lucidrains/vit-pytorch/blob/main/vit_pytorch/vit_1d.py 차용 ==========
77
+ # ========== 아래부터 setr 원본 참고 https://github.com/open-mmlab/mmsegmentation/blob/master/mmseg/models/backbones/vit.py ==========
78
+
79
+
80
+ class TransformerBlock(nn.Module):
81
+ def __init__(
82
+ self,
83
+ dim,
84
+ num_attn_heads,
85
+ attn_head_dim,
86
+ mlp_dim,
87
+ attn_dropout=0.0,
88
+ ffn_dropout=0.0,
89
+ ):
90
+ super().__init__()
91
+ self.attn = Attention(
92
+ dim, heads=num_attn_heads, dim_head=attn_head_dim, dropout=attn_dropout
93
+ )
94
+ self.ffn = FeedForward(dim, mlp_dim, dropout=ffn_dropout)
95
+
96
+ def forward(self, x):
97
+ x = self.attn(x) + x
98
+ x = self.ffn(x) + x
99
+ return x
100
+
101
+
102
+ class PatchEmbed(nn.Module):
103
+ def __init__(
104
+ self,
105
+ embed_dim=1024,
106
+ kernel_size=16,
107
+ bias=False,
108
+ ):
109
+ super().__init__()
110
+
111
+ self.projection = nn.Conv1d(
112
+ in_channels=1,
113
+ out_channels=embed_dim,
114
+ kernel_size=kernel_size,
115
+ stride=kernel_size,
116
+ bias=bias,
117
+ )
118
+
119
+ def forward(self, x: torch.Tensor):
120
+ return self.projection(x).transpose(1, 2)
121
+
122
+
123
+ class SETR(nn.Module):
124
+ def __init__(self, config):
125
+ super().__init__()
126
+
127
+ embed_dim = int(config.embed_dim)
128
+ data_len = int(config.data_len) # ECGPQRSTDataset.second, hz 에 맞춰서
129
+ patch_size = int(config.patch_size)
130
+ assert data_len % patch_size == 0
131
+ num_patches = data_len // patch_size
132
+ patch_bias = bool(config.patch_bias)
133
+ dropout = float(config.dropout)
134
+ # pos_dropout_p: float = config.pos_dropout_p # 파라미터라 너무 많으므로 우선 dropout 개수는 하나로 사용
135
+ num_layers = int(config.num_layers) # transformer block 개수
136
+ num_attn_heads = int(config.num_attn_heads)
137
+ attn_head_dim = int(config.attn_head_dim)
138
+ mlp_dim = int(config.mlp_dim)
139
+ # attn_dropout: float = config.attn_dropout
140
+ # ffn_dropout: float = config.ffn_dropout
141
+ interpolate_mode = str(config.interpolate_mode)
142
+ dec_conf: dict = config.dec_conf
143
+ assert len(dec_conf) == 1
144
+ self.dec_mode: str = list(dec_conf.keys())[0]
145
+ assert self.dec_mode in ["naive", "pup", "mla"]
146
+ self.dec_param: dict = dec_conf[self.dec_mode]
147
+ output_size = int(config.output_size)
148
+
149
+ # patch embedding
150
+ self.patch_embed = PatchEmbed(
151
+ embed_dim=embed_dim,
152
+ kernel_size=patch_size,
153
+ bias=patch_bias,
154
+ )
155
+
156
+ # positional embedding
157
+ self.pos_embed = nn.Parameter(torch.randn(1, num_patches, embed_dim))
158
+ self.pos_dropout = nn.Dropout(p=dropout)
159
+
160
+ # transformer encoder
161
+ self.layers = nn.ModuleList()
162
+ for _ in range(num_layers):
163
+ self.layers.append(
164
+ TransformerBlock(
165
+ dim=embed_dim,
166
+ num_attn_heads=num_attn_heads,
167
+ attn_head_dim=attn_head_dim,
168
+ mlp_dim=mlp_dim,
169
+ attn_dropout=dropout,
170
+ ffn_dropout=dropout,
171
+ )
172
+ )
173
+
174
+ # decoder
175
+ self.dec_layers = nn.ModuleList()
176
+ if self.dec_mode == "naive":
177
+ self.dec_layers.append(nn.Upsample(size=data_len, mode=interpolate_mode))
178
+ dec_out_channel = embed_dim
179
+ elif self.dec_mode == "pup":
180
+ self.dec_layers.append(nn.LayerNorm(embed_dim))
181
+ dec_up_scale = int(self.dec_param["up_scale"])
182
+ available_up_count = int(
183
+ math.log(data_len // num_patches, dec_up_scale)
184
+ ) # scale_factor 방법으로 upsample 할 수 있는 단계 계산, 나머지는 size 방법으로 upsample
185
+ pup_channels = int(self.dec_param["channels"])
186
+ dec_in_channel = embed_dim
187
+ dec_out_channel = pup_channels
188
+ dec_kernel_size = int(self.dec_param["kernel_size"])
189
+ dec_num_convs_by_layer = int(self.dec_param["num_convs_by_layer"])
190
+ assert dec_kernel_size in [1, 3] # 원본 코드 그대로
191
+ for i in range(available_up_count + 1):
192
+ for _ in range(dec_num_convs_by_layer):
193
+ self.dec_layers.append(
194
+ nn.Conv1d(
195
+ dec_in_channel,
196
+ dec_out_channel,
197
+ kernel_size=dec_kernel_size,
198
+ stride=1,
199
+ padding=(dec_kernel_size - 1) // 2,
200
+ )
201
+ )
202
+ dec_in_channel = dec_out_channel
203
+ if i < available_up_count:
204
+ self.dec_layers.append(
205
+ nn.Upsample(scale_factor=dec_up_scale, mode=interpolate_mode)
206
+ )
207
+ else: # last upsample
208
+ self.dec_layers.append(
209
+ nn.Upsample(size=data_len, mode=interpolate_mode)
210
+ )
211
+ else: # mla
212
+ dec_up_scale = int(self.dec_param["up_scale"])
213
+ assert (
214
+ data_len >= dec_up_scale * num_patches
215
+ ) # transformer 중간 결과를 up_scale 만큼 upsample 했을 때 원본 보다는 작아야 최종 upsample 이 의미가 있음
216
+ dec_output_step = int(self.dec_param["output_step"])
217
+ assert num_layers % dec_output_step == 0
218
+ dec_num_convs_by_layer = int(self.dec_param["num_convs_by_layer"])
219
+ dec_kernel_size = int(self.dec_param["kernel_size"])
220
+ mid_feature_cnt = num_layers // dec_output_step
221
+ mla_channel = int(self.dec_param["channels"])
222
+ for _ in range(mid_feature_cnt):
223
+ # transformer block 중간 결과에서 각 step 별로 추출한 feature map 에 적용할 conv-upsample
224
+ dec_in_channel = embed_dim
225
+ dec_layers_each_upsample = []
226
+ for _ in range(dec_num_convs_by_layer):
227
+ dec_layers_each_upsample.append(
228
+ nn.Conv1d(
229
+ dec_in_channel,
230
+ mla_channel,
231
+ kernel_size=dec_kernel_size,
232
+ stride=1,
233
+ padding=(dec_kernel_size - 1) // 2,
234
+ )
235
+ )
236
+ dec_in_channel = mla_channel
237
+ dec_layers_each_upsample.append(
238
+ nn.Upsample(scale_factor=dec_up_scale, mode=interpolate_mode)
239
+ )
240
+ self.dec_layers.append(nn.Sequential(*dec_layers_each_upsample))
241
+ # last decoder layer: 중간 feature map 을 concat 한 이후, upsample
242
+ self.dec_layers.append(nn.Upsample(size=data_len, mode=interpolate_mode))
243
+
244
+ dec_out_channel = (
245
+ mla_channel * mid_feature_cnt
246
+ ) # self.dec_layers 를 transformer 중간 결과들에 적용한 feature map 개수(mid_feature_cnt)만큼 channel-wise concat 하기 때문에 그만큼 증가된 channel ��� 아래 self.cls 의 in_channel 로 사용되어어야 함
247
+
248
+ self.cls = nn.Conv1d(dec_out_channel, output_size, 1, bias=False)
249
+
250
+ def forward(self, input: torch.Tensor, y=None):
251
+ output = input
252
+
253
+ # patch embedding
254
+ output = self.patch_embed(output)
255
+
256
+ # positional embedding
257
+ output += self.pos_embed
258
+ output = self.pos_dropout(output)
259
+
260
+ outputs = []
261
+ # transformer encoder
262
+ for i, layer in enumerate(self.layers):
263
+ output = layer(output)
264
+ if self.dec_mode == "mla":
265
+ if (i + 1) % int(self.dec_param["output_step"]) == 0:
266
+ outputs.append(output.transpose(1, 2))
267
+ if self.dec_mode != "mla": # mla 의 경우 위에서 이미 추가
268
+ outputs.append(output.transpose(1, 2))
269
+
270
+ # decoder
271
+ if self.dec_mode == "naive":
272
+ assert len(outputs) == 1
273
+ output = outputs[0]
274
+ output = self.dec_layers[0](output)
275
+ elif self.dec_mode == "pup":
276
+ assert len(outputs) == 1
277
+ output = outputs[0]
278
+ pup_norm = self.dec_layers[0]
279
+ output = pup_norm(output.transpose(1, 2)).transpose(1, 2)
280
+ for i, dec_layer in enumerate(self.dec_layers[1:]):
281
+ output = dec_layer(output)
282
+ else: # mla
283
+ dec_output_step = int(self.dec_param["output_step"])
284
+ mid_feature_cnt = len(self.layers) // dec_output_step
285
+ assert len(outputs) == mid_feature_cnt
286
+ for i in range(len(outputs)):
287
+ outputs[i] = self.dec_layers[i](outputs[i])
288
+ output = torch.cat(outputs, dim=1)
289
+ output = self.dec_layers[-1](output)
290
+
291
+ return self.cls(output)
res/impl/SegFormer.py ADDED
@@ -0,0 +1,294 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ paper: https://arxiv.org/abs/2105.15203
3
+ - ref:
4
+ - encoder:
5
+ - https://github.com/NVlabs/SegFormer/blob/master/mmseg/models/backbones/mix_transformer.py
6
+ - https://github.com/open-mmlab/mmsegmentation/blob/master/mmseg/models/backbones/mit.py
7
+ - decoder:
8
+ - https://github.com/NVlabs/SegFormer/blob/master/mmseg/models/decode_heads/segformer_head.py
9
+ - https://github.com/open-mmlab/mmsegmentation/blob/master/mmseg/models/decode_heads/segformer_head.py
10
+ """
11
+
12
+
13
+ import torch
14
+ from torch import nn
15
+ from torch.functional import F
16
+ import math
17
+ from einops import rearrange
18
+
19
+
20
+ class MixFFN(nn.Module):
21
+ def __init__(self, embed_dim, channels, dropout=0.0):
22
+ super().__init__()
23
+
24
+ self.layers = nn.Sequential(
25
+ nn.Conv1d( # fc1
26
+ in_channels=embed_dim, out_channels=channels, kernel_size=1, stride=1
27
+ ),
28
+ nn.Conv1d( # position embed (depthwise-separable)
29
+ in_channels=channels,
30
+ out_channels=channels,
31
+ kernel_size=3,
32
+ stride=1,
33
+ padding=1,
34
+ groups=channels,
35
+ ),
36
+ nn.GELU(),
37
+ nn.Dropout(dropout),
38
+ nn.Conv1d( # fc2
39
+ in_channels=channels, out_channels=embed_dim, kernel_size=1
40
+ ),
41
+ nn.Dropout(dropout),
42
+ )
43
+
44
+ def forward(self, x):
45
+ out = x.transpose(1, 2)
46
+ out = self.layers(out)
47
+ out = out.transpose(1, 2)
48
+ return out
49
+
50
+
51
+ class EfficientMultiheadAttention(nn.Module):
52
+ """
53
+ PVT(Pyramid Vision Transformer)에서 사용한 Spatial-Reduction Attention 을 차용
54
+ 변수명 중 sr 은 Spatial-Reduction 의 약어
55
+ """
56
+
57
+ def __init__(
58
+ self, embed_dim, num_heads=8, attn_drop=0.0, proj_drop=0.0, sr_ratio=1
59
+ ):
60
+ super().__init__()
61
+
62
+ assert (
63
+ embed_dim % num_heads == 0
64
+ ), f"dim {embed_dim} should be divided by num_heads {num_heads}."
65
+
66
+ self.num_heads = num_heads
67
+ head_dim = embed_dim // num_heads
68
+ self.scale = head_dim**-0.5
69
+
70
+ self.q = nn.Linear(embed_dim, embed_dim, bias=False)
71
+ self.kv = nn.Linear(embed_dim, embed_dim * 2, bias=False)
72
+ self.attn_drop = nn.Dropout(attn_drop)
73
+ self.proj = nn.Linear(embed_dim, embed_dim)
74
+ self.proj_drop = nn.Dropout(proj_drop)
75
+
76
+ self.sr_ratio = sr_ratio
77
+ if sr_ratio > 1:
78
+ self.sr = nn.Conv1d(
79
+ embed_dim, embed_dim, kernel_size=sr_ratio, stride=sr_ratio
80
+ )
81
+ self.norm = nn.LayerNorm(embed_dim)
82
+
83
+ def forward(self, x):
84
+ B, N, C = x.shape
85
+ q = self.q(x)
86
+ q = rearrange(q, "b n (h c) -> b h n c", h=self.num_heads)
87
+
88
+ if self.sr_ratio > 1:
89
+ x_ = x.transpose(1, 2)
90
+ x_ = self.sr(x_).transpose(1, 2)
91
+ x_ = self.norm(x_)
92
+ kv = self.kv(x_)
93
+ kv = rearrange(
94
+ kv,
95
+ "b n (two_heads h c) -> two_heads b h n c",
96
+ two_heads=2,
97
+ h=self.num_heads,
98
+ )
99
+ else:
100
+ kv = self.kv(x)
101
+ kv = rearrange(
102
+ kv,
103
+ "b n (two_heads h c) -> two_heads b h n c",
104
+ two_heads=2,
105
+ h=self.num_heads,
106
+ )
107
+ k, v = kv[0], kv[1]
108
+
109
+ attn = (q @ k.transpose(-2, -1)) * self.scale
110
+ attn = attn.softmax(dim=-1)
111
+ attn = self.attn_drop(attn)
112
+
113
+ x = (attn @ v).transpose(1, 2)
114
+ x = x.reshape(B, N, C)
115
+ x = self.proj(x)
116
+ x = self.proj_drop(x)
117
+
118
+ return x
119
+
120
+
121
+ class TransformerBlock(nn.Module):
122
+ def __init__(self, embed_dim, num_heads, ffn_channels, dropout=0.2, sr_ratio=1):
123
+ super().__init__()
124
+
125
+ self.attn = nn.Sequential(
126
+ nn.LayerNorm(embed_dim),
127
+ EfficientMultiheadAttention(
128
+ embed_dim=embed_dim,
129
+ num_heads=num_heads,
130
+ attn_drop=dropout,
131
+ proj_drop=dropout,
132
+ sr_ratio=sr_ratio,
133
+ ),
134
+ )
135
+
136
+ self.ffn = nn.Sequential(
137
+ nn.LayerNorm(embed_dim),
138
+ MixFFN(embed_dim=embed_dim, channels=ffn_channels, dropout=dropout),
139
+ )
140
+
141
+ def forward(self, x):
142
+ x = x + self.attn(x)
143
+ x = x + self.ffn(x)
144
+ return x
145
+
146
+
147
+ class PatchEmbed(nn.Module):
148
+ def __init__(
149
+ self,
150
+ in_channels=1,
151
+ embed_dim=1024,
152
+ kernel_size=7,
153
+ stride=4,
154
+ padding=3,
155
+ bias=False,
156
+ ):
157
+ super().__init__()
158
+
159
+ self.projection = nn.Conv1d(
160
+ in_channels=in_channels,
161
+ out_channels=embed_dim,
162
+ kernel_size=kernel_size,
163
+ stride=stride,
164
+ padding=padding,
165
+ bias=bias,
166
+ )
167
+
168
+ def forward(self, x: torch.Tensor):
169
+ return self.projection(x).transpose(1, 2)
170
+
171
+
172
+ class MiT(nn.Module):
173
+ """MixVisionTransformer"""
174
+
175
+ def __init__(
176
+ self,
177
+ embed_dim=512,
178
+ num_blocks=[2, 2, 6, 2],
179
+ num_heads=[1, 2, "ceil"],
180
+ sr_ratios=[1, 2, "ceil"],
181
+ mlp_ratio=4,
182
+ dropout=0.2,
183
+ ):
184
+ super().__init__()
185
+
186
+ num_stages = len(num_blocks)
187
+ round_func = getattr(math, num_heads[2]) # math.ceil or match.floor
188
+ num_heads = [
189
+ round_func((num_heads[0] * math.pow(num_heads[1], itr)))
190
+ for itr in range(num_stages)
191
+ ]
192
+ round_func = getattr(math, sr_ratios[2]) # math.ceil or match.floor
193
+ sr_ratios = [
194
+ round_func(sr_ratios[0] * math.pow(sr_ratios[1], itr))
195
+ for itr in range(num_stages)
196
+ ]
197
+ sr_ratios.reverse()
198
+
199
+ self.embed_dims = [embed_dim * num_head for num_head in num_heads]
200
+ patch_kernel_sizes = [7] # [7, 3, 3, ..]
201
+ patch_kernel_sizes.extend([3] * (num_stages - 1))
202
+ patch_strides = [4] # [4, 2, 2, ..]
203
+ patch_strides.extend([2] * (num_stages - 1))
204
+ patch_paddings = [3] # [3, 1, 1, ..]
205
+ patch_paddings.extend([1] * (num_stages - 1))
206
+
207
+ in_channels = 1
208
+ self.stages = nn.ModuleList()
209
+ for i, num_block in enumerate(num_blocks):
210
+ patch_embed = PatchEmbed(
211
+ in_channels=in_channels,
212
+ embed_dim=self.embed_dims[i],
213
+ kernel_size=patch_kernel_sizes[i],
214
+ stride=patch_strides[i],
215
+ padding=patch_paddings[i],
216
+ )
217
+ blocks = nn.ModuleList(
218
+ [
219
+ TransformerBlock(
220
+ embed_dim=self.embed_dims[i],
221
+ num_heads=num_heads[i],
222
+ ffn_channels=mlp_ratio * self.embed_dims[i],
223
+ dropout=dropout,
224
+ sr_ratio=sr_ratios[i],
225
+ )
226
+ for _ in range(num_block)
227
+ ]
228
+ )
229
+ in_channels = self.embed_dims[i]
230
+ norm = nn.LayerNorm(self.embed_dims[i])
231
+ self.stages.append(nn.ModuleList([patch_embed, blocks, norm]))
232
+
233
+ def forward(self, x):
234
+ outs = []
235
+
236
+ for stage in self.stages:
237
+ x = stage[0](x) # patch embed
238
+ for block in stage[1]: # transformer blocks
239
+ x = block(x)
240
+ x = stage[2](x) # norm
241
+ x = x.transpose(1, 2)
242
+ outs.append(x)
243
+
244
+ return outs
245
+
246
+
247
+ class SegFormer(nn.Module):
248
+ def __init__(self, config):
249
+ super().__init__()
250
+
251
+ embed_dim = int(config.embed_dim)
252
+ num_blocks = config.num_blocks
253
+ num_heads = config.num_heads
254
+ assert len(num_heads) == 3 and num_heads[2] in ["floor", "ceil"]
255
+ sr_ratios = config.sr_ratios
256
+ assert len(sr_ratios) == 3 and sr_ratios[2] in ["floor", "ceil"]
257
+ mlp_ratio = int(config.mlp_ratio)
258
+ dropout = float(config.dropout)
259
+ decoder_channels = int(config.decoder_channels)
260
+ self.interpolate_mode = str(config.interpolate_mode)
261
+ output_size = int(config.output_size)
262
+
263
+ self.MiT = MiT(embed_dim, num_blocks, num_heads, sr_ratios, mlp_ratio, dropout)
264
+
265
+ num_stages = len(num_blocks)
266
+ self.decode_mlps = nn.ModuleList(
267
+ [
268
+ nn.Conv1d(self.MiT.embed_dims[i], decoder_channels, 1, bias=False)
269
+ for i in range(num_stages)
270
+ ]
271
+ )
272
+ self.decode_fusion = nn.Conv1d(
273
+ decoder_channels * num_stages, decoder_channels, 1, bias=False
274
+ )
275
+
276
+ self.cls = nn.Conv1d(decoder_channels, output_size, 1, bias=False)
277
+
278
+ def forward(self, input: torch.Tensor, y=None):
279
+ output = input
280
+
281
+ output = self.MiT(output)
282
+ for i, (_output, decode_mlp) in enumerate(zip(output, self.decode_mlps)):
283
+ _output = decode_mlp(_output)
284
+ if i != 0:
285
+ _output = F.interpolate(
286
+ _output, size=output[0].shape[2], mode=self.interpolate_mode
287
+ )
288
+ output[i] = _output
289
+
290
+ output = torch.concat(output, dim=1)
291
+ output = self.decode_fusion(output)
292
+ output = self.cls(output)
293
+
294
+ return F.interpolate(output, size=input.shape[2], mode=self.interpolate_mode)
res/impl/UNet3PlusDeepSup.py ADDED
@@ -0,0 +1,241 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ paper: https://arxiv.org/abs/2004.08790
3
+ ref: https://github.com/ZJUGiveLab/UNet-Version/blob/master/models/UNet_3Plus.py
4
+ """
5
+
6
+ import torch
7
+ from torch import nn
8
+ from torch.functional import F
9
+
10
+
11
+ class UNetConv(nn.Module):
12
+ def __init__(
13
+ self,
14
+ in_size,
15
+ out_size,
16
+ is_batchnorm=True,
17
+ num_layers=2,
18
+ kernel_size=3,
19
+ stride=1,
20
+ padding=1,
21
+ ):
22
+ super().__init__()
23
+ self.num_layers = num_layers
24
+
25
+ for i in range(num_layers):
26
+ seq = [nn.Conv1d(in_size, out_size, kernel_size, stride, padding)]
27
+ if is_batchnorm:
28
+ seq.append(nn.BatchNorm1d(out_size))
29
+ seq.append(nn.ReLU())
30
+ conv = nn.Sequential(*seq)
31
+ setattr(self, "conv%d" % i, conv)
32
+ in_size = out_size
33
+
34
+ def forward(self, inputs):
35
+ x = inputs
36
+ for i in range(self.num_layers):
37
+ conv = getattr(self, "conv%d" % i)
38
+ x = conv(x)
39
+
40
+ return x
41
+
42
+
43
+ class UNet3PlusDeepSup(nn.Module):
44
+ def __init__(self, config):
45
+ super().__init__()
46
+
47
+ self.config = config
48
+ inplanes = int(config.inplanes)
49
+ kernel_size = int(config.kernel_size)
50
+ padding = (kernel_size - 1) // 2
51
+ num_encoder_layers = int(config.num_encoder_layers)
52
+ encoder_batchnorm = bool(config.encoder_batchnorm)
53
+ self.num_depths = int(config.num_depths)
54
+ self.interpolate_mode = str(config.interpolate_mode)
55
+ dropout = float(config.dropout)
56
+ self.use_cgm = bool(config.use_cgm)
57
+ # sum_of_sup == True: 모든 sup 을 elementwise sum 하여 하나의 dense map 을 만들어 label 과 loss 를 구함
58
+ # sum_of_sup == False: 각 sup 과 label의 loss 를 각각 구하여 하나의 loss 에 저장
59
+ self.sum_of_sup = bool(config.sum_of_sup)
60
+ # TrialSetup._init_network_params 에서 설정됨
61
+ self.output_size: int = config.output_size
62
+
63
+ # Encoder
64
+ self.encoders = torch.nn.ModuleList()
65
+ for i in range(self.num_depths):
66
+ """(MaxPool - UNetConv) 를 수행하는 것이 하나의 depth 이고, 예외적으로 첫번째 depth 의 encode 결과는 (UNetConv)만 수행한 것"""
67
+ _encoders = []
68
+ if i != 0:
69
+ _encoders.append(nn.MaxPool1d(2))
70
+ _encoders.append(
71
+ UNetConv(
72
+ 1 if i == 0 else (inplanes * (2 ** (i - 1))),
73
+ inplanes * (2**i),
74
+ is_batchnorm=encoder_batchnorm,
75
+ num_layers=num_encoder_layers,
76
+ kernel_size=kernel_size,
77
+ stride=1,
78
+ padding=padding,
79
+ )
80
+ )
81
+ self.encoders.append(nn.Sequential(*_encoders))
82
+
83
+ # CGM: Classification-Guided Module
84
+ if self.use_cgm:
85
+ self.cls = nn.Sequential(
86
+ nn.Dropout(dropout),
87
+ nn.Conv1d(
88
+ inplanes * (2 ** (self.num_depths - 1)), 2 * self.output_size, 1
89
+ ),
90
+ nn.AdaptiveMaxPool1d(1),
91
+ nn.Sigmoid(),
92
+ )
93
+
94
+ # Decoder
95
+ self.up_channels = inplanes * self.num_depths
96
+
97
+ self.decoders = torch.nn.ModuleList()
98
+ for i in reversed(range(self.num_depths - 1)):
99
+ """
100
+ 각 decoder 는 각 encode 결과를 MaxPool 하거나 그대로(Conv,BatchNorm,Relu 만) 사용하거나 Upsample 된 결과를 수행하고 concat 하여 (Conv,BatchNorm,Relu)를 수행할 수 있도록 구성
101
+ 다만, Upsample 은 encode 결과와 size 를 맞추기 간편하도록 forward 단계에서 torch.functional.interpolate() 로 수행
102
+ """
103
+ # 각 단계별 decoder 는 항상 num_depths 만큼 구성되고 내부적으로 MaxPool/그대로/Upsample 수행할지가 달라짐
104
+ _decoders = torch.nn.ModuleList()
105
+ for j in range(self.num_depths):
106
+ _each_decoders = []
107
+ if j < i:
108
+ _each_decoders.append(nn.MaxPool1d(2 ** (i - j), ceil_mode=True))
109
+ if i < j < self.num_depths - 1:
110
+ _each_decoders.append(
111
+ nn.Conv1d(
112
+ inplanes * self.num_depths,
113
+ inplanes,
114
+ kernel_size,
115
+ padding=padding,
116
+ )
117
+ )
118
+ else:
119
+ _each_decoders.append(
120
+ nn.Conv1d(
121
+ inplanes * (2**j), inplanes, kernel_size, padding=padding
122
+ )
123
+ )
124
+ _each_decoders.append(nn.BatchNorm1d(inplanes))
125
+ _each_decoders.append(nn.ReLU())
126
+ _decoders.append(nn.Sequential(*_each_decoders))
127
+ _decoders.append(
128
+ nn.Sequential(
129
+ nn.Conv1d(
130
+ self.up_channels, self.up_channels, kernel_size, padding=padding
131
+ ),
132
+ nn.BatchNorm1d(self.up_channels),
133
+ nn.ReLU(),
134
+ )
135
+ )
136
+ self.decoders.append(_decoders)
137
+
138
+ # 앞 conv 들은 in channel 이 up_channels(inplanes*num_depths(원본에서는 320)), 마지막 conv 는 마지막 encoder 결과의 output_channel 과 맞춤
139
+ self.sup_conv = torch.nn.ModuleList()
140
+ for i in range(self.num_depths - 1):
141
+ self.sup_conv.append(
142
+ nn.Sequential(
143
+ nn.Conv1d(
144
+ self.up_channels, self.output_size, kernel_size, padding=padding
145
+ ),
146
+ nn.BatchNorm1d(self.output_size),
147
+ nn.ReLU(),
148
+ )
149
+ )
150
+ self.sup_conv.append(
151
+ nn.Sequential(
152
+ nn.Conv1d(
153
+ inplanes * (2 ** (self.num_depths - 1)),
154
+ self.output_size,
155
+ kernel_size,
156
+ padding=padding,
157
+ ),
158
+ nn.BatchNorm1d(self.output_size),
159
+ nn.ReLU(),
160
+ )
161
+ )
162
+
163
+ def forward(self, input: torch.Tensor, y=None):
164
+ # Encoder
165
+ output = input
166
+ enc_features = [] # X1Ee, X2Ee, .. , X5Ee
167
+ dec_features = [] # X5Ee, X4De, .. , X1De
168
+ for encoder in self.encoders:
169
+ output = encoder(output)
170
+ enc_features.append(output)
171
+ dec_features.append(output)
172
+
173
+ # CGM
174
+ cls_branch_max = None
175
+ if self.use_cgm:
176
+ # (B, 2*3(output_size), 1)
177
+ cls_branch: torch.Tensor = self.cls(enc_features[-1])
178
+ # (B, 3(output_size))
179
+ cls_branch_max = cls_branch.view(
180
+ input.shape[0], self.output_size, 2
181
+ ).argmax(2)
182
+
183
+ # Decoder
184
+ for i in reversed(range(self.num_depths - 1)):
185
+ _each_dec_feature = []
186
+ for j in range(self.num_depths):
187
+ if j <= i:
188
+ _each_enc = enc_features[j]
189
+ else:
190
+ _each_enc = F.interpolate(
191
+ dec_features[self.num_depths - j - 1],
192
+ enc_features[i].shape[2],
193
+ mode=self.interpolate_mode,
194
+ )
195
+ _each_dec_feature.append(
196
+ self.decoders[self.num_depths - i - 2][j](_each_enc)
197
+ )
198
+ dec_features.append(
199
+ self.decoders[self.num_depths - i - 2][-1](
200
+ torch.cat(_each_dec_feature, dim=1)
201
+ )
202
+ )
203
+
204
+ sup = []
205
+ for i, (dec_feature, sup_conv) in enumerate(
206
+ zip(dec_features, reversed(self.sup_conv))
207
+ ):
208
+ if i < self.num_depths - 1:
209
+ sup.append(
210
+ F.interpolate(
211
+ sup_conv(dec_feature),
212
+ input.shape[2],
213
+ mode=self.interpolate_mode,
214
+ )
215
+ )
216
+ else:
217
+ sup.append(sup_conv(dec_feature))
218
+
219
+ if self.use_cgm:
220
+ if self.sum_of_sup:
221
+ return torch.sigmoid(
222
+ sum(
223
+ [
224
+ torch.einsum("ijk,ij->ijk", [_sup, cls_branch_max])
225
+ for _sup in reversed(sup)
226
+ ]
227
+ )
228
+ )
229
+ else:
230
+ return [
231
+ torch.sigmoid(
232
+ torch.einsum("ijk,ij->ijk", [_sup, cls_branch_max])
233
+ for _sup in reversed(sup)
234
+ )
235
+ ]
236
+
237
+ else:
238
+ if self.sum_of_sup:
239
+ return torch.sigmoid(sum(sup))
240
+ else:
241
+ return [torch.sigmoid(_sup) for _sup in reversed(sup)]
res/models/hrnetv2/best_config.json DELETED
@@ -1,151 +0,0 @@
1
- {
2
- "train": {
3
- "progress": true,
4
- "random_seed": 2407041220,
5
- "resume_dir": [],
6
- "checkpoint_dir": "/bfai/nfs_export/workspace/share/result/wogh/hrnet/train-240704_123013",
7
- "checkpoint_save_freq": 1,
8
- "working_dir": "",
9
- "user": "wogh",
10
- "name": "hrnet",
11
- "exp_name": "wogh:hrnet",
12
- "type": "supervised",
13
- "task": "segmentation",
14
- "epochs": 501,
15
- "batch_size": 64,
16
- "hpo": {
17
- "num_samples": 256,
18
- "criteria": {
19
- "jaccard_avg": 1
20
- },
21
- "scheduler": {
22
- "ASHAScheduler": {
23
- "grace_period": 200,
24
- "max_t": 501
25
- }
26
- }
27
- },
28
- "label": {
29
- "num_labels": 3,
30
- "path": [
31
- "/bfai/nfs_export/workspace/share/labels/pqrst/ludb/train.csv",
32
- "/bfai/nfs_export/workspace/share/labels/pqrst/ludb/valid.csv",
33
- "/bfai/nfs_export/workspace/share/labels/pqrst/ludb/test.csv"
34
- ],
35
- "target": [
36
- "p_onoffs",
37
- "qrs_onoffs",
38
- "t_onoffs"
39
- ],
40
- "split_ratio": [
41
- 1,
42
- 1,
43
- 1
44
- ]
45
- },
46
- "resource_per_trial": {
47
- "num_workers": 1,
48
- "num_gpus_per_worker": 1,
49
- "num_cpus_per_worker": 16
50
- },
51
- "comment": "",
52
- "tracking": true,
53
- "available_resources": {
54
- "available_gpus": 16.0
55
- }
56
- },
57
- "solver": {
58
- "SolverPQRST": {
59
- "mixed_precision": true,
60
- "gradient_clip": 0.1
61
- }
62
- },
63
- "datasets": [
64
- {
65
- "ECGPQRSTDataset": {
66
- "lead_type": [
67
- "I",
68
- "II",
69
- "III",
70
- "aVR",
71
- "aVL",
72
- "aVF",
73
- "V1",
74
- "V2",
75
- "V3",
76
- "V4",
77
- "V5",
78
- "V6"
79
- ],
80
- "aux_data": [],
81
- "normalization": "z_norm",
82
- "second": 10,
83
- "hz": 500
84
- }
85
- }
86
- ],
87
- "models": [
88
- {
89
- "network": {
90
- "HRNetV2": {
91
- "data_len": 5000,
92
- "kernel_size": 5,
93
- "dilation": 1,
94
- "num_stages": 3,
95
- "num_blocks": 6,
96
- "num_modules": [
97
- 1,
98
- 1,
99
- 1,
100
- 4,
101
- 3
102
- ],
103
- "use_bottleneck": [
104
- 1,
105
- 0,
106
- 0,
107
- 0,
108
- 0
109
- ],
110
- "stage1_channels": 128,
111
- "num_channels_init": 48,
112
- "interpolate_mode": "linear",
113
- "task": "segmentation",
114
- "num_leads": 12,
115
- "num_aux": 0,
116
- "output_size": 3,
117
- "aux_output_size": 0
118
- }
119
- },
120
- "optimizer": [
121
- {
122
- "SGD": {
123
- "lr": 0.0983058839402403,
124
- "momentum": 0.9,
125
- "weight_decay": 0.0003850652731758502,
126
- "sharpness_min": false
127
- }
128
- }
129
- ],
130
- "scheduler": [
131
- {
132
- "PolynomialLR": {
133
- "total_iters": 501,
134
- "power": 0.0
135
- }
136
- }
137
- ]
138
- }
139
- ],
140
- "loss_fns": [
141
- {
142
- "BCEWithLogitsLoss": {}
143
- }
144
- ],
145
- "cur_epoch": 358,
146
- "cutoff": [
147
- 0.001163482666015625,
148
- 0.15087890625,
149
- -0.587890625
150
- ]
151
- }