ECG_Delineation / res /impl /DeepLabV3Plus.py
wogh2012's picture
refactor: add implementations
aefacda
"""
paper: https://arxiv.org/abs/1802.02611
ref:
- https://github.com/tensorflow/models/tree/master/research/deeplab
- https://github.com/VainF/DeepLabV3Plus-Pytorch
- https://github.com/Hyunjulie/KR-Reading-Computer-Vision-Papers/blob/master/DeepLabv3%2B/deeplabv3p.py
"""
import math
import torch
from torch import nn
from torch.functional import F
class AtrousSeparableConv1d(nn.Module):
def __init__(
self, inplanes, planes, kernel_size=3, stride=1, dilation=1, bias=False
):
super(AtrousSeparableConv1d, self).__init__()
self.depthwise = nn.Conv1d(
inplanes,
inplanes,
kernel_size,
stride,
0,
dilation,
groups=inplanes,
bias=bias,
)
self.pointwise = nn.Conv1d(inplanes, planes, 1, 1, 0, 1, 1, bias=bias)
def forward(self, x):
x = self.apply_fixed_padding(
x, self.depthwise.kernel_size[0], rate=self.depthwise.dilation[0]
)
x = self.depthwise(x)
x = self.pointwise(x)
return x
def apply_fixed_padding(self, inputs, kernel_size, rate):
"""
ํ•ด๋‹น ํ•จ์ˆ˜๋Š” (dilation)rate ์™€ kernel_size ์— ๋”ฐ๋ผ output ์˜ ํฌ๊ธฐ๊ฐ€ input ์˜ ํฌ๊ธฐ์™€ ๋™์ผํ•ด์งˆ ์ˆ˜ ์žˆ๋„๋ก input ์— padding ์„ ์ ์šฉํ•ฉ๋‹ˆ๋‹ค.
๋‹ค๋งŒ, stride ๊ฐ€ 2 ์ด์ƒ์ธ ๊ฒฝ์šฐ์—๋Š” ํ•ด๋‹น ํ•จ์ˆ˜๋ฅผ ๊ฑฐ์น˜๋”๋ผ๋„ input ๊ณผ output ํฌ๊ธฐ๊ฐ€ ๋™์ผํ•ด์ง€์ง€ ์•Š์„ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
์ด ๊ฒฝ์šฐ๋Š” ์ตœ๋Œ€ํ•œ input ๊ณผ output ํฌ๊ธฐ๋ฅผ ๋งž์ถฐ์ฃผ๋Š” ๊ฒƒ์— ์˜๋ฏธ๊ฐ€ ์žˆ๊ณ , ์ „์ฒด ๋„คํŠธ์›Œํฌ์˜ ๋งˆ์ง€๋ง‰ upsample ๋‹จ๊ณ„์—์„œ ์ตœ์ข…์ ์œผ๋กœ ํฌ๊ธฐ๋ฅผ ๋งž์ถฐ์ค๋‹ˆ๋‹ค.
"""
kernel_size_effective = kernel_size + (kernel_size - 1) * (rate - 1)
pad_total = kernel_size_effective - 1
pad_beg = pad_total // 2
pad_end = pad_total - pad_beg
padded_inputs = F.pad(inputs, (pad_beg, pad_end))
return padded_inputs
class Block(nn.Module):
def __init__(
self,
inplanes,
planes,
reps,
kernel_size=3,
stride=1,
dilation=1,
start_with_relu=True,
grow_first=True,
is_last=False,
):
super(Block, self).__init__()
if planes != inplanes or stride != 1:
self.skip = nn.Conv1d(inplanes, planes, 1, stride=stride, bias=False)
self.skipbn = nn.BatchNorm1d(planes)
else:
self.skip = None
self.relu = nn.ReLU(inplace=True)
rep = []
filters = inplanes
if grow_first:
rep.append(self.relu)
rep.append(
AtrousSeparableConv1d(
inplanes, planes, kernel_size, stride=1, dilation=dilation
)
)
rep.append(nn.BatchNorm1d(planes))
filters = planes
for _ in range(reps - 1):
rep.append(self.relu)
rep.append(
AtrousSeparableConv1d(
filters, filters, kernel_size, stride=1, dilation=dilation
)
)
rep.append(nn.BatchNorm1d(filters))
if not grow_first:
rep.append(self.relu)
rep.append(
AtrousSeparableConv1d(
inplanes, planes, kernel_size, stride=1, dilation=dilation
)
)
rep.append(nn.BatchNorm1d(planes))
if not start_with_relu:
rep = rep[1:]
if stride == 2:
rep.append(AtrousSeparableConv1d(planes, planes, kernel_size, stride=2))
elif stride == 1:
if is_last:
rep.append(AtrousSeparableConv1d(planes, planes, kernel_size, stride=1))
else:
raise NotImplementedError("stride must be 1 or 2 in Block.")
self.rep = nn.Sequential(*rep)
def forward(self, inp):
x = self.rep(inp)
if self.skip is not None:
skip = self.skip(inp)
skip = self.skipbn(skip)
else:
skip = inp
x += skip
return x
class Xception(nn.Module):
"""Modified Aligned Xception"""
def __init__(
self,
inplanes=1,
output_stride=16,
kernel_size=3,
middle_repeat=16,
middle_block_rate=1,
exit_block_rates=(1, 2),
):
super(Xception, self).__init__()
if output_stride == 16:
entry3_stride = 2
elif output_stride == 8:
entry3_stride = 1
else:
raise NotImplementedError
self.conv1 = nn.Conv1d(
inplanes,
32,
kernel_size,
stride=2,
padding=(kernel_size - 1) // 2,
bias=False,
)
self.bn1 = nn.BatchNorm1d(32)
self.relu = nn.ReLU(inplace=True)
self.conv2 = nn.Conv1d(
32, 64, kernel_size, stride=1, padding=(kernel_size - 1) // 2, bias=False
)
self.bn2 = nn.BatchNorm1d(64)
self.entry1 = Block(
64, 128, reps=2, kernel_size=kernel_size, stride=2, start_with_relu=False
)
self.entry2 = Block(
128,
256,
reps=2,
kernel_size=kernel_size,
stride=2,
start_with_relu=True,
grow_first=True,
)
self.entry3 = Block(
256,
728,
reps=2,
kernel_size=kernel_size,
stride=entry3_stride,
start_with_relu=True,
grow_first=True,
is_last=True,
)
self.middle = nn.Sequential(
*[
Block(
728,
728,
reps=3,
kernel_size=kernel_size,
stride=1,
dilation=middle_block_rate,
start_with_relu=True,
grow_first=True,
)
for _ in range(middle_repeat)
]
)
self.exit = Block(
728,
1024,
reps=2,
kernel_size=kernel_size,
stride=1,
dilation=exit_block_rates[0],
start_with_relu=True,
grow_first=False,
is_last=True,
)
self.conv3 = AtrousSeparableConv1d(
1024, 1536, kernel_size, stride=1, dilation=exit_block_rates[1]
)
self.bn3 = nn.BatchNorm1d(1536)
self.conv4 = AtrousSeparableConv1d(
1536, 1536, kernel_size, stride=1, dilation=exit_block_rates[1]
)
self.bn4 = nn.BatchNorm1d(1536)
self.conv5 = AtrousSeparableConv1d(
1536, 2048, kernel_size, stride=1, dilation=exit_block_rates[1]
)
self.bn5 = nn.BatchNorm1d(2048)
def forward(self, x: torch.Tensor):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.conv2(x)
x = self.bn2(x)
x = self.relu(x)
low_level = x = self.entry1(x)
x = self.entry2(x)
x = self.entry3(x)
x = self.middle(x)
x = self.exit(x)
x = self.conv3(x)
x = self.bn3(x)
x = self.relu(x)
x = self.conv4(x)
x = self.bn4(x)
x = self.relu(x)
x = self.conv5(x)
x = self.bn5(x)
x = self.relu(x)
return x, low_level
class ASPP(nn.Module):
"""Atrous Spatial Pyramid Pooling"""
def __init__(self, inplanes, planes, rate, kernel_size=3):
super(ASPP, self).__init__()
if rate == 1:
kernel_size = 1
padding = 0
else:
padding = rate * (kernel_size - 1) // 2
self.atrous_convolution = nn.Conv1d(
inplanes,
planes,
kernel_size=kernel_size,
stride=1,
padding=padding,
dilation=rate,
bias=False,
)
self.bn = nn.BatchNorm1d(planes)
self.relu = nn.ReLU()
def forward(self, x):
x = self.atrous_convolution(x)
x = self.bn(x)
return self.relu(x)
class DeepLabV3Plus(nn.Module):
def __init__(self, config):
super(DeepLabV3Plus, self).__init__()
self.config = config
# output_stride: (input's spatial resolution / output's resolution)
output_stride = int(config.output_stride)
kernel_size = int(config.kernel_size)
middle_block_rate = int(config.middle_block_rate)
exit_block_rates: list = config.exit_block_rates
middle_repeat = int(config.middle_repeat)
self.interpolate_mode = str(config.interpolate_mode)
aspp_channel = int(config.aspp_channel)
aspp_rate: list = config.aspp_rate
output_size = config.output_size # 3(p, qrs, t)
self.xception_features = Xception(
output_stride=output_stride,
kernel_size=kernel_size,
middle_repeat=middle_repeat,
middle_block_rate=middle_block_rate,
exit_block_rates=exit_block_rates,
)
# ASPP
self.aspp1 = ASPP(
2048, aspp_channel, rate=aspp_rate[0], kernel_size=kernel_size
)
self.aspp2 = ASPP(
2048, aspp_channel, rate=aspp_rate[1], kernel_size=kernel_size
)
self.aspp3 = ASPP(
2048, aspp_channel, rate=aspp_rate[2], kernel_size=kernel_size
)
self.aspp4 = ASPP(
2048, aspp_channel, rate=aspp_rate[3], kernel_size=kernel_size
)
self.relu = nn.ReLU()
self.global_avg_pool = nn.Sequential(
nn.AdaptiveAvgPool1d(1),
nn.Conv1d(2048, aspp_channel, 1, stride=1, bias=False),
nn.BatchNorm1d(aspp_channel),
nn.ReLU(),
)
self.conv1 = nn.Conv1d(aspp_channel * 5, aspp_channel, 1, bias=False)
self.bn1 = nn.BatchNorm1d(aspp_channel)
# adopt [1x1, 48] for channel reduction.
self.conv2 = nn.Conv1d(128, 48, 1, bias=False)
self.bn2 = nn.BatchNorm1d(48)
self.last_conv = nn.Sequential(
nn.Conv1d(
aspp_channel + 48,
256,
kernel_size=kernel_size,
stride=1,
padding=(kernel_size - 1) // 2,
bias=False,
),
nn.BatchNorm1d(256),
nn.ReLU(),
nn.Conv1d(
256,
256,
kernel_size=kernel_size,
stride=1,
padding=(kernel_size - 1) // 2,
bias=False,
),
nn.BatchNorm1d(256),
nn.ReLU(),
nn.Conv1d(256, output_size, kernel_size=1, stride=1),
)
def forward(self, input):
x, low_level_features = self.xception_features(input)
x1 = self.aspp1(x)
x2 = self.aspp2(x)
x3 = self.aspp3(x)
x4 = self.aspp4(x)
x5 = self.global_avg_pool(x)
x5 = F.interpolate(x5, size=x4.shape[2:], mode=self.interpolate_mode)
x = torch.cat((x1, x2, x3, x4, x5), dim=1)
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = F.interpolate(
x, size=int(math.ceil(input.shape[-1] / 4)), mode=self.interpolate_mode
)
low_level_features = self.conv2(low_level_features)
low_level_features = self.bn2(low_level_features)
low_level_features = self.relu(low_level_features)
x = torch.cat((x, low_level_features), dim=1)
x = self.last_conv(x)
x = F.interpolate(x, size=input.shape[2:], mode=self.interpolate_mode)
return x