ghlee94's picture
Init
2a13495
import torch
import torch.nn as nn
import torch.nn.functional as F
class ConvBnRelu(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int,
stride: int = 1,
padding: int = 0,
dilation: int = 1,
groups: int = 1,
bias: bool = True,
add_relu: bool = True,
interpolate: bool = False,
):
super(ConvBnRelu, self).__init__()
self.conv = nn.Conv2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
bias=bias,
groups=groups,
)
self.add_relu = add_relu
self.interpolate = interpolate
self.bn = nn.BatchNorm2d(out_channels)
self.activation = nn.ReLU(inplace=True)
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
if self.add_relu:
x = self.activation(x)
if self.interpolate:
x = F.interpolate(x, scale_factor=2, mode="bilinear", align_corners=True)
return x
class FPABlock(nn.Module):
def __init__(self, in_channels, out_channels, upscale_mode="bilinear"):
super(FPABlock, self).__init__()
self.upscale_mode = upscale_mode
if self.upscale_mode == "bilinear":
self.align_corners = True
else:
self.align_corners = False
# global pooling branch
self.branch1 = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
ConvBnRelu(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=1,
stride=1,
padding=0,
),
)
# midddle branch
self.mid = nn.Sequential(
ConvBnRelu(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=1,
stride=1,
padding=0,
)
)
self.down1 = nn.Sequential(
nn.MaxPool2d(kernel_size=2, stride=2),
ConvBnRelu(
in_channels=in_channels,
out_channels=1,
kernel_size=7,
stride=1,
padding=3,
),
)
self.down2 = nn.Sequential(
nn.MaxPool2d(kernel_size=2, stride=2),
ConvBnRelu(
in_channels=1, out_channels=1, kernel_size=5, stride=1, padding=2
),
)
self.down3 = nn.Sequential(
nn.MaxPool2d(kernel_size=2, stride=2),
ConvBnRelu(
in_channels=1, out_channels=1, kernel_size=3, stride=1, padding=1
),
ConvBnRelu(
in_channels=1, out_channels=1, kernel_size=3, stride=1, padding=1
),
)
self.conv2 = ConvBnRelu(
in_channels=1, out_channels=1, kernel_size=5, stride=1, padding=2
)
self.conv1 = ConvBnRelu(
in_channels=1, out_channels=1, kernel_size=7, stride=1, padding=3
)
def forward(self, x):
h, w = x.size(2), x.size(3)
b1 = self.branch1(x)
upscale_parameters = dict(
mode=self.upscale_mode, align_corners=self.align_corners
)
b1 = F.interpolate(b1, size=(h, w), **upscale_parameters)
mid = self.mid(x)
x1 = self.down1(x)
x2 = self.down2(x1)
x3 = self.down3(x2)
x3 = F.interpolate(x3, size=(h // 4, w // 4), **upscale_parameters)
x2 = self.conv2(x2)
x = x2 + x3
x = F.interpolate(x, size=(h // 2, w // 2), **upscale_parameters)
x1 = self.conv1(x1)
x = x + x1
x = F.interpolate(x, size=(h, w), **upscale_parameters)
x = torch.mul(x, mid)
x = x + b1
return x
class GAUBlock(nn.Module):
def __init__(
self, in_channels: int, out_channels: int, upscale_mode: str = "bilinear"
):
super(GAUBlock, self).__init__()
self.upscale_mode = upscale_mode
self.align_corners = True if upscale_mode == "bilinear" else None
self.conv1 = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
ConvBnRelu(
in_channels=out_channels,
out_channels=out_channels,
kernel_size=1,
add_relu=False,
),
nn.Sigmoid(),
)
self.conv2 = ConvBnRelu(
in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1
)
def forward(self, x, y):
"""
Args:
x: low level feature
y: high level feature
"""
h, w = x.size(2), x.size(3)
y_up = F.interpolate(
y, size=(h, w), mode=self.upscale_mode, align_corners=self.align_corners
)
x = self.conv2(x)
y = self.conv1(y)
z = torch.mul(x, y)
return y_up + z
class PANDecoder(nn.Module):
def __init__(
self, encoder_channels, decoder_channels, upscale_mode: str = "bilinear"
):
super().__init__()
self.fpa = FPABlock(
in_channels=encoder_channels[-1], out_channels=decoder_channels
)
self.gau3 = GAUBlock(
in_channels=encoder_channels[-2],
out_channels=decoder_channels,
upscale_mode=upscale_mode,
)
self.gau2 = GAUBlock(
in_channels=encoder_channels[-3],
out_channels=decoder_channels,
upscale_mode=upscale_mode,
)
self.gau1 = GAUBlock(
in_channels=encoder_channels[-4],
out_channels=decoder_channels,
upscale_mode=upscale_mode,
)
def forward(self, *features):
bottleneck = features[-1]
x5 = self.fpa(bottleneck) # 1/32
x4 = self.gau3(features[-2], x5) # 1/16
x3 = self.gau2(features[-3], x4) # 1/8
x2 = self.gau1(features[-4], x3) # 1/4
return x2