import torch.nn as nn import torch.nn.functional as F class FPA(nn.Module): def __init__(self, channels=2048): """ Feature Pyramid Attention :type channels: int """ super(FPA, self).__init__() channels_mid = int(channels / 4) self.channels_cond = channels # Master branch self.conv_master = nn.Conv2d(self.channels_cond, channels, kernel_size=1, bias=False) self.bn_master = nn.BatchNorm2d(channels) # Global pooling branch self.conv_gpb = nn.Conv2d(self.channels_cond, channels, kernel_size=1, bias=False) #self.bn_gpb = nn.BatchNorm2d(channels) # C333 because of the shape of last feature maps is (16, 16). self.conv7x7_1 = nn.Conv2d(self.channels_cond, channels_mid, kernel_size=(7, 7), stride=2, padding=3, bias=False) self.bn1_1 = nn.BatchNorm2d(channels_mid) self.conv5x5_1 = nn.Conv2d(channels_mid, channels_mid, kernel_size=(5, 5), stride=2, padding=2, bias=False) self.bn2_1 = nn.BatchNorm2d(channels_mid) self.conv3x3_1 = nn.Conv2d(channels_mid, channels_mid, kernel_size=(3, 3), stride=2, padding=1, bias=False) self.bn3_1 = nn.BatchNorm2d(channels_mid) self.conv7x7_2 = nn.Conv2d(channels_mid, channels_mid, kernel_size=(7, 7), stride=1, padding=3, bias=False) self.bn1_2 = nn.BatchNorm2d(channels_mid) self.conv5x5_2 = nn.Conv2d(channels_mid, channels_mid, kernel_size=(5, 5), stride=1, padding=2, bias=False) self.bn2_2 = nn.BatchNorm2d(channels_mid) self.conv3x3_2 = nn.Conv2d(channels_mid, channels_mid, kernel_size=(3, 3), stride=1, padding=1, bias=False) self.bn3_2 = nn.BatchNorm2d(channels_mid) self.bn_upsample_1 = nn.BatchNorm2d(channels) self.conv1x1_up1 = nn.Conv2d(channels_mid, channels, kernel_size=(1, 1), stride=1, padding=0, bias=False) self.relu = nn.ReLU(inplace=True) def forward(self, x): """ :param x: Shape: [b, 2048, h, w] :return: out: Feature maps. Shape: [b, 2048, h, w] """ # Master branch x_master = self.conv_master(x) x_master = self.bn_master(x_master) # Global pooling branch x_gpb = nn.AvgPool2d(x.shape[2:])(x).view(x.shape[0], self.channels_cond, 1, 1) x_gpb = self.conv_gpb(x_gpb) #x_gpb = self.bn_gpb(x_gpb) # Branch 1 x1_1 = self.conv7x7_1(x) x1_1 = self.bn1_1(x1_1) x1_1 = self.relu(x1_1) x1_2 = self.conv7x7_2(x1_1) x1_2 = self.bn1_2(x1_2) # Branch 2 x2_1 = self.conv5x5_1(x1_1) x2_1 = self.bn2_1(x2_1) x2_1 = self.relu(x2_1) x2_2 = self.conv5x5_2(x2_1) x2_2 = self.bn2_2(x2_2) # Branch 3 x3_1 = self.conv3x3_1(x2_1) x3_1 = self.bn3_1(x3_1) x3_1 = self.relu(x3_1) x3_2 = self.conv3x3_2(x3_1) x3_2 = self.bn3_2(x3_2) # Merge branch 1 and 2 x3_upsample = F.upsample(x3_2, size=x2_2.shape[-2:], mode='bilinear', align_corners=False) x2_merge = self.relu(x2_2 + x3_upsample) x2_upsample = F.upsample(x2_merge, size=x1_2.shape[-2:], mode='bilinear', align_corners=False) x1_merge = self.relu(x1_2 + x2_upsample) x1_merge_upsample = F.upsample(x1_merge, size=x_master.shape[-2:], mode='bilinear', align_corners=False) x1_merge_upsample_ch = self.relu(self.bn_upsample_1(self.conv1x1_up1(x1_merge_upsample))) x_master = x_master * x1_merge_upsample_ch # out = self.relu(x_master + x_gpb) return out class GAU(nn.Module): def __init__(self, channels_high, channels_low, upsample=True): super(GAU, self).__init__() # Global Attention Upsample self.upsample = upsample self.conv3x3 = nn.Conv2d(channels_low, channels_low, kernel_size=3, padding=1, bias=False) self.bn_low = nn.BatchNorm2d(channels_low) self.conv1x1 = nn.Conv2d(channels_high, channels_low, kernel_size=1, padding=0, bias=False) #self.bn_high = nn.BatchNorm2d(channels_low) if upsample: self.conv_upsample = nn.ConvTranspose2d(channels_high, channels_low, kernel_size=4, stride=2, padding=1, bias=False) self.bn_upsample = nn.BatchNorm2d(channels_low) else: self.conv_reduction = nn.Conv2d(channels_high, channels_low, kernel_size=1, padding=0, bias=False) self.bn_reduction = nn.BatchNorm2d(channels_low) self.relu = nn.ReLU(inplace=True) def forward(self, fms_high, fms_low, fm_mask=None): """ Use the high level features with abundant catagory information to weight the low level features with pixel localization information. In the meantime, we further use mask feature maps with catagory-specific information to localize the mask position. :param fms_high: Features of high level. Tensor. :param fms_low: Features of low level. Tensor. :param fm_mask: :return: fms_att_upsample """ b, c, h, w = fms_high.shape fms_high_gp = nn.AvgPool2d(fms_high.shape[2:])(fms_high).view(len(fms_high), c, 1, 1) fms_high_gp = self.conv1x1(fms_high_gp) # fms_high_gp = self.bn_high(fms_high_gp)# arlog, when the spatial size HxW = 1x1, the BN cannot be used. fms_high_gp = self.relu(fms_high_gp) # fms_low_mask = torch.cat([fms_low, fm_mask], dim=1) fms_low_mask = self.conv3x3(fms_low) fms_low_mask = self.bn_low(fms_low_mask) fms_att = fms_low_mask * fms_high_gp if self.upsample: out = self.relu( self.bn_upsample(self.conv_upsample(fms_high)) + fms_att) else: out = self.relu( self.bn_reduction(self.conv_reduction(fms_high)) + fms_att) return out class PAN(nn.Module): def __init__(self): """ :param blocks: Blocks of the network with reverse sequential. """ super(PAN, self).__init__() channels_blocks = [2048, 1024, 512, 256] self.fpa = FPA(channels=channels_blocks[0]) self.gau_block1 = GAU(channels_blocks[0], channels_blocks[1]) self.gau_block2 = GAU(channels_blocks[1], channels_blocks[2]) self.gau_block3 = GAU(channels_blocks[2], channels_blocks[3]) self.gau = [self.gau_block1, self.gau_block2, self.gau_block3] def forward(self, fms): """ :param fms: Feature maps of forward propagation in the network with reverse sequential. shape:[b, c, h, w] :return: fm_high. [b, 256, h, w] """ feats = [] for i, fm_low in enumerate(fms[::-1]): if i == 0: fm_high = self.fpa(fm_low) else: fm_high = self.gau[int(i-1)](fm_high, fm_low) feats.append(fm_high) feats.reverse() return tuple(feats)