import torch import torch.nn as nn import torch.nn.functional as F from attention import ChannelAttention, SpatialAttention, DualCrossModalAttention from srm_conv import SRMConv2d_simple, SRMConv2d_Separate from xception import TransferModel class SRMPixelAttention(nn.Module): def __init__(self, in_channels): super(SRMPixelAttention, self).__init__() # self.srm = SRMConv2d_simple() self.conv = nn.Sequential( nn.Conv2d(in_channels, 32, 3, 2, 0, bias=False), nn.BatchNorm2d(32), nn.ReLU(inplace=True), nn.Conv2d(32, 64, 3, bias=False), nn.BatchNorm2d(64), nn.ReLU(inplace=True), ) self.pa = SpatialAttention() for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, a=1) if not m.bias is None: nn.init.constant_(m.bias, 0) def forward(self, x_srm): # x_srm = self.srm(x) fea = self.conv(x_srm) att_map = self.pa(fea) return att_map class FeatureFusionModule(nn.Module): def __init__(self, in_chan=2048*2, out_chan=2048, *args, **kwargs): super(FeatureFusionModule, self).__init__() self.convblk = nn.Sequential( nn.Conv2d(in_chan, out_chan, 1, 1, 0, bias=False), nn.BatchNorm2d(out_chan), nn.ReLU() ) self.ca = ChannelAttention(out_chan, ratio=16) self.init_weight() def forward(self, x, y): fuse_fea = self.convblk(torch.cat((x, y), dim=1)) fuse_fea = fuse_fea + fuse_fea * self.ca(fuse_fea) return fuse_fea def init_weight(self): for ly in self.children(): if isinstance(ly, nn.Conv2d): nn.init.kaiming_normal_(ly.weight, a=1) if not ly.bias is None: nn.init.constant_(ly.bias, 0) class Two_Stream_Net(nn.Module): def __init__(self): super().__init__() self.xception_rgb = TransferModel( 'xception', dropout=0.5, inc=3, return_fea=True) self.xception_srm = TransferModel( 'xception', dropout=0.5, inc=3, return_fea=True) self.srm_conv0 = SRMConv2d_simple(inc=3) self.srm_conv1 = SRMConv2d_Separate(32, 32) self.srm_conv2 = SRMConv2d_Separate(64, 64) self.relu = nn.ReLU(inplace=True) self.att_map = None self.srm_sa = SRMPixelAttention(3) self.srm_sa_post = nn.Sequential( nn.BatchNorm2d(64), nn.ReLU(inplace=True) ) self.dual_cma0 = DualCrossModalAttention(in_dim=728, ret_att=False) self.dual_cma1 = DualCrossModalAttention(in_dim=728, ret_att=False) self.fusion = FeatureFusionModule() self.att_dic = {} def features(self, x): srm = self.srm_conv0(x) x = self.xception_rgb.model.fea_part1_0(x) y = self.xception_srm.model.fea_part1_0(srm) \ + self.srm_conv1(x) y = self.relu(y) x = self.xception_rgb.model.fea_part1_1(x) y = self.xception_srm.model.fea_part1_1(y) \ + self.srm_conv2(x) y = self.relu(y) # srm guided spatial attention self.att_map = self.srm_sa(srm) x = x * self.att_map + x x = self.srm_sa_post(x) x = self.xception_rgb.model.fea_part2(x) y = self.xception_srm.model.fea_part2(y) x, y = self.dual_cma0(x, y) x = self.xception_rgb.model.fea_part3(x) y = self.xception_srm.model.fea_part3(y) x, y = self.dual_cma1(x, y) x = self.xception_rgb.model.fea_part4(x) y = self.xception_srm.model.fea_part4(y) x = self.xception_rgb.model.fea_part5(x) y = self.xception_srm.model.fea_part5(y) fea = self.fusion(x, y) return fea def classifier(self, fea): out, fea = self.xception_rgb.classifier(fea) return out, fea def forward(self, x): ''' x: original rgb Return: out: (B, 2) the output for loss computing fea: (B, 1024) the flattened features before the last FC att_map: srm spatial attention map ''' out, fea = self.classifier(self.features(x)) return out, fea, self.att_map if __name__ == '__main__': model = Two_Stream_Net() dummy = torch.rand((1,3,256,256)) out = model(dummy) print(model)