ghlee94's picture
Init
2a13495
import torch.nn as nn
from segmentation_models_pytorch.base import modules
class TransposeX2(nn.Sequential):
def __init__(self, in_channels, out_channels, use_batchnorm=True):
super().__init__()
layers = [
nn.ConvTranspose2d(
in_channels, out_channels, kernel_size=4, stride=2, padding=1
),
nn.ReLU(inplace=True),
]
if use_batchnorm:
layers.insert(1, nn.BatchNorm2d(out_channels))
super().__init__(*layers)
class DecoderBlock(nn.Module):
def __init__(self, in_channels, out_channels, use_batchnorm=True):
super().__init__()
self.block = nn.Sequential(
modules.Conv2dReLU(
in_channels,
in_channels // 4,
kernel_size=1,
use_batchnorm=use_batchnorm,
),
TransposeX2(
in_channels // 4, in_channels // 4, use_batchnorm=use_batchnorm
),
modules.Conv2dReLU(
in_channels // 4,
out_channels,
kernel_size=1,
use_batchnorm=use_batchnorm,
),
)
def forward(self, x, skip=None):
x = self.block(x)
if skip is not None:
x = x + skip
return x
class LinknetDecoder(nn.Module):
def __init__(
self, encoder_channels, prefinal_channels=32, n_blocks=5, use_batchnorm=True,
):
super().__init__()
# remove first skip
encoder_channels = encoder_channels[1:]
# reverse channels to start from head of encoder
encoder_channels = encoder_channels[::-1]
channels = list(encoder_channels) + [prefinal_channels]
self.blocks = nn.ModuleList(
[
DecoderBlock(channels[i], channels[i + 1], use_batchnorm=use_batchnorm)
for i in range(n_blocks)
]
)
def forward(self, *features):
features = features[1:] # remove first skip
features = features[::-1] # reverse channels to start from head of encoder
x = features[0]
skips = features[1:]
for i, decoder_block in enumerate(self.blocks):
skip = skips[i] if i < len(skips) else None
x = decoder_block(x, skip)
return x