File size: 5,031 Bytes
aefacda
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
"""
paper: https://arxiv.org/abs/1605.06211
ref: https://github.com/shelhamer/fcn.berkeleyvision.org/blob/master/voc-fcn8s/net.py
"""

import torch
import torch.nn as nn


class FCN(nn.Module):
    def __init__(self, config):
        super().__init__()

        self.config = config
        self.kernel_size = int(config.kernel_size)
        last_layer_kernel_size = int(config.last_layer_kernel_size)
        inplanes = int(config.inplanes)
        combine_conf: dict = config.combine_conf
        self.num_layers = int(combine_conf["num_layers"])
        self.first_padding = {6: 240, 5: 130, 4: 80}[self.num_layers]
        self.num_convs = int(config.num_convs)
        self.dilation = int(config.dilation)
        self.combine_until = int(combine_conf["combine_until"])
        assert self.combine_until < self.num_layers
        dropout = float(config.dropout)
        output_size = config.output_size  # 3(p, qrs, t)

        self.layers = nn.ModuleList()
        for i in range(self.num_layers):
            self.layers.append(
                self._make_layer(
                    1 if i == 0 else inplanes * (2 ** (i - 1)),
                    inplanes * (2 ** (i)),
                    is_first=True if i == 0 else False,
                )
            )
        # pool ๋‹จ๊ณ„๊ฐ€ ์—†๋Š” ๋งˆ์ง€๋ง‰ conv layer๋กœ ๋‹ค๋ฅธ layer ์™€ ๋‹ค๋ฅด๊ฒŒ conv ๊ฐœ์ˆ˜(2)์™€ channel์ด ๊ณ ์ •์ด๊ณ , dropout์„ ์ˆ˜ํ–‰
        self.layers.append(
            nn.Sequential(
                nn.Conv1d(inplanes * (2 ** (i)), 4096, last_layer_kernel_size),
                nn.BatchNorm1d(4096),
                nn.ReLU(),
                nn.Dropout(dropout),
                nn.Conv1d(4096, 4096, 1),
                nn.BatchNorm1d(4096),
                nn.ReLU(),
                nn.Dropout(dropout),
            )
        )
        self.score_convs = []
        self.up_convs = []
        for i in range(self.combine_until, self.num_layers - 1):
            # pool ๊ฒฐ๊ณผ๋ฅผ combine ํ•˜๋Š” ๋งŒํผ๋งŒ score_convs ์™€ up_convs ๊ฐ€ ์ƒ์„ฑ๋จ
            self.score_convs.append(
                nn.Conv1d(inplanes * (2 ** (i)), output_size, kernel_size=1, bias=False)
            )
            self.up_convs.append(
                nn.ConvTranspose1d(output_size, output_size, kernel_size=4, stride=2)
            )
        # pool ์ด ์—†๋Š” ๋งˆ์ง€๋ง‰ convs ๊ฒฐ๊ณผ์— ์ˆ˜ํ–‰ํ•˜๋Š” score_convs
        # self.score_convs ๋Š” ํ•ญ์ƒ self.up_convs ์˜ ๊ฐœ์ˆ˜๋ณด๋‹ค 1๊ฐœ ๋” ๋งŽ์Œ
        self.score_convs.append(nn.Conv1d(4096, output_size, kernel_size=1, bias=False))

        self.score_convs.reverse()
        self.score_convs = nn.ModuleList(self.score_convs)
        self.up_convs = nn.ModuleList(self.up_convs)
        self.last_up_convs = nn.ConvTranspose1d(
            output_size,
            output_size,
            kernel_size=2 ** (self.combine_until + 1) * 2,  # stride * 2
            stride=2 ** (self.combine_until + 1),
        )

    def _make_layer(
        self,
        in_channel: int,
        out_channel: int,
        is_first: bool = False,
    ):
        layer = []
        plane = in_channel
        for idx in range(self.num_convs):
            layer.append(
                nn.Conv1d(
                    plane,
                    out_channel,
                    kernel_size=self.kernel_size,
                    padding=self.first_padding
                    if idx == 0 and is_first
                    else (self.dilation * (self.kernel_size - 1)) // 2,
                    dilation=self.dilation,
                    bias=False,
                )
            )
            layer.append(nn.BatchNorm1d(out_channel))
            layer.append(nn.ReLU())
            plane = out_channel

        layer.append(nn.MaxPool1d(2, 2, ceil_mode=True))
        return nn.Sequential(*layer)

    def forward(self, input: torch.Tensor, y=None):
        output: torch.Tensor = input

        pools = []
        for idx, layer in enumerate(self.layers):
            output = layer(output)
            if self.combine_until <= idx < (self.num_layers - 1):
                pools.append(output)
        pools.reverse()

        output = self.score_convs[0](output)
        if len(pools) > 0:
            output = self.up_convs[0](output)
            for i in range(len(pools)):
                score_pool = self.score_convs[i + 1](pools[i])
                offset = (score_pool.shape[2] - output.shape[2]) // 2
                cropped_score_pool = torch.tensor_split(
                    score_pool, (offset, offset + output.shape[2]), dim=2
                )[1]
                output = torch.add(cropped_score_pool, output)
                if i < len(pools) - 1:  # ๋งˆ์ง€๋ง‰ up_conv ๋Š” last_up_convs ์ด์šฉ
                    output = self.up_convs[i + 1](output)
        output = self.last_up_convs(output)

        offset = (output.shape[2] - input.shape[2]) // 2
        cropped_score_pool = torch.tensor_split(
            output, (offset, offset + input.shape[2]), dim=2
        )[1]
        return cropped_score_pool