import torch import torch.nn as nn import torch.nn.functional as F from ...base import modules as md class DecoderBlock(nn.Module): def __init__( self, in_channels, skip_channels, out_channels, use_batchnorm=True, attention_type=[None, None], ): super().__init__() self.conv1 = md.Conv3dReLU( in_channels + skip_channels, out_channels, kernel_size=3, padding=1, use_batchnorm=use_batchnorm, ) self.attention1 = md.Attention(attention_type[0], in_channels=in_channels + skip_channels) self.conv2 = md.Conv3dReLU( out_channels, out_channels, kernel_size=3, padding=1, use_batchnorm=use_batchnorm, ) self.attention2 = md.Attention(attention_type[1], in_channels=out_channels) def forward(self, x, skip=None): x = F.interpolate(x, scale_factor=2, mode="nearest") if skip is not None: x = torch.cat([x, skip], dim=1) x = self.attention1(x) x = self.conv1(x) x = self.conv2(x) x = self.attention2(x) return x class CenterBlock(nn.Sequential): def __init__(self, in_channels, out_channels, use_batchnorm=True): conv1 = md.Conv3dReLU( in_channels, out_channels, kernel_size=3, padding=1, use_batchnorm=use_batchnorm, ) conv2 = md.Conv3dReLU( out_channels, out_channels, kernel_size=3, padding=1, use_batchnorm=use_batchnorm, ) super().__init__(conv1, conv2) class UnetDecoder_3D(nn.Module): def __init__( self, encoder_channels, decoder_channels, n_blocks=5, use_batchnorm=True, attention_type=None, center=False, deep_supervision=False, ): super().__init__() if n_blocks != len(decoder_channels): raise ValueError( "Model depth is {}, but you provide `decoder_channels` for {} blocks.".format( n_blocks, len(decoder_channels) ) ) # reverse channels to start from head of encoder encoder_channels = encoder_channels[::-1] # computing blocks input and output channels head_channels = encoder_channels[0] in_channels = [head_channels] + list(decoder_channels[:-1]) skip_channels = list(encoder_channels[1:]) + [0] out_channels = decoder_channels if center: self.center = CenterBlock(head_channels, head_channels, use_batchnorm=use_batchnorm) else: self.center = nn.Identity() self.deep_supervision = deep_supervision # combine decoder keyword arguments kwargs = dict(use_batchnorm=use_batchnorm, attention_type=[attention_type, attention_type]) blocks = [] for block_idx, (in_ch, skip_ch, out_ch) in enumerate(zip(in_channels, skip_channels, out_channels)): # For the last block, attention1 is not used if block_idx == (len(in_channels) - 1): kwargs["attention_type"] = [None, attention_type] blocks.append(DecoderBlock(in_ch, skip_ch, out_ch, **kwargs)) blocks = [ DecoderBlock(in_ch, skip_ch, out_ch, **kwargs) for in_ch, skip_ch, out_ch in zip(in_channels, skip_channels, out_channels) ] self.blocks = nn.ModuleList(blocks) def forward(self, *features): features = features[::-1] # reverse channels to start from head of encoder head = features[0] skips = features[1:] x = self.center(head) if self.deep_supervision and self.training: outputs = [] for i, decoder_block in enumerate(self.blocks): skip = skips[i] if i < len(skips) else None x = decoder_block(x, skip) if self.deep_supervision and self.training: outputs.append(x) if self.deep_supervision and self.training: return outputs return x