""" paper: https://arxiv.org/abs/1612.01105 ref: - https://github.com/hszhao/PSPNet """ import torch from torch import nn from torch.functional import F class PPM(nn.Module): """Pyramid Pooling Module""" def __init__(self, in_dim, reduction_dim, bins, interplate_mode): super(PPM, self).__init__() self.features = [] for bin in bins: self.features.append( nn.Sequential( nn.AdaptiveAvgPool1d(bin), nn.Conv1d(in_dim, reduction_dim, kernel_size=1, bias=False), nn.BatchNorm1d(reduction_dim), nn.ReLU(), ) ) self.features = nn.ModuleList(self.features) self.interplate_mode = interplate_mode def forward(self, x: torch.Tensor): x_size = x.size() out = [x] for f in self.features: out.append(F.interpolate(f(x), x_size[2], mode=self.interplate_mode)) return torch.cat(out, dim=1) class Bottleneck(nn.Module): def __init__( self, inplanes, planes, expansion=4, kernel_size=3, stride=1, dilation=1, padding=1, downsample=None, ): super(Bottleneck, self).__init__() self.conv1 = nn.Conv1d(inplanes, planes, kernel_size=1, bias=False) self.bn1 = nn.BatchNorm1d(planes) self.conv2 = nn.Conv1d( planes, planes, kernel_size=kernel_size, stride=stride, dilation=dilation, padding=padding, bias=False, ) self.bn2 = nn.BatchNorm1d(planes) self.conv3 = nn.Conv1d(planes, planes * expansion, kernel_size=1, bias=False) self.bn3 = nn.BatchNorm1d(planes * expansion) self.relu = nn.ReLU() self.downsample = downsample def forward(self, x): residual = x out = self.conv1(x) out = self.bn1(out) out = self.relu(out) out = self.conv2(out) out = self.bn2(out) out = self.relu(out) out = self.conv3(out) out = self.bn3(out) if self.downsample is not None: residual = self.downsample(x) out += residual out = self.relu(out) return out class PSPNet(nn.Module): def __init__(self, config): super(PSPNet, self).__init__() self.config = config self.kernel_size = int(config.kernel_size) self.padding = (self.kernel_size - 1) // 2 self.expansion = int(config.expansion) self.inplanes = int(config.inplanes) num_layers = int(config.num_layers) self.num_bottlenecks = int(config.num_bottlenecks) self.interpolate_mode = str(config.interpolate_mode) self.dilation = int(config.dilation) ppm_bins: list = config.ppm_bins self.aux_idx = int(config.aux_idx) assert self.aux_idx < num_layers self.aux_ratio = float(config.aux_ratio) dropout = float(config.dropout) output_size = config.output_size # 3(p, qrs, t) # stem 단계에서 1/4 만큼 downsample 된 상태로 시작 self.stem = nn.Sequential( *[ nn.Conv1d( 1, self.inplanes, self.kernel_size, stride=2, padding=self.padding, bias=False, ), nn.BatchNorm1d(self.inplanes), nn.ReLU(), nn.MaxPool1d(self.kernel_size, stride=2, padding=self.padding), ] ) self.layers = [] plane = self.inplanes for i in range(num_layers): self.layers.append(self._make_layer(plane * (2 ** (i)))) self.layers = nn.ModuleList(self.layers) encode_dim = self.inplanes self.ppm = PPM( encode_dim, int(encode_dim / len(ppm_bins)), ppm_bins, self.interpolate_mode, ) encode_dim *= 2 self.cls = nn.Sequential( nn.Conv1d( encode_dim, 512, kernel_size=self.kernel_size, padding=self.padding, bias=False, ), nn.BatchNorm1d(512), nn.ReLU(), nn.Dropout1d(dropout), nn.Conv1d(512, output_size, kernel_size=1), ) self.aux_branch = nn.Sequential( # 추출하고자 하는 layer index 에 해당하는 channel 과 맞춰주어야 함 nn.Conv1d( plane * self.expansion * (2**self.aux_idx), 256, kernel_size=self.kernel_size, padding=self.padding, bias=False, ), nn.BatchNorm1d(256), nn.ReLU(), nn.Dropout1d(0.1), nn.Conv1d(256, output_size, kernel_size=1), ) def _make_layer(self, planes: int): """ self.num_bottlenecks 개의 bottleneck 으로 구성된 layer 를 반환 첫번째 bottleneck 에서 2 만큼 downsample 됨 두번째 이후부터의 bottleneck 에서 self.dilation 으로 dilated conv 수행 """ downsample = nn.Sequential( nn.Conv1d( self.inplanes, planes * self.expansion, kernel_size=1, stride=2, bias=False, ), nn.BatchNorm1d(planes * self.expansion), ) bottlenecks = [] bottlenecks.append( Bottleneck( self.inplanes, planes, expansion=self.expansion, kernel_size=self.kernel_size, stride=2, dilation=1, padding=self.padding, downsample=downsample, ) ) self.inplanes = planes * self.expansion for _ in range(1, self.num_bottlenecks): bottlenecks.append( Bottleneck( self.inplanes, planes, expansion=self.expansion, kernel_size=self.kernel_size, stride=1, dilation=self.dilation, padding=(self.dilation * (self.kernel_size - 1)) // 2, ) ) return nn.Sequential(*bottlenecks) def forward(self, input: torch.Tensor, y=None): output: torch.Tensor = input output = self.stem(output) for i, _layer in enumerate(self.layers): output = _layer(output) if i == self.aux_idx: aux = output output = self.ppm(output) output = self.cls(output) output = F.interpolate( output, input.shape[2], mode=self.interpolate_mode, ) if self.training: aux = self.aux_branch(aux) aux = F.interpolate( aux, input.shape[2], mode=self.interpolate_mode, ) return torch.add(output * (1 - self.aux_ratio), aux * self.aux_ratio) else: return output