wogh2012's picture
refactor: add implementations
aefacda
"""
paper: https://arxiv.org/abs/1612.01105
ref:
- https://github.com/hszhao/PSPNet
"""
import torch
from torch import nn
from torch.functional import F
class PPM(nn.Module):
"""Pyramid Pooling Module"""
def __init__(self, in_dim, reduction_dim, bins, interplate_mode):
super(PPM, self).__init__()
self.features = []
for bin in bins:
self.features.append(
nn.Sequential(
nn.AdaptiveAvgPool1d(bin),
nn.Conv1d(in_dim, reduction_dim, kernel_size=1, bias=False),
nn.BatchNorm1d(reduction_dim),
nn.ReLU(),
)
)
self.features = nn.ModuleList(self.features)
self.interplate_mode = interplate_mode
def forward(self, x: torch.Tensor):
x_size = x.size()
out = [x]
for f in self.features:
out.append(F.interpolate(f(x), x_size[2], mode=self.interplate_mode))
return torch.cat(out, dim=1)
class Bottleneck(nn.Module):
def __init__(
self,
inplanes,
planes,
expansion=4,
kernel_size=3,
stride=1,
dilation=1,
padding=1,
downsample=None,
):
super(Bottleneck, self).__init__()
self.conv1 = nn.Conv1d(inplanes, planes, kernel_size=1, bias=False)
self.bn1 = nn.BatchNorm1d(planes)
self.conv2 = nn.Conv1d(
planes,
planes,
kernel_size=kernel_size,
stride=stride,
dilation=dilation,
padding=padding,
bias=False,
)
self.bn2 = nn.BatchNorm1d(planes)
self.conv3 = nn.Conv1d(planes, planes * expansion, kernel_size=1, bias=False)
self.bn3 = nn.BatchNorm1d(planes * expansion)
self.relu = nn.ReLU()
self.downsample = downsample
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out
class PSPNet(nn.Module):
def __init__(self, config):
super(PSPNet, self).__init__()
self.config = config
self.kernel_size = int(config.kernel_size)
self.padding = (self.kernel_size - 1) // 2
self.expansion = int(config.expansion)
self.inplanes = int(config.inplanes)
num_layers = int(config.num_layers)
self.num_bottlenecks = int(config.num_bottlenecks)
self.interpolate_mode = str(config.interpolate_mode)
self.dilation = int(config.dilation)
ppm_bins: list = config.ppm_bins
self.aux_idx = int(config.aux_idx)
assert self.aux_idx < num_layers
self.aux_ratio = float(config.aux_ratio)
dropout = float(config.dropout)
output_size = config.output_size # 3(p, qrs, t)
# stem λ‹¨κ³„μ—μ„œ 1/4 만큼 downsample 된 μƒνƒœλ‘œ μ‹œμž‘
self.stem = nn.Sequential(
*[
nn.Conv1d(
1,
self.inplanes,
self.kernel_size,
stride=2,
padding=self.padding,
bias=False,
),
nn.BatchNorm1d(self.inplanes),
nn.ReLU(),
nn.MaxPool1d(self.kernel_size, stride=2, padding=self.padding),
]
)
self.layers = []
plane = self.inplanes
for i in range(num_layers):
self.layers.append(self._make_layer(plane * (2 ** (i))))
self.layers = nn.ModuleList(self.layers)
encode_dim = self.inplanes
self.ppm = PPM(
encode_dim,
int(encode_dim / len(ppm_bins)),
ppm_bins,
self.interpolate_mode,
)
encode_dim *= 2
self.cls = nn.Sequential(
nn.Conv1d(
encode_dim,
512,
kernel_size=self.kernel_size,
padding=self.padding,
bias=False,
),
nn.BatchNorm1d(512),
nn.ReLU(),
nn.Dropout1d(dropout),
nn.Conv1d(512, output_size, kernel_size=1),
)
self.aux_branch = nn.Sequential(
# μΆ”μΆœν•˜κ³ μž ν•˜λŠ” layer index 에 ν•΄λ‹Ήν•˜λŠ” channel κ³Ό λ§žμΆ°μ£Όμ–΄μ•Ό 함
nn.Conv1d(
plane * self.expansion * (2**self.aux_idx),
256,
kernel_size=self.kernel_size,
padding=self.padding,
bias=False,
),
nn.BatchNorm1d(256),
nn.ReLU(),
nn.Dropout1d(0.1),
nn.Conv1d(256, output_size, kernel_size=1),
)
def _make_layer(self, planes: int):
"""
self.num_bottlenecks 개의 bottleneck 으둜 κ΅¬μ„±λœ layer λ₯Ό λ°˜ν™˜
첫번째 bottleneck μ—μ„œ 2 만큼 downsample 됨
λ‘λ²ˆμ§Έ μ΄ν›„λΆ€ν„°μ˜ bottleneck μ—μ„œ self.dilation 으둜 dilated conv μˆ˜ν–‰
"""
downsample = nn.Sequential(
nn.Conv1d(
self.inplanes,
planes * self.expansion,
kernel_size=1,
stride=2,
bias=False,
),
nn.BatchNorm1d(planes * self.expansion),
)
bottlenecks = []
bottlenecks.append(
Bottleneck(
self.inplanes,
planes,
expansion=self.expansion,
kernel_size=self.kernel_size,
stride=2,
dilation=1,
padding=self.padding,
downsample=downsample,
)
)
self.inplanes = planes * self.expansion
for _ in range(1, self.num_bottlenecks):
bottlenecks.append(
Bottleneck(
self.inplanes,
planes,
expansion=self.expansion,
kernel_size=self.kernel_size,
stride=1,
dilation=self.dilation,
padding=(self.dilation * (self.kernel_size - 1)) // 2,
)
)
return nn.Sequential(*bottlenecks)
def forward(self, input: torch.Tensor, y=None):
output: torch.Tensor = input
output = self.stem(output)
for i, _layer in enumerate(self.layers):
output = _layer(output)
if i == self.aux_idx:
aux = output
output = self.ppm(output)
output = self.cls(output)
output = F.interpolate(
output,
input.shape[2],
mode=self.interpolate_mode,
)
if self.training:
aux = self.aux_branch(aux)
aux = F.interpolate(
aux,
input.shape[2],
mode=self.interpolate_mode,
)
return torch.add(output * (1 - self.aux_ratio), aux * self.aux_ratio)
else:
return output