Spaces:
Runtime error
Runtime error
import torch.nn as nn | |
import torch.nn.functional as F | |
class FPA(nn.Module): | |
def __init__(self, channels=2048): | |
""" | |
Feature Pyramid Attention | |
:type channels: int | |
""" | |
super(FPA, self).__init__() | |
channels_mid = int(channels / 4) | |
self.channels_cond = channels | |
# Master branch | |
self.conv_master = nn.Conv2d(self.channels_cond, channels, kernel_size=1, bias=False) | |
self.bn_master = nn.BatchNorm2d(channels) | |
# Global pooling branch | |
self.conv_gpb = nn.Conv2d(self.channels_cond, channels, kernel_size=1, bias=False) | |
#self.bn_gpb = nn.BatchNorm2d(channels) | |
# C333 because of the shape of last feature maps is (16, 16). | |
self.conv7x7_1 = nn.Conv2d(self.channels_cond, channels_mid, kernel_size=(7, 7), stride=2, padding=3, bias=False) | |
self.bn1_1 = nn.BatchNorm2d(channels_mid) | |
self.conv5x5_1 = nn.Conv2d(channels_mid, channels_mid, kernel_size=(5, 5), stride=2, padding=2, bias=False) | |
self.bn2_1 = nn.BatchNorm2d(channels_mid) | |
self.conv3x3_1 = nn.Conv2d(channels_mid, channels_mid, kernel_size=(3, 3), stride=2, padding=1, bias=False) | |
self.bn3_1 = nn.BatchNorm2d(channels_mid) | |
self.conv7x7_2 = nn.Conv2d(channels_mid, channels_mid, kernel_size=(7, 7), stride=1, padding=3, bias=False) | |
self.bn1_2 = nn.BatchNorm2d(channels_mid) | |
self.conv5x5_2 = nn.Conv2d(channels_mid, channels_mid, kernel_size=(5, 5), stride=1, padding=2, bias=False) | |
self.bn2_2 = nn.BatchNorm2d(channels_mid) | |
self.conv3x3_2 = nn.Conv2d(channels_mid, channels_mid, kernel_size=(3, 3), stride=1, padding=1, bias=False) | |
self.bn3_2 = nn.BatchNorm2d(channels_mid) | |
self.bn_upsample_1 = nn.BatchNorm2d(channels) | |
self.conv1x1_up1 = nn.Conv2d(channels_mid, channels, kernel_size=(1, 1), stride=1, padding=0, bias=False) | |
self.relu = nn.ReLU(inplace=True) | |
def forward(self, x): | |
""" | |
:param x: Shape: [b, 2048, h, w] | |
:return: out: Feature maps. Shape: [b, 2048, h, w] | |
""" | |
# Master branch | |
x_master = self.conv_master(x) | |
x_master = self.bn_master(x_master) | |
# Global pooling branch | |
x_gpb = nn.AvgPool2d(x.shape[2:])(x).view(x.shape[0], self.channels_cond, 1, 1) | |
x_gpb = self.conv_gpb(x_gpb) | |
#x_gpb = self.bn_gpb(x_gpb) | |
# Branch 1 | |
x1_1 = self.conv7x7_1(x) | |
x1_1 = self.bn1_1(x1_1) | |
x1_1 = self.relu(x1_1) | |
x1_2 = self.conv7x7_2(x1_1) | |
x1_2 = self.bn1_2(x1_2) | |
# Branch 2 | |
x2_1 = self.conv5x5_1(x1_1) | |
x2_1 = self.bn2_1(x2_1) | |
x2_1 = self.relu(x2_1) | |
x2_2 = self.conv5x5_2(x2_1) | |
x2_2 = self.bn2_2(x2_2) | |
# Branch 3 | |
x3_1 = self.conv3x3_1(x2_1) | |
x3_1 = self.bn3_1(x3_1) | |
x3_1 = self.relu(x3_1) | |
x3_2 = self.conv3x3_2(x3_1) | |
x3_2 = self.bn3_2(x3_2) | |
# Merge branch 1 and 2 | |
x3_upsample = F.upsample(x3_2, size=x2_2.shape[-2:], | |
mode='bilinear', align_corners=False) | |
x2_merge = self.relu(x2_2 + x3_upsample) | |
x2_upsample = F.upsample(x2_merge, size=x1_2.shape[-2:], | |
mode='bilinear', align_corners=False) | |
x1_merge = self.relu(x1_2 + x2_upsample) | |
x1_merge_upsample = F.upsample(x1_merge, size=x_master.shape[-2:], | |
mode='bilinear', align_corners=False) | |
x1_merge_upsample_ch = self.relu(self.bn_upsample_1(self.conv1x1_up1(x1_merge_upsample))) | |
x_master = x_master * x1_merge_upsample_ch | |
# | |
out = self.relu(x_master + x_gpb) | |
return out | |
class GAU(nn.Module): | |
def __init__(self, channels_high, channels_low, upsample=True): | |
super(GAU, self).__init__() | |
# Global Attention Upsample | |
self.upsample = upsample | |
self.conv3x3 = nn.Conv2d(channels_low, channels_low, kernel_size=3, padding=1, bias=False) | |
self.bn_low = nn.BatchNorm2d(channels_low) | |
self.conv1x1 = nn.Conv2d(channels_high, channels_low, kernel_size=1, padding=0, bias=False) | |
#self.bn_high = nn.BatchNorm2d(channels_low) | |
if upsample: | |
self.conv_upsample = nn.ConvTranspose2d(channels_high, channels_low, kernel_size=4, stride=2, padding=1, bias=False) | |
self.bn_upsample = nn.BatchNorm2d(channels_low) | |
else: | |
self.conv_reduction = nn.Conv2d(channels_high, channels_low, kernel_size=1, padding=0, bias=False) | |
self.bn_reduction = nn.BatchNorm2d(channels_low) | |
self.relu = nn.ReLU(inplace=True) | |
def forward(self, fms_high, fms_low, fm_mask=None): | |
""" | |
Use the high level features with abundant catagory information to weight the low level features with pixel | |
localization information. In the meantime, we further use mask feature maps with catagory-specific information | |
to localize the mask position. | |
:param fms_high: Features of high level. Tensor. | |
:param fms_low: Features of low level. Tensor. | |
:param fm_mask: | |
:return: fms_att_upsample | |
""" | |
b, c, h, w = fms_high.shape | |
fms_high_gp = nn.AvgPool2d(fms_high.shape[2:])(fms_high).view(len(fms_high), c, 1, 1) | |
fms_high_gp = self.conv1x1(fms_high_gp) | |
# fms_high_gp = self.bn_high(fms_high_gp)# arlog, when the spatial size HxW = 1x1, the BN cannot be used. | |
fms_high_gp = self.relu(fms_high_gp) | |
# fms_low_mask = torch.cat([fms_low, fm_mask], dim=1) | |
fms_low_mask = self.conv3x3(fms_low) | |
fms_low_mask = self.bn_low(fms_low_mask) | |
fms_att = fms_low_mask * fms_high_gp | |
if self.upsample: | |
out = self.relu( | |
self.bn_upsample(self.conv_upsample(fms_high)) + fms_att) | |
else: | |
out = self.relu( | |
self.bn_reduction(self.conv_reduction(fms_high)) + fms_att) | |
return out | |
class PAN(nn.Module): | |
def __init__(self): | |
""" | |
:param blocks: Blocks of the network with reverse sequential. | |
""" | |
super(PAN, self).__init__() | |
channels_blocks = [2048, 1024, 512, 256] | |
self.fpa = FPA(channels=channels_blocks[0]) | |
self.gau_block1 = GAU(channels_blocks[0], channels_blocks[1]) | |
self.gau_block2 = GAU(channels_blocks[1], channels_blocks[2]) | |
self.gau_block3 = GAU(channels_blocks[2], channels_blocks[3]) | |
self.gau = [self.gau_block1, self.gau_block2, self.gau_block3] | |
def forward(self, fms): | |
""" | |
:param fms: Feature maps of forward propagation in the network with reverse sequential. shape:[b, c, h, w] | |
:return: fm_high. [b, 256, h, w] | |
""" | |
feats = [] | |
for i, fm_low in enumerate(fms[::-1]): | |
if i == 0: | |
fm_high = self.fpa(fm_low) | |
else: | |
fm_high = self.gau[int(i-1)](fm_high, fm_low) | |
feats.append(fm_high) | |
feats.reverse() | |
return tuple(feats) | |