Cyril666's picture
First model version
4ea50ff
import torch.nn as nn
import torch.nn.functional as F
class FPA(nn.Module):
def __init__(self, channels=2048):
"""
Feature Pyramid Attention
:type channels: int
"""
super(FPA, self).__init__()
channels_mid = int(channels / 4)
self.channels_cond = channels
# Master branch
self.conv_master = nn.Conv2d(self.channels_cond, channels, kernel_size=1, bias=False)
self.bn_master = nn.BatchNorm2d(channels)
# Global pooling branch
self.conv_gpb = nn.Conv2d(self.channels_cond, channels, kernel_size=1, bias=False)
#self.bn_gpb = nn.BatchNorm2d(channels)
# C333 because of the shape of last feature maps is (16, 16).
self.conv7x7_1 = nn.Conv2d(self.channels_cond, channels_mid, kernel_size=(7, 7), stride=2, padding=3, bias=False)
self.bn1_1 = nn.BatchNorm2d(channels_mid)
self.conv5x5_1 = nn.Conv2d(channels_mid, channels_mid, kernel_size=(5, 5), stride=2, padding=2, bias=False)
self.bn2_1 = nn.BatchNorm2d(channels_mid)
self.conv3x3_1 = nn.Conv2d(channels_mid, channels_mid, kernel_size=(3, 3), stride=2, padding=1, bias=False)
self.bn3_1 = nn.BatchNorm2d(channels_mid)
self.conv7x7_2 = nn.Conv2d(channels_mid, channels_mid, kernel_size=(7, 7), stride=1, padding=3, bias=False)
self.bn1_2 = nn.BatchNorm2d(channels_mid)
self.conv5x5_2 = nn.Conv2d(channels_mid, channels_mid, kernel_size=(5, 5), stride=1, padding=2, bias=False)
self.bn2_2 = nn.BatchNorm2d(channels_mid)
self.conv3x3_2 = nn.Conv2d(channels_mid, channels_mid, kernel_size=(3, 3), stride=1, padding=1, bias=False)
self.bn3_2 = nn.BatchNorm2d(channels_mid)
self.bn_upsample_1 = nn.BatchNorm2d(channels)
self.conv1x1_up1 = nn.Conv2d(channels_mid, channels, kernel_size=(1, 1), stride=1, padding=0, bias=False)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
"""
:param x: Shape: [b, 2048, h, w]
:return: out: Feature maps. Shape: [b, 2048, h, w]
"""
# Master branch
x_master = self.conv_master(x)
x_master = self.bn_master(x_master)
# Global pooling branch
x_gpb = nn.AvgPool2d(x.shape[2:])(x).view(x.shape[0], self.channels_cond, 1, 1)
x_gpb = self.conv_gpb(x_gpb)
#x_gpb = self.bn_gpb(x_gpb)
# Branch 1
x1_1 = self.conv7x7_1(x)
x1_1 = self.bn1_1(x1_1)
x1_1 = self.relu(x1_1)
x1_2 = self.conv7x7_2(x1_1)
x1_2 = self.bn1_2(x1_2)
# Branch 2
x2_1 = self.conv5x5_1(x1_1)
x2_1 = self.bn2_1(x2_1)
x2_1 = self.relu(x2_1)
x2_2 = self.conv5x5_2(x2_1)
x2_2 = self.bn2_2(x2_2)
# Branch 3
x3_1 = self.conv3x3_1(x2_1)
x3_1 = self.bn3_1(x3_1)
x3_1 = self.relu(x3_1)
x3_2 = self.conv3x3_2(x3_1)
x3_2 = self.bn3_2(x3_2)
# Merge branch 1 and 2
x3_upsample = F.upsample(x3_2, size=x2_2.shape[-2:],
mode='bilinear', align_corners=False)
x2_merge = self.relu(x2_2 + x3_upsample)
x2_upsample = F.upsample(x2_merge, size=x1_2.shape[-2:],
mode='bilinear', align_corners=False)
x1_merge = self.relu(x1_2 + x2_upsample)
x1_merge_upsample = F.upsample(x1_merge, size=x_master.shape[-2:],
mode='bilinear', align_corners=False)
x1_merge_upsample_ch = self.relu(self.bn_upsample_1(self.conv1x1_up1(x1_merge_upsample)))
x_master = x_master * x1_merge_upsample_ch
#
out = self.relu(x_master + x_gpb)
return out
class GAU(nn.Module):
def __init__(self, channels_high, channels_low, upsample=True):
super(GAU, self).__init__()
# Global Attention Upsample
self.upsample = upsample
self.conv3x3 = nn.Conv2d(channels_low, channels_low, kernel_size=3, padding=1, bias=False)
self.bn_low = nn.BatchNorm2d(channels_low)
self.conv1x1 = nn.Conv2d(channels_high, channels_low, kernel_size=1, padding=0, bias=False)
#self.bn_high = nn.BatchNorm2d(channels_low)
if upsample:
self.conv_upsample = nn.ConvTranspose2d(channels_high, channels_low, kernel_size=4, stride=2, padding=1, bias=False)
self.bn_upsample = nn.BatchNorm2d(channels_low)
else:
self.conv_reduction = nn.Conv2d(channels_high, channels_low, kernel_size=1, padding=0, bias=False)
self.bn_reduction = nn.BatchNorm2d(channels_low)
self.relu = nn.ReLU(inplace=True)
def forward(self, fms_high, fms_low, fm_mask=None):
"""
Use the high level features with abundant catagory information to weight the low level features with pixel
localization information. In the meantime, we further use mask feature maps with catagory-specific information
to localize the mask position.
:param fms_high: Features of high level. Tensor.
:param fms_low: Features of low level. Tensor.
:param fm_mask:
:return: fms_att_upsample
"""
b, c, h, w = fms_high.shape
fms_high_gp = nn.AvgPool2d(fms_high.shape[2:])(fms_high).view(len(fms_high), c, 1, 1)
fms_high_gp = self.conv1x1(fms_high_gp)
# fms_high_gp = self.bn_high(fms_high_gp)# arlog, when the spatial size HxW = 1x1, the BN cannot be used.
fms_high_gp = self.relu(fms_high_gp)
# fms_low_mask = torch.cat([fms_low, fm_mask], dim=1)
fms_low_mask = self.conv3x3(fms_low)
fms_low_mask = self.bn_low(fms_low_mask)
fms_att = fms_low_mask * fms_high_gp
if self.upsample:
out = self.relu(
self.bn_upsample(self.conv_upsample(fms_high)) + fms_att)
else:
out = self.relu(
self.bn_reduction(self.conv_reduction(fms_high)) + fms_att)
return out
class PAN(nn.Module):
def __init__(self):
"""
:param blocks: Blocks of the network with reverse sequential.
"""
super(PAN, self).__init__()
channels_blocks = [2048, 1024, 512, 256]
self.fpa = FPA(channels=channels_blocks[0])
self.gau_block1 = GAU(channels_blocks[0], channels_blocks[1])
self.gau_block2 = GAU(channels_blocks[1], channels_blocks[2])
self.gau_block3 = GAU(channels_blocks[2], channels_blocks[3])
self.gau = [self.gau_block1, self.gau_block2, self.gau_block3]
def forward(self, fms):
"""
:param fms: Feature maps of forward propagation in the network with reverse sequential. shape:[b, c, h, w]
:return: fm_high. [b, 256, h, w]
"""
feats = []
for i, fm_low in enumerate(fms[::-1]):
if i == 0:
fm_high = self.fpa(fm_low)
else:
fm_high = self.gau[int(i-1)](fm_high, fm_low)
feats.append(fm_high)
feats.reverse()
return tuple(feats)