Spaces:
Sleeping
Sleeping
File size: 5,031 Bytes
aefacda |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 |
"""
paper: https://arxiv.org/abs/1605.06211
ref: https://github.com/shelhamer/fcn.berkeleyvision.org/blob/master/voc-fcn8s/net.py
"""
import torch
import torch.nn as nn
class FCN(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.kernel_size = int(config.kernel_size)
last_layer_kernel_size = int(config.last_layer_kernel_size)
inplanes = int(config.inplanes)
combine_conf: dict = config.combine_conf
self.num_layers = int(combine_conf["num_layers"])
self.first_padding = {6: 240, 5: 130, 4: 80}[self.num_layers]
self.num_convs = int(config.num_convs)
self.dilation = int(config.dilation)
self.combine_until = int(combine_conf["combine_until"])
assert self.combine_until < self.num_layers
dropout = float(config.dropout)
output_size = config.output_size # 3(p, qrs, t)
self.layers = nn.ModuleList()
for i in range(self.num_layers):
self.layers.append(
self._make_layer(
1 if i == 0 else inplanes * (2 ** (i - 1)),
inplanes * (2 ** (i)),
is_first=True if i == 0 else False,
)
)
# pool ๋จ๊ณ๊ฐ ์๋ ๋ง์ง๋ง conv layer๋ก ๋ค๋ฅธ layer ์ ๋ค๋ฅด๊ฒ conv ๊ฐ์(2)์ channel์ด ๊ณ ์ ์ด๊ณ , dropout์ ์ํ
self.layers.append(
nn.Sequential(
nn.Conv1d(inplanes * (2 ** (i)), 4096, last_layer_kernel_size),
nn.BatchNorm1d(4096),
nn.ReLU(),
nn.Dropout(dropout),
nn.Conv1d(4096, 4096, 1),
nn.BatchNorm1d(4096),
nn.ReLU(),
nn.Dropout(dropout),
)
)
self.score_convs = []
self.up_convs = []
for i in range(self.combine_until, self.num_layers - 1):
# pool ๊ฒฐ๊ณผ๋ฅผ combine ํ๋ ๋งํผ๋ง score_convs ์ up_convs ๊ฐ ์์ฑ๋จ
self.score_convs.append(
nn.Conv1d(inplanes * (2 ** (i)), output_size, kernel_size=1, bias=False)
)
self.up_convs.append(
nn.ConvTranspose1d(output_size, output_size, kernel_size=4, stride=2)
)
# pool ์ด ์๋ ๋ง์ง๋ง convs ๊ฒฐ๊ณผ์ ์ํํ๋ score_convs
# self.score_convs ๋ ํญ์ self.up_convs ์ ๊ฐ์๋ณด๋ค 1๊ฐ ๋ ๋ง์
self.score_convs.append(nn.Conv1d(4096, output_size, kernel_size=1, bias=False))
self.score_convs.reverse()
self.score_convs = nn.ModuleList(self.score_convs)
self.up_convs = nn.ModuleList(self.up_convs)
self.last_up_convs = nn.ConvTranspose1d(
output_size,
output_size,
kernel_size=2 ** (self.combine_until + 1) * 2, # stride * 2
stride=2 ** (self.combine_until + 1),
)
def _make_layer(
self,
in_channel: int,
out_channel: int,
is_first: bool = False,
):
layer = []
plane = in_channel
for idx in range(self.num_convs):
layer.append(
nn.Conv1d(
plane,
out_channel,
kernel_size=self.kernel_size,
padding=self.first_padding
if idx == 0 and is_first
else (self.dilation * (self.kernel_size - 1)) // 2,
dilation=self.dilation,
bias=False,
)
)
layer.append(nn.BatchNorm1d(out_channel))
layer.append(nn.ReLU())
plane = out_channel
layer.append(nn.MaxPool1d(2, 2, ceil_mode=True))
return nn.Sequential(*layer)
def forward(self, input: torch.Tensor, y=None):
output: torch.Tensor = input
pools = []
for idx, layer in enumerate(self.layers):
output = layer(output)
if self.combine_until <= idx < (self.num_layers - 1):
pools.append(output)
pools.reverse()
output = self.score_convs[0](output)
if len(pools) > 0:
output = self.up_convs[0](output)
for i in range(len(pools)):
score_pool = self.score_convs[i + 1](pools[i])
offset = (score_pool.shape[2] - output.shape[2]) // 2
cropped_score_pool = torch.tensor_split(
score_pool, (offset, offset + output.shape[2]), dim=2
)[1]
output = torch.add(cropped_score_pool, output)
if i < len(pools) - 1: # ๋ง์ง๋ง up_conv ๋ last_up_convs ์ด์ฉ
output = self.up_convs[i + 1](output)
output = self.last_up_convs(output)
offset = (output.shape[2] - input.shape[2]) // 2
cropped_score_pool = torch.tensor_split(
output, (offset, offset + input.shape[2]), dim=2
)[1]
return cropped_score_pool
|