import torch import torch.nn as nn import torch.nn.functional as F class ConvBnRelu(nn.Module): def __init__( self, in_channels: int, out_channels: int, kernel_size: int, stride: int = 1, padding: int = 0, dilation: int = 1, groups: int = 1, bias: bool = True, add_relu: bool = True, interpolate: bool = False, ): super(ConvBnRelu, self).__init__() self.conv = nn.Conv2d( in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, bias=bias, groups=groups, ) self.add_relu = add_relu self.interpolate = interpolate self.bn = nn.BatchNorm2d(out_channels) self.activation = nn.ReLU(inplace=True) def forward(self, x): x = self.conv(x) x = self.bn(x) if self.add_relu: x = self.activation(x) if self.interpolate: x = F.interpolate(x, scale_factor=2, mode="bilinear", align_corners=True) return x class FPABlock(nn.Module): def __init__(self, in_channels, out_channels, upscale_mode="bilinear"): super(FPABlock, self).__init__() self.upscale_mode = upscale_mode if self.upscale_mode == "bilinear": self.align_corners = True else: self.align_corners = False # global pooling branch self.branch1 = nn.Sequential( nn.AdaptiveAvgPool2d(1), ConvBnRelu( in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1, padding=0, ), ) # midddle branch self.mid = nn.Sequential( ConvBnRelu( in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1, padding=0, ) ) self.down1 = nn.Sequential( nn.MaxPool2d(kernel_size=2, stride=2), ConvBnRelu( in_channels=in_channels, out_channels=1, kernel_size=7, stride=1, padding=3, ), ) self.down2 = nn.Sequential( nn.MaxPool2d(kernel_size=2, stride=2), ConvBnRelu( in_channels=1, out_channels=1, kernel_size=5, stride=1, padding=2 ), ) self.down3 = nn.Sequential( nn.MaxPool2d(kernel_size=2, stride=2), ConvBnRelu( in_channels=1, out_channels=1, kernel_size=3, stride=1, padding=1 ), ConvBnRelu( in_channels=1, out_channels=1, kernel_size=3, stride=1, padding=1 ), ) self.conv2 = ConvBnRelu( in_channels=1, out_channels=1, kernel_size=5, stride=1, padding=2 ) self.conv1 = ConvBnRelu( in_channels=1, out_channels=1, kernel_size=7, stride=1, padding=3 ) def forward(self, x): h, w = x.size(2), x.size(3) b1 = self.branch1(x) upscale_parameters = dict( mode=self.upscale_mode, align_corners=self.align_corners ) b1 = F.interpolate(b1, size=(h, w), **upscale_parameters) mid = self.mid(x) x1 = self.down1(x) x2 = self.down2(x1) x3 = self.down3(x2) x3 = F.interpolate(x3, size=(h // 4, w // 4), **upscale_parameters) x2 = self.conv2(x2) x = x2 + x3 x = F.interpolate(x, size=(h // 2, w // 2), **upscale_parameters) x1 = self.conv1(x1) x = x + x1 x = F.interpolate(x, size=(h, w), **upscale_parameters) x = torch.mul(x, mid) x = x + b1 return x class GAUBlock(nn.Module): def __init__( self, in_channels: int, out_channels: int, upscale_mode: str = "bilinear" ): super(GAUBlock, self).__init__() self.upscale_mode = upscale_mode self.align_corners = True if upscale_mode == "bilinear" else None self.conv1 = nn.Sequential( nn.AdaptiveAvgPool2d(1), ConvBnRelu( in_channels=out_channels, out_channels=out_channels, kernel_size=1, add_relu=False, ), nn.Sigmoid(), ) self.conv2 = ConvBnRelu( in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1 ) def forward(self, x, y): """ Args: x: low level feature y: high level feature """ h, w = x.size(2), x.size(3) y_up = F.interpolate( y, size=(h, w), mode=self.upscale_mode, align_corners=self.align_corners ) x = self.conv2(x) y = self.conv1(y) z = torch.mul(x, y) return y_up + z class PANDecoder(nn.Module): def __init__( self, encoder_channels, decoder_channels, upscale_mode: str = "bilinear" ): super().__init__() self.fpa = FPABlock( in_channels=encoder_channels[-1], out_channels=decoder_channels ) self.gau3 = GAUBlock( in_channels=encoder_channels[-2], out_channels=decoder_channels, upscale_mode=upscale_mode, ) self.gau2 = GAUBlock( in_channels=encoder_channels[-3], out_channels=decoder_channels, upscale_mode=upscale_mode, ) self.gau1 = GAUBlock( in_channels=encoder_channels[-4], out_channels=decoder_channels, upscale_mode=upscale_mode, ) def forward(self, *features): bottleneck = features[-1] x5 = self.fpa(bottleneck) # 1/32 x4 = self.gau3(features[-2], x5) # 1/16 x3 = self.gau2(features[-3], x4) # 1/8 x2 = self.gau1(features[-4], x3) # 1/4 return x2