|
import torch.nn as nn |
|
|
|
from segmentation_models_pytorch.base import modules |
|
|
|
|
|
class TransposeX2(nn.Sequential): |
|
def __init__(self, in_channels, out_channels, use_batchnorm=True): |
|
super().__init__() |
|
layers = [ |
|
nn.ConvTranspose2d( |
|
in_channels, out_channels, kernel_size=4, stride=2, padding=1 |
|
), |
|
nn.ReLU(inplace=True), |
|
] |
|
|
|
if use_batchnorm: |
|
layers.insert(1, nn.BatchNorm2d(out_channels)) |
|
|
|
super().__init__(*layers) |
|
|
|
|
|
class DecoderBlock(nn.Module): |
|
def __init__(self, in_channels, out_channels, use_batchnorm=True): |
|
super().__init__() |
|
|
|
self.block = nn.Sequential( |
|
modules.Conv2dReLU( |
|
in_channels, |
|
in_channels // 4, |
|
kernel_size=1, |
|
use_batchnorm=use_batchnorm, |
|
), |
|
TransposeX2( |
|
in_channels // 4, in_channels // 4, use_batchnorm=use_batchnorm |
|
), |
|
modules.Conv2dReLU( |
|
in_channels // 4, |
|
out_channels, |
|
kernel_size=1, |
|
use_batchnorm=use_batchnorm, |
|
), |
|
) |
|
|
|
def forward(self, x, skip=None): |
|
x = self.block(x) |
|
if skip is not None: |
|
x = x + skip |
|
return x |
|
|
|
|
|
class LinknetDecoder(nn.Module): |
|
def __init__( |
|
self, encoder_channels, prefinal_channels=32, n_blocks=5, use_batchnorm=True, |
|
): |
|
super().__init__() |
|
|
|
|
|
encoder_channels = encoder_channels[1:] |
|
|
|
encoder_channels = encoder_channels[::-1] |
|
|
|
channels = list(encoder_channels) + [prefinal_channels] |
|
|
|
self.blocks = nn.ModuleList( |
|
[ |
|
DecoderBlock(channels[i], channels[i + 1], use_batchnorm=use_batchnorm) |
|
for i in range(n_blocks) |
|
] |
|
) |
|
|
|
def forward(self, *features): |
|
features = features[1:] |
|
features = features[::-1] |
|
|
|
x = features[0] |
|
skips = features[1:] |
|
|
|
for i, decoder_block in enumerate(self.blocks): |
|
skip = skips[i] if i < len(skips) else None |
|
x = decoder_block(x, skip) |
|
|
|
return x |
|
|