Spaces:
Sleeping
Sleeping
""" | |
paper: https://arxiv.org/abs/1612.01105 | |
ref: | |
- https://github.com/hszhao/PSPNet | |
""" | |
import torch | |
from torch import nn | |
from torch.functional import F | |
class PPM(nn.Module): | |
"""Pyramid Pooling Module""" | |
def __init__(self, in_dim, reduction_dim, bins, interplate_mode): | |
super(PPM, self).__init__() | |
self.features = [] | |
for bin in bins: | |
self.features.append( | |
nn.Sequential( | |
nn.AdaptiveAvgPool1d(bin), | |
nn.Conv1d(in_dim, reduction_dim, kernel_size=1, bias=False), | |
nn.BatchNorm1d(reduction_dim), | |
nn.ReLU(), | |
) | |
) | |
self.features = nn.ModuleList(self.features) | |
self.interplate_mode = interplate_mode | |
def forward(self, x: torch.Tensor): | |
x_size = x.size() | |
out = [x] | |
for f in self.features: | |
out.append(F.interpolate(f(x), x_size[2], mode=self.interplate_mode)) | |
return torch.cat(out, dim=1) | |
class Bottleneck(nn.Module): | |
def __init__( | |
self, | |
inplanes, | |
planes, | |
expansion=4, | |
kernel_size=3, | |
stride=1, | |
dilation=1, | |
padding=1, | |
downsample=None, | |
): | |
super(Bottleneck, self).__init__() | |
self.conv1 = nn.Conv1d(inplanes, planes, kernel_size=1, bias=False) | |
self.bn1 = nn.BatchNorm1d(planes) | |
self.conv2 = nn.Conv1d( | |
planes, | |
planes, | |
kernel_size=kernel_size, | |
stride=stride, | |
dilation=dilation, | |
padding=padding, | |
bias=False, | |
) | |
self.bn2 = nn.BatchNorm1d(planes) | |
self.conv3 = nn.Conv1d(planes, planes * expansion, kernel_size=1, bias=False) | |
self.bn3 = nn.BatchNorm1d(planes * expansion) | |
self.relu = nn.ReLU() | |
self.downsample = downsample | |
def forward(self, x): | |
residual = x | |
out = self.conv1(x) | |
out = self.bn1(out) | |
out = self.relu(out) | |
out = self.conv2(out) | |
out = self.bn2(out) | |
out = self.relu(out) | |
out = self.conv3(out) | |
out = self.bn3(out) | |
if self.downsample is not None: | |
residual = self.downsample(x) | |
out += residual | |
out = self.relu(out) | |
return out | |
class PSPNet(nn.Module): | |
def __init__(self, config): | |
super(PSPNet, self).__init__() | |
self.config = config | |
self.kernel_size = int(config.kernel_size) | |
self.padding = (self.kernel_size - 1) // 2 | |
self.expansion = int(config.expansion) | |
self.inplanes = int(config.inplanes) | |
num_layers = int(config.num_layers) | |
self.num_bottlenecks = int(config.num_bottlenecks) | |
self.interpolate_mode = str(config.interpolate_mode) | |
self.dilation = int(config.dilation) | |
ppm_bins: list = config.ppm_bins | |
self.aux_idx = int(config.aux_idx) | |
assert self.aux_idx < num_layers | |
self.aux_ratio = float(config.aux_ratio) | |
dropout = float(config.dropout) | |
output_size = config.output_size # 3(p, qrs, t) | |
# stem λ¨κ³μμ 1/4 λ§νΌ downsample λ μνλ‘ μμ | |
self.stem = nn.Sequential( | |
*[ | |
nn.Conv1d( | |
1, | |
self.inplanes, | |
self.kernel_size, | |
stride=2, | |
padding=self.padding, | |
bias=False, | |
), | |
nn.BatchNorm1d(self.inplanes), | |
nn.ReLU(), | |
nn.MaxPool1d(self.kernel_size, stride=2, padding=self.padding), | |
] | |
) | |
self.layers = [] | |
plane = self.inplanes | |
for i in range(num_layers): | |
self.layers.append(self._make_layer(plane * (2 ** (i)))) | |
self.layers = nn.ModuleList(self.layers) | |
encode_dim = self.inplanes | |
self.ppm = PPM( | |
encode_dim, | |
int(encode_dim / len(ppm_bins)), | |
ppm_bins, | |
self.interpolate_mode, | |
) | |
encode_dim *= 2 | |
self.cls = nn.Sequential( | |
nn.Conv1d( | |
encode_dim, | |
512, | |
kernel_size=self.kernel_size, | |
padding=self.padding, | |
bias=False, | |
), | |
nn.BatchNorm1d(512), | |
nn.ReLU(), | |
nn.Dropout1d(dropout), | |
nn.Conv1d(512, output_size, kernel_size=1), | |
) | |
self.aux_branch = nn.Sequential( | |
# μΆμΆνκ³ μ νλ layer index μ ν΄λΉνλ channel κ³Ό λ§μΆ°μ£Όμ΄μΌ ν¨ | |
nn.Conv1d( | |
plane * self.expansion * (2**self.aux_idx), | |
256, | |
kernel_size=self.kernel_size, | |
padding=self.padding, | |
bias=False, | |
), | |
nn.BatchNorm1d(256), | |
nn.ReLU(), | |
nn.Dropout1d(0.1), | |
nn.Conv1d(256, output_size, kernel_size=1), | |
) | |
def _make_layer(self, planes: int): | |
""" | |
self.num_bottlenecks κ°μ bottleneck μΌλ‘ ꡬμ±λ layer λ₯Ό λ°ν | |
첫λ²μ§Έ bottleneck μμ 2 λ§νΌ downsample λ¨ | |
λλ²μ§Έ μ΄νλΆν°μ bottleneck μμ self.dilation μΌλ‘ dilated conv μν | |
""" | |
downsample = nn.Sequential( | |
nn.Conv1d( | |
self.inplanes, | |
planes * self.expansion, | |
kernel_size=1, | |
stride=2, | |
bias=False, | |
), | |
nn.BatchNorm1d(planes * self.expansion), | |
) | |
bottlenecks = [] | |
bottlenecks.append( | |
Bottleneck( | |
self.inplanes, | |
planes, | |
expansion=self.expansion, | |
kernel_size=self.kernel_size, | |
stride=2, | |
dilation=1, | |
padding=self.padding, | |
downsample=downsample, | |
) | |
) | |
self.inplanes = planes * self.expansion | |
for _ in range(1, self.num_bottlenecks): | |
bottlenecks.append( | |
Bottleneck( | |
self.inplanes, | |
planes, | |
expansion=self.expansion, | |
kernel_size=self.kernel_size, | |
stride=1, | |
dilation=self.dilation, | |
padding=(self.dilation * (self.kernel_size - 1)) // 2, | |
) | |
) | |
return nn.Sequential(*bottlenecks) | |
def forward(self, input: torch.Tensor, y=None): | |
output: torch.Tensor = input | |
output = self.stem(output) | |
for i, _layer in enumerate(self.layers): | |
output = _layer(output) | |
if i == self.aux_idx: | |
aux = output | |
output = self.ppm(output) | |
output = self.cls(output) | |
output = F.interpolate( | |
output, | |
input.shape[2], | |
mode=self.interpolate_mode, | |
) | |
if self.training: | |
aux = self.aux_branch(aux) | |
aux = F.interpolate( | |
aux, | |
input.shape[2], | |
mode=self.interpolate_mode, | |
) | |
return torch.add(output * (1 - self.aux_ratio), aux * self.aux_ratio) | |
else: | |
return output | |