""" paper: https://arxiv.org/abs/1802.02611 ref: - https://github.com/tensorflow/models/tree/master/research/deeplab - https://github.com/VainF/DeepLabV3Plus-Pytorch - https://github.com/Hyunjulie/KR-Reading-Computer-Vision-Papers/blob/master/DeepLabv3%2B/deeplabv3p.py """ import math import torch from torch import nn from torch.functional import F class AtrousSeparableConv1d(nn.Module): def __init__( self, inplanes, planes, kernel_size=3, stride=1, dilation=1, bias=False ): super(AtrousSeparableConv1d, self).__init__() self.depthwise = nn.Conv1d( inplanes, inplanes, kernel_size, stride, 0, dilation, groups=inplanes, bias=bias, ) self.pointwise = nn.Conv1d(inplanes, planes, 1, 1, 0, 1, 1, bias=bias) def forward(self, x): x = self.apply_fixed_padding( x, self.depthwise.kernel_size[0], rate=self.depthwise.dilation[0] ) x = self.depthwise(x) x = self.pointwise(x) return x def apply_fixed_padding(self, inputs, kernel_size, rate): """ 해당 함수는 (dilation)rate 와 kernel_size 에 따라 output 의 크기가 input 의 크기와 동일해질 수 있도록 input 에 padding 을 적용합니다. 다만, stride 가 2 이상인 경우에는 해당 함수를 거치더라도 input 과 output 크기가 동일해지지 않을 수 있습니다. 이 경우는 최대한 input 과 output 크기를 맞춰주는 것에 의미가 있고, 전체 네트워크의 마지막 upsample 단계에서 최종적으로 크기를 맞춰줍니다. """ kernel_size_effective = kernel_size + (kernel_size - 1) * (rate - 1) pad_total = kernel_size_effective - 1 pad_beg = pad_total // 2 pad_end = pad_total - pad_beg padded_inputs = F.pad(inputs, (pad_beg, pad_end)) return padded_inputs class Block(nn.Module): def __init__( self, inplanes, planes, reps, kernel_size=3, stride=1, dilation=1, start_with_relu=True, grow_first=True, is_last=False, ): super(Block, self).__init__() if planes != inplanes or stride != 1: self.skip = nn.Conv1d(inplanes, planes, 1, stride=stride, bias=False) self.skipbn = nn.BatchNorm1d(planes) else: self.skip = None self.relu = nn.ReLU(inplace=True) rep = [] filters = inplanes if grow_first: rep.append(self.relu) rep.append( AtrousSeparableConv1d( inplanes, planes, kernel_size, stride=1, dilation=dilation ) ) rep.append(nn.BatchNorm1d(planes)) filters = planes for _ in range(reps - 1): rep.append(self.relu) rep.append( AtrousSeparableConv1d( filters, filters, kernel_size, stride=1, dilation=dilation ) ) rep.append(nn.BatchNorm1d(filters)) if not grow_first: rep.append(self.relu) rep.append( AtrousSeparableConv1d( inplanes, planes, kernel_size, stride=1, dilation=dilation ) ) rep.append(nn.BatchNorm1d(planes)) if not start_with_relu: rep = rep[1:] if stride == 2: rep.append(AtrousSeparableConv1d(planes, planes, kernel_size, stride=2)) elif stride == 1: if is_last: rep.append(AtrousSeparableConv1d(planes, planes, kernel_size, stride=1)) else: raise NotImplementedError("stride must be 1 or 2 in Block.") self.rep = nn.Sequential(*rep) def forward(self, inp): x = self.rep(inp) if self.skip is not None: skip = self.skip(inp) skip = self.skipbn(skip) else: skip = inp x += skip return x class Xception(nn.Module): """Modified Aligned Xception""" def __init__( self, inplanes=1, output_stride=16, kernel_size=3, middle_repeat=16, middle_block_rate=1, exit_block_rates=(1, 2), ): super(Xception, self).__init__() if output_stride == 16: entry3_stride = 2 elif output_stride == 8: entry3_stride = 1 else: raise NotImplementedError self.conv1 = nn.Conv1d( inplanes, 32, kernel_size, stride=2, padding=(kernel_size - 1) // 2, bias=False, ) self.bn1 = nn.BatchNorm1d(32) self.relu = nn.ReLU(inplace=True) self.conv2 = nn.Conv1d( 32, 64, kernel_size, stride=1, padding=(kernel_size - 1) // 2, bias=False ) self.bn2 = nn.BatchNorm1d(64) self.entry1 = Block( 64, 128, reps=2, kernel_size=kernel_size, stride=2, start_with_relu=False ) self.entry2 = Block( 128, 256, reps=2, kernel_size=kernel_size, stride=2, start_with_relu=True, grow_first=True, ) self.entry3 = Block( 256, 728, reps=2, kernel_size=kernel_size, stride=entry3_stride, start_with_relu=True, grow_first=True, is_last=True, ) self.middle = nn.Sequential( *[ Block( 728, 728, reps=3, kernel_size=kernel_size, stride=1, dilation=middle_block_rate, start_with_relu=True, grow_first=True, ) for _ in range(middle_repeat) ] ) self.exit = Block( 728, 1024, reps=2, kernel_size=kernel_size, stride=1, dilation=exit_block_rates[0], start_with_relu=True, grow_first=False, is_last=True, ) self.conv3 = AtrousSeparableConv1d( 1024, 1536, kernel_size, stride=1, dilation=exit_block_rates[1] ) self.bn3 = nn.BatchNorm1d(1536) self.conv4 = AtrousSeparableConv1d( 1536, 1536, kernel_size, stride=1, dilation=exit_block_rates[1] ) self.bn4 = nn.BatchNorm1d(1536) self.conv5 = AtrousSeparableConv1d( 1536, 2048, kernel_size, stride=1, dilation=exit_block_rates[1] ) self.bn5 = nn.BatchNorm1d(2048) def forward(self, x: torch.Tensor): x = self.conv1(x) x = self.bn1(x) x = self.relu(x) x = self.conv2(x) x = self.bn2(x) x = self.relu(x) low_level = x = self.entry1(x) x = self.entry2(x) x = self.entry3(x) x = self.middle(x) x = self.exit(x) x = self.conv3(x) x = self.bn3(x) x = self.relu(x) x = self.conv4(x) x = self.bn4(x) x = self.relu(x) x = self.conv5(x) x = self.bn5(x) x = self.relu(x) return x, low_level class ASPP(nn.Module): """Atrous Spatial Pyramid Pooling""" def __init__(self, inplanes, planes, rate, kernel_size=3): super(ASPP, self).__init__() if rate == 1: kernel_size = 1 padding = 0 else: padding = rate * (kernel_size - 1) // 2 self.atrous_convolution = nn.Conv1d( inplanes, planes, kernel_size=kernel_size, stride=1, padding=padding, dilation=rate, bias=False, ) self.bn = nn.BatchNorm1d(planes) self.relu = nn.ReLU() def forward(self, x): x = self.atrous_convolution(x) x = self.bn(x) return self.relu(x) class DeepLabV3Plus(nn.Module): def __init__(self, config): super(DeepLabV3Plus, self).__init__() self.config = config # output_stride: (input's spatial resolution / output's resolution) output_stride = int(config.output_stride) kernel_size = int(config.kernel_size) middle_block_rate = int(config.middle_block_rate) exit_block_rates: list = config.exit_block_rates middle_repeat = int(config.middle_repeat) self.interpolate_mode = str(config.interpolate_mode) aspp_channel = int(config.aspp_channel) aspp_rate: list = config.aspp_rate output_size = config.output_size # 3(p, qrs, t) self.xception_features = Xception( output_stride=output_stride, kernel_size=kernel_size, middle_repeat=middle_repeat, middle_block_rate=middle_block_rate, exit_block_rates=exit_block_rates, ) # ASPP self.aspp1 = ASPP( 2048, aspp_channel, rate=aspp_rate[0], kernel_size=kernel_size ) self.aspp2 = ASPP( 2048, aspp_channel, rate=aspp_rate[1], kernel_size=kernel_size ) self.aspp3 = ASPP( 2048, aspp_channel, rate=aspp_rate[2], kernel_size=kernel_size ) self.aspp4 = ASPP( 2048, aspp_channel, rate=aspp_rate[3], kernel_size=kernel_size ) self.relu = nn.ReLU() self.global_avg_pool = nn.Sequential( nn.AdaptiveAvgPool1d(1), nn.Conv1d(2048, aspp_channel, 1, stride=1, bias=False), nn.BatchNorm1d(aspp_channel), nn.ReLU(), ) self.conv1 = nn.Conv1d(aspp_channel * 5, aspp_channel, 1, bias=False) self.bn1 = nn.BatchNorm1d(aspp_channel) # adopt [1x1, 48] for channel reduction. self.conv2 = nn.Conv1d(128, 48, 1, bias=False) self.bn2 = nn.BatchNorm1d(48) self.last_conv = nn.Sequential( nn.Conv1d( aspp_channel + 48, 256, kernel_size=kernel_size, stride=1, padding=(kernel_size - 1) // 2, bias=False, ), nn.BatchNorm1d(256), nn.ReLU(), nn.Conv1d( 256, 256, kernel_size=kernel_size, stride=1, padding=(kernel_size - 1) // 2, bias=False, ), nn.BatchNorm1d(256), nn.ReLU(), nn.Conv1d(256, output_size, kernel_size=1, stride=1), ) def forward(self, input): x, low_level_features = self.xception_features(input) x1 = self.aspp1(x) x2 = self.aspp2(x) x3 = self.aspp3(x) x4 = self.aspp4(x) x5 = self.global_avg_pool(x) x5 = F.interpolate(x5, size=x4.shape[2:], mode=self.interpolate_mode) x = torch.cat((x1, x2, x3, x4, x5), dim=1) x = self.conv1(x) x = self.bn1(x) x = self.relu(x) x = F.interpolate( x, size=int(math.ceil(input.shape[-1] / 4)), mode=self.interpolate_mode ) low_level_features = self.conv2(low_level_features) low_level_features = self.bn2(low_level_features) low_level_features = self.relu(low_level_features) x = torch.cat((x, low_level_features), dim=1) x = self.last_conv(x) x = F.interpolate(x, size=input.shape[2:], mode=self.interpolate_mode) return x