Spaces:
Sleeping
Sleeping
""" | |
paper: https://arxiv.org/abs/1605.06211 | |
ref: https://github.com/shelhamer/fcn.berkeleyvision.org/blob/master/voc-fcn8s/net.py | |
""" | |
import torch | |
import torch.nn as nn | |
class FCN(nn.Module): | |
def __init__(self, config): | |
super().__init__() | |
self.config = config | |
self.kernel_size = int(config.kernel_size) | |
last_layer_kernel_size = int(config.last_layer_kernel_size) | |
inplanes = int(config.inplanes) | |
combine_conf: dict = config.combine_conf | |
self.num_layers = int(combine_conf["num_layers"]) | |
self.first_padding = {6: 240, 5: 130, 4: 80}[self.num_layers] | |
self.num_convs = int(config.num_convs) | |
self.dilation = int(config.dilation) | |
self.combine_until = int(combine_conf["combine_until"]) | |
assert self.combine_until < self.num_layers | |
dropout = float(config.dropout) | |
output_size = config.output_size # 3(p, qrs, t) | |
self.layers = nn.ModuleList() | |
for i in range(self.num_layers): | |
self.layers.append( | |
self._make_layer( | |
1 if i == 0 else inplanes * (2 ** (i - 1)), | |
inplanes * (2 ** (i)), | |
is_first=True if i == 0 else False, | |
) | |
) | |
# pool ๋จ๊ณ๊ฐ ์๋ ๋ง์ง๋ง conv layer๋ก ๋ค๋ฅธ layer ์ ๋ค๋ฅด๊ฒ conv ๊ฐ์(2)์ channel์ด ๊ณ ์ ์ด๊ณ , dropout์ ์ํ | |
self.layers.append( | |
nn.Sequential( | |
nn.Conv1d(inplanes * (2 ** (i)), 4096, last_layer_kernel_size), | |
nn.BatchNorm1d(4096), | |
nn.ReLU(), | |
nn.Dropout(dropout), | |
nn.Conv1d(4096, 4096, 1), | |
nn.BatchNorm1d(4096), | |
nn.ReLU(), | |
nn.Dropout(dropout), | |
) | |
) | |
self.score_convs = [] | |
self.up_convs = [] | |
for i in range(self.combine_until, self.num_layers - 1): | |
# pool ๊ฒฐ๊ณผ๋ฅผ combine ํ๋ ๋งํผ๋ง score_convs ์ up_convs ๊ฐ ์์ฑ๋จ | |
self.score_convs.append( | |
nn.Conv1d(inplanes * (2 ** (i)), output_size, kernel_size=1, bias=False) | |
) | |
self.up_convs.append( | |
nn.ConvTranspose1d(output_size, output_size, kernel_size=4, stride=2) | |
) | |
# pool ์ด ์๋ ๋ง์ง๋ง convs ๊ฒฐ๊ณผ์ ์ํํ๋ score_convs | |
# self.score_convs ๋ ํญ์ self.up_convs ์ ๊ฐ์๋ณด๋ค 1๊ฐ ๋ ๋ง์ | |
self.score_convs.append(nn.Conv1d(4096, output_size, kernel_size=1, bias=False)) | |
self.score_convs.reverse() | |
self.score_convs = nn.ModuleList(self.score_convs) | |
self.up_convs = nn.ModuleList(self.up_convs) | |
self.last_up_convs = nn.ConvTranspose1d( | |
output_size, | |
output_size, | |
kernel_size=2 ** (self.combine_until + 1) * 2, # stride * 2 | |
stride=2 ** (self.combine_until + 1), | |
) | |
def _make_layer( | |
self, | |
in_channel: int, | |
out_channel: int, | |
is_first: bool = False, | |
): | |
layer = [] | |
plane = in_channel | |
for idx in range(self.num_convs): | |
layer.append( | |
nn.Conv1d( | |
plane, | |
out_channel, | |
kernel_size=self.kernel_size, | |
padding=self.first_padding | |
if idx == 0 and is_first | |
else (self.dilation * (self.kernel_size - 1)) // 2, | |
dilation=self.dilation, | |
bias=False, | |
) | |
) | |
layer.append(nn.BatchNorm1d(out_channel)) | |
layer.append(nn.ReLU()) | |
plane = out_channel | |
layer.append(nn.MaxPool1d(2, 2, ceil_mode=True)) | |
return nn.Sequential(*layer) | |
def forward(self, input: torch.Tensor, y=None): | |
output: torch.Tensor = input | |
pools = [] | |
for idx, layer in enumerate(self.layers): | |
output = layer(output) | |
if self.combine_until <= idx < (self.num_layers - 1): | |
pools.append(output) | |
pools.reverse() | |
output = self.score_convs[0](output) | |
if len(pools) > 0: | |
output = self.up_convs[0](output) | |
for i in range(len(pools)): | |
score_pool = self.score_convs[i + 1](pools[i]) | |
offset = (score_pool.shape[2] - output.shape[2]) // 2 | |
cropped_score_pool = torch.tensor_split( | |
score_pool, (offset, offset + output.shape[2]), dim=2 | |
)[1] | |
output = torch.add(cropped_score_pool, output) | |
if i < len(pools) - 1: # ๋ง์ง๋ง up_conv ๋ last_up_convs ์ด์ฉ | |
output = self.up_convs[i + 1](output) | |
output = self.last_up_convs(output) | |
offset = (output.shape[2] - input.shape[2]) // 2 | |
cropped_score_pool = torch.tensor_split( | |
output, (offset, offset + input.shape[2]), dim=2 | |
)[1] | |
return cropped_score_pool | |