Spaces:
Sleeping
Sleeping
import torch.nn as nn | |
from feature_extractor_models.base import modules | |
class TransposeX2(nn.Sequential): | |
def __init__(self, in_channels, out_channels, use_batchnorm=True): | |
super().__init__() | |
layers = [ | |
nn.ConvTranspose2d( | |
in_channels, out_channels, kernel_size=4, stride=2, padding=1 | |
), | |
nn.ReLU(inplace=True), | |
] | |
if use_batchnorm: | |
layers.insert(1, nn.BatchNorm2d(out_channels)) | |
super().__init__(*layers) | |
class DecoderBlock(nn.Module): | |
def __init__(self, in_channels, out_channels, use_batchnorm=True): | |
super().__init__() | |
self.block = nn.Sequential( | |
modules.Conv2dReLU( | |
in_channels, | |
in_channels // 4, | |
kernel_size=1, | |
use_batchnorm=use_batchnorm, | |
), | |
TransposeX2( | |
in_channels // 4, in_channels // 4, use_batchnorm=use_batchnorm | |
), | |
modules.Conv2dReLU( | |
in_channels // 4, | |
out_channels, | |
kernel_size=1, | |
use_batchnorm=use_batchnorm, | |
), | |
) | |
def forward(self, x, skip=None): | |
x = self.block(x) | |
if skip is not None: | |
x = x + skip | |
return x | |
class LinknetDecoder(nn.Module): | |
def __init__( | |
self, encoder_channels, prefinal_channels=32, n_blocks=5, use_batchnorm=True | |
): | |
super().__init__() | |
# remove first skip | |
encoder_channels = encoder_channels[1:] | |
# reverse channels to start from head of encoder | |
encoder_channels = encoder_channels[::-1] | |
channels = list(encoder_channels) + [prefinal_channels] | |
self.blocks = nn.ModuleList( | |
[ | |
DecoderBlock(channels[i], channels[i + 1], use_batchnorm=use_batchnorm) | |
for i in range(n_blocks) | |
] | |
) | |
def forward(self, *features): | |
features = features[1:] # remove first skip | |
features = features[::-1] # reverse channels to start from head of encoder | |
x = features[0] | |
skips = features[1:] | |
for i, decoder_block in enumerate(self.blocks): | |
skip = skips[i] if i < len(skips) else None | |
x = decoder_block(x, skip) | |
return x | |