wogh2012's picture
refactor: add implementations
aefacda
"""
paper: https://arxiv.org/abs/1605.06211
ref: https://github.com/shelhamer/fcn.berkeleyvision.org/blob/master/voc-fcn8s/net.py
"""
import torch
import torch.nn as nn
class FCN(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.kernel_size = int(config.kernel_size)
last_layer_kernel_size = int(config.last_layer_kernel_size)
inplanes = int(config.inplanes)
combine_conf: dict = config.combine_conf
self.num_layers = int(combine_conf["num_layers"])
self.first_padding = {6: 240, 5: 130, 4: 80}[self.num_layers]
self.num_convs = int(config.num_convs)
self.dilation = int(config.dilation)
self.combine_until = int(combine_conf["combine_until"])
assert self.combine_until < self.num_layers
dropout = float(config.dropout)
output_size = config.output_size # 3(p, qrs, t)
self.layers = nn.ModuleList()
for i in range(self.num_layers):
self.layers.append(
self._make_layer(
1 if i == 0 else inplanes * (2 ** (i - 1)),
inplanes * (2 ** (i)),
is_first=True if i == 0 else False,
)
)
# pool ๋‹จ๊ณ„๊ฐ€ ์—†๋Š” ๋งˆ์ง€๋ง‰ conv layer๋กœ ๋‹ค๋ฅธ layer ์™€ ๋‹ค๋ฅด๊ฒŒ conv ๊ฐœ์ˆ˜(2)์™€ channel์ด ๊ณ ์ •์ด๊ณ , dropout์„ ์ˆ˜ํ–‰
self.layers.append(
nn.Sequential(
nn.Conv1d(inplanes * (2 ** (i)), 4096, last_layer_kernel_size),
nn.BatchNorm1d(4096),
nn.ReLU(),
nn.Dropout(dropout),
nn.Conv1d(4096, 4096, 1),
nn.BatchNorm1d(4096),
nn.ReLU(),
nn.Dropout(dropout),
)
)
self.score_convs = []
self.up_convs = []
for i in range(self.combine_until, self.num_layers - 1):
# pool ๊ฒฐ๊ณผ๋ฅผ combine ํ•˜๋Š” ๋งŒํผ๋งŒ score_convs ์™€ up_convs ๊ฐ€ ์ƒ์„ฑ๋จ
self.score_convs.append(
nn.Conv1d(inplanes * (2 ** (i)), output_size, kernel_size=1, bias=False)
)
self.up_convs.append(
nn.ConvTranspose1d(output_size, output_size, kernel_size=4, stride=2)
)
# pool ์ด ์—†๋Š” ๋งˆ์ง€๋ง‰ convs ๊ฒฐ๊ณผ์— ์ˆ˜ํ–‰ํ•˜๋Š” score_convs
# self.score_convs ๋Š” ํ•ญ์ƒ self.up_convs ์˜ ๊ฐœ์ˆ˜๋ณด๋‹ค 1๊ฐœ ๋” ๋งŽ์Œ
self.score_convs.append(nn.Conv1d(4096, output_size, kernel_size=1, bias=False))
self.score_convs.reverse()
self.score_convs = nn.ModuleList(self.score_convs)
self.up_convs = nn.ModuleList(self.up_convs)
self.last_up_convs = nn.ConvTranspose1d(
output_size,
output_size,
kernel_size=2 ** (self.combine_until + 1) * 2, # stride * 2
stride=2 ** (self.combine_until + 1),
)
def _make_layer(
self,
in_channel: int,
out_channel: int,
is_first: bool = False,
):
layer = []
plane = in_channel
for idx in range(self.num_convs):
layer.append(
nn.Conv1d(
plane,
out_channel,
kernel_size=self.kernel_size,
padding=self.first_padding
if idx == 0 and is_first
else (self.dilation * (self.kernel_size - 1)) // 2,
dilation=self.dilation,
bias=False,
)
)
layer.append(nn.BatchNorm1d(out_channel))
layer.append(nn.ReLU())
plane = out_channel
layer.append(nn.MaxPool1d(2, 2, ceil_mode=True))
return nn.Sequential(*layer)
def forward(self, input: torch.Tensor, y=None):
output: torch.Tensor = input
pools = []
for idx, layer in enumerate(self.layers):
output = layer(output)
if self.combine_until <= idx < (self.num_layers - 1):
pools.append(output)
pools.reverse()
output = self.score_convs[0](output)
if len(pools) > 0:
output = self.up_convs[0](output)
for i in range(len(pools)):
score_pool = self.score_convs[i + 1](pools[i])
offset = (score_pool.shape[2] - output.shape[2]) // 2
cropped_score_pool = torch.tensor_split(
score_pool, (offset, offset + output.shape[2]), dim=2
)[1]
output = torch.add(cropped_score_pool, output)
if i < len(pools) - 1: # ๋งˆ์ง€๋ง‰ up_conv ๋Š” last_up_convs ์ด์šฉ
output = self.up_convs[i + 1](output)
output = self.last_up_convs(output)
offset = (output.shape[2] - input.shape[2]) // 2
cropped_score_pool = torch.tensor_split(
output, (offset, offset + input.shape[2]), dim=2
)[1]
return cropped_score_pool