Spaces:
Sleeping
Sleeping
""" | |
paper: https://arxiv.org/abs/1802.02611 | |
ref: | |
- https://github.com/tensorflow/models/tree/master/research/deeplab | |
- https://github.com/VainF/DeepLabV3Plus-Pytorch | |
- https://github.com/Hyunjulie/KR-Reading-Computer-Vision-Papers/blob/master/DeepLabv3%2B/deeplabv3p.py | |
""" | |
import math | |
import torch | |
from torch import nn | |
from torch.functional import F | |
class AtrousSeparableConv1d(nn.Module): | |
def __init__( | |
self, inplanes, planes, kernel_size=3, stride=1, dilation=1, bias=False | |
): | |
super(AtrousSeparableConv1d, self).__init__() | |
self.depthwise = nn.Conv1d( | |
inplanes, | |
inplanes, | |
kernel_size, | |
stride, | |
0, | |
dilation, | |
groups=inplanes, | |
bias=bias, | |
) | |
self.pointwise = nn.Conv1d(inplanes, planes, 1, 1, 0, 1, 1, bias=bias) | |
def forward(self, x): | |
x = self.apply_fixed_padding( | |
x, self.depthwise.kernel_size[0], rate=self.depthwise.dilation[0] | |
) | |
x = self.depthwise(x) | |
x = self.pointwise(x) | |
return x | |
def apply_fixed_padding(self, inputs, kernel_size, rate): | |
""" | |
ํด๋น ํจ์๋ (dilation)rate ์ kernel_size ์ ๋ฐ๋ผ output ์ ํฌ๊ธฐ๊ฐ input ์ ํฌ๊ธฐ์ ๋์ผํด์ง ์ ์๋๋ก input ์ padding ์ ์ ์ฉํฉ๋๋ค. | |
๋ค๋ง, stride ๊ฐ 2 ์ด์์ธ ๊ฒฝ์ฐ์๋ ํด๋น ํจ์๋ฅผ ๊ฑฐ์น๋๋ผ๋ input ๊ณผ output ํฌ๊ธฐ๊ฐ ๋์ผํด์ง์ง ์์ ์ ์์ต๋๋ค. | |
์ด ๊ฒฝ์ฐ๋ ์ต๋ํ input ๊ณผ output ํฌ๊ธฐ๋ฅผ ๋ง์ถฐ์ฃผ๋ ๊ฒ์ ์๋ฏธ๊ฐ ์๊ณ , ์ ์ฒด ๋คํธ์ํฌ์ ๋ง์ง๋ง upsample ๋จ๊ณ์์ ์ต์ข ์ ์ผ๋ก ํฌ๊ธฐ๋ฅผ ๋ง์ถฐ์ค๋๋ค. | |
""" | |
kernel_size_effective = kernel_size + (kernel_size - 1) * (rate - 1) | |
pad_total = kernel_size_effective - 1 | |
pad_beg = pad_total // 2 | |
pad_end = pad_total - pad_beg | |
padded_inputs = F.pad(inputs, (pad_beg, pad_end)) | |
return padded_inputs | |
class Block(nn.Module): | |
def __init__( | |
self, | |
inplanes, | |
planes, | |
reps, | |
kernel_size=3, | |
stride=1, | |
dilation=1, | |
start_with_relu=True, | |
grow_first=True, | |
is_last=False, | |
): | |
super(Block, self).__init__() | |
if planes != inplanes or stride != 1: | |
self.skip = nn.Conv1d(inplanes, planes, 1, stride=stride, bias=False) | |
self.skipbn = nn.BatchNorm1d(planes) | |
else: | |
self.skip = None | |
self.relu = nn.ReLU(inplace=True) | |
rep = [] | |
filters = inplanes | |
if grow_first: | |
rep.append(self.relu) | |
rep.append( | |
AtrousSeparableConv1d( | |
inplanes, planes, kernel_size, stride=1, dilation=dilation | |
) | |
) | |
rep.append(nn.BatchNorm1d(planes)) | |
filters = planes | |
for _ in range(reps - 1): | |
rep.append(self.relu) | |
rep.append( | |
AtrousSeparableConv1d( | |
filters, filters, kernel_size, stride=1, dilation=dilation | |
) | |
) | |
rep.append(nn.BatchNorm1d(filters)) | |
if not grow_first: | |
rep.append(self.relu) | |
rep.append( | |
AtrousSeparableConv1d( | |
inplanes, planes, kernel_size, stride=1, dilation=dilation | |
) | |
) | |
rep.append(nn.BatchNorm1d(planes)) | |
if not start_with_relu: | |
rep = rep[1:] | |
if stride == 2: | |
rep.append(AtrousSeparableConv1d(planes, planes, kernel_size, stride=2)) | |
elif stride == 1: | |
if is_last: | |
rep.append(AtrousSeparableConv1d(planes, planes, kernel_size, stride=1)) | |
else: | |
raise NotImplementedError("stride must be 1 or 2 in Block.") | |
self.rep = nn.Sequential(*rep) | |
def forward(self, inp): | |
x = self.rep(inp) | |
if self.skip is not None: | |
skip = self.skip(inp) | |
skip = self.skipbn(skip) | |
else: | |
skip = inp | |
x += skip | |
return x | |
class Xception(nn.Module): | |
"""Modified Aligned Xception""" | |
def __init__( | |
self, | |
inplanes=1, | |
output_stride=16, | |
kernel_size=3, | |
middle_repeat=16, | |
middle_block_rate=1, | |
exit_block_rates=(1, 2), | |
): | |
super(Xception, self).__init__() | |
if output_stride == 16: | |
entry3_stride = 2 | |
elif output_stride == 8: | |
entry3_stride = 1 | |
else: | |
raise NotImplementedError | |
self.conv1 = nn.Conv1d( | |
inplanes, | |
32, | |
kernel_size, | |
stride=2, | |
padding=(kernel_size - 1) // 2, | |
bias=False, | |
) | |
self.bn1 = nn.BatchNorm1d(32) | |
self.relu = nn.ReLU(inplace=True) | |
self.conv2 = nn.Conv1d( | |
32, 64, kernel_size, stride=1, padding=(kernel_size - 1) // 2, bias=False | |
) | |
self.bn2 = nn.BatchNorm1d(64) | |
self.entry1 = Block( | |
64, 128, reps=2, kernel_size=kernel_size, stride=2, start_with_relu=False | |
) | |
self.entry2 = Block( | |
128, | |
256, | |
reps=2, | |
kernel_size=kernel_size, | |
stride=2, | |
start_with_relu=True, | |
grow_first=True, | |
) | |
self.entry3 = Block( | |
256, | |
728, | |
reps=2, | |
kernel_size=kernel_size, | |
stride=entry3_stride, | |
start_with_relu=True, | |
grow_first=True, | |
is_last=True, | |
) | |
self.middle = nn.Sequential( | |
*[ | |
Block( | |
728, | |
728, | |
reps=3, | |
kernel_size=kernel_size, | |
stride=1, | |
dilation=middle_block_rate, | |
start_with_relu=True, | |
grow_first=True, | |
) | |
for _ in range(middle_repeat) | |
] | |
) | |
self.exit = Block( | |
728, | |
1024, | |
reps=2, | |
kernel_size=kernel_size, | |
stride=1, | |
dilation=exit_block_rates[0], | |
start_with_relu=True, | |
grow_first=False, | |
is_last=True, | |
) | |
self.conv3 = AtrousSeparableConv1d( | |
1024, 1536, kernel_size, stride=1, dilation=exit_block_rates[1] | |
) | |
self.bn3 = nn.BatchNorm1d(1536) | |
self.conv4 = AtrousSeparableConv1d( | |
1536, 1536, kernel_size, stride=1, dilation=exit_block_rates[1] | |
) | |
self.bn4 = nn.BatchNorm1d(1536) | |
self.conv5 = AtrousSeparableConv1d( | |
1536, 2048, kernel_size, stride=1, dilation=exit_block_rates[1] | |
) | |
self.bn5 = nn.BatchNorm1d(2048) | |
def forward(self, x: torch.Tensor): | |
x = self.conv1(x) | |
x = self.bn1(x) | |
x = self.relu(x) | |
x = self.conv2(x) | |
x = self.bn2(x) | |
x = self.relu(x) | |
low_level = x = self.entry1(x) | |
x = self.entry2(x) | |
x = self.entry3(x) | |
x = self.middle(x) | |
x = self.exit(x) | |
x = self.conv3(x) | |
x = self.bn3(x) | |
x = self.relu(x) | |
x = self.conv4(x) | |
x = self.bn4(x) | |
x = self.relu(x) | |
x = self.conv5(x) | |
x = self.bn5(x) | |
x = self.relu(x) | |
return x, low_level | |
class ASPP(nn.Module): | |
"""Atrous Spatial Pyramid Pooling""" | |
def __init__(self, inplanes, planes, rate, kernel_size=3): | |
super(ASPP, self).__init__() | |
if rate == 1: | |
kernel_size = 1 | |
padding = 0 | |
else: | |
padding = rate * (kernel_size - 1) // 2 | |
self.atrous_convolution = nn.Conv1d( | |
inplanes, | |
planes, | |
kernel_size=kernel_size, | |
stride=1, | |
padding=padding, | |
dilation=rate, | |
bias=False, | |
) | |
self.bn = nn.BatchNorm1d(planes) | |
self.relu = nn.ReLU() | |
def forward(self, x): | |
x = self.atrous_convolution(x) | |
x = self.bn(x) | |
return self.relu(x) | |
class DeepLabV3Plus(nn.Module): | |
def __init__(self, config): | |
super(DeepLabV3Plus, self).__init__() | |
self.config = config | |
# output_stride: (input's spatial resolution / output's resolution) | |
output_stride = int(config.output_stride) | |
kernel_size = int(config.kernel_size) | |
middle_block_rate = int(config.middle_block_rate) | |
exit_block_rates: list = config.exit_block_rates | |
middle_repeat = int(config.middle_repeat) | |
self.interpolate_mode = str(config.interpolate_mode) | |
aspp_channel = int(config.aspp_channel) | |
aspp_rate: list = config.aspp_rate | |
output_size = config.output_size # 3(p, qrs, t) | |
self.xception_features = Xception( | |
output_stride=output_stride, | |
kernel_size=kernel_size, | |
middle_repeat=middle_repeat, | |
middle_block_rate=middle_block_rate, | |
exit_block_rates=exit_block_rates, | |
) | |
# ASPP | |
self.aspp1 = ASPP( | |
2048, aspp_channel, rate=aspp_rate[0], kernel_size=kernel_size | |
) | |
self.aspp2 = ASPP( | |
2048, aspp_channel, rate=aspp_rate[1], kernel_size=kernel_size | |
) | |
self.aspp3 = ASPP( | |
2048, aspp_channel, rate=aspp_rate[2], kernel_size=kernel_size | |
) | |
self.aspp4 = ASPP( | |
2048, aspp_channel, rate=aspp_rate[3], kernel_size=kernel_size | |
) | |
self.relu = nn.ReLU() | |
self.global_avg_pool = nn.Sequential( | |
nn.AdaptiveAvgPool1d(1), | |
nn.Conv1d(2048, aspp_channel, 1, stride=1, bias=False), | |
nn.BatchNorm1d(aspp_channel), | |
nn.ReLU(), | |
) | |
self.conv1 = nn.Conv1d(aspp_channel * 5, aspp_channel, 1, bias=False) | |
self.bn1 = nn.BatchNorm1d(aspp_channel) | |
# adopt [1x1, 48] for channel reduction. | |
self.conv2 = nn.Conv1d(128, 48, 1, bias=False) | |
self.bn2 = nn.BatchNorm1d(48) | |
self.last_conv = nn.Sequential( | |
nn.Conv1d( | |
aspp_channel + 48, | |
256, | |
kernel_size=kernel_size, | |
stride=1, | |
padding=(kernel_size - 1) // 2, | |
bias=False, | |
), | |
nn.BatchNorm1d(256), | |
nn.ReLU(), | |
nn.Conv1d( | |
256, | |
256, | |
kernel_size=kernel_size, | |
stride=1, | |
padding=(kernel_size - 1) // 2, | |
bias=False, | |
), | |
nn.BatchNorm1d(256), | |
nn.ReLU(), | |
nn.Conv1d(256, output_size, kernel_size=1, stride=1), | |
) | |
def forward(self, input): | |
x, low_level_features = self.xception_features(input) | |
x1 = self.aspp1(x) | |
x2 = self.aspp2(x) | |
x3 = self.aspp3(x) | |
x4 = self.aspp4(x) | |
x5 = self.global_avg_pool(x) | |
x5 = F.interpolate(x5, size=x4.shape[2:], mode=self.interpolate_mode) | |
x = torch.cat((x1, x2, x3, x4, x5), dim=1) | |
x = self.conv1(x) | |
x = self.bn1(x) | |
x = self.relu(x) | |
x = F.interpolate( | |
x, size=int(math.ceil(input.shape[-1] / 4)), mode=self.interpolate_mode | |
) | |
low_level_features = self.conv2(low_level_features) | |
low_level_features = self.bn2(low_level_features) | |
low_level_features = self.relu(low_level_features) | |
x = torch.cat((x, low_level_features), dim=1) | |
x = self.last_conv(x) | |
x = F.interpolate(x, size=input.shape[2:], mode=self.interpolate_mode) | |
return x | |