""" attention modules in ['SimAM', 'CBAM', 'SE', 'GAM'] were applied in the ablation study ver: Dec 24th 15:00 ref: https://github.com/xmu-xiaoma666/External-Attention-pytorch """ import torch import torch.nn as nn import math import torch.nn.functional as F from torch.nn import init # help func class BasicConv(nn.Module): def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1, groups=1, relu=True, bn=True, bias=False): super(BasicConv, self).__init__() self.out_channels = out_planes self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias) self.bn = nn.BatchNorm2d(out_planes, eps=1e-5, momentum=0.01, affine=True) if bn else None self.relu = nn.ReLU() if relu else None def forward(self, x): x = self.conv(x) if self.bn is not None: x = self.bn(x) if self.relu is not None: x = self.relu(x) return x class Flatten(nn.Module): def forward(self, x): return x.view(x.size(0), -1) class ChannelGate(nn.Module): def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max']): super(ChannelGate, self).__init__() self.gate_channels = gate_channels self.mlp = nn.Sequential( Flatten(), nn.Linear(gate_channels, int(gate_channels // reduction_ratio)), nn.ReLU(), nn.Linear(int(gate_channels // reduction_ratio), gate_channels) ) self.pool_types = pool_types def forward(self, x): channel_att_sum = None for pool_type in self.pool_types: if pool_type == 'avg': avg_pool = F.avg_pool2d(x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3))) channel_att_raw = self.mlp(avg_pool) elif pool_type == 'max': max_pool = F.max_pool2d(x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3))) channel_att_raw = self.mlp(max_pool) elif pool_type == 'lp': lp_pool = F.lp_pool2d(x, 2, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3))) channel_att_raw = self.mlp(lp_pool) elif pool_type == 'lse': # LSE pool only lse_pool = logsumexp_2d(x) channel_att_raw = self.mlp(lse_pool) if channel_att_sum is None: channel_att_sum = channel_att_raw else: channel_att_sum = channel_att_sum + channel_att_raw scale = F.sigmoid(channel_att_sum).unsqueeze(2).unsqueeze(3).expand_as(x) return x * scale def logsumexp_2d(tensor): tensor_flatten = tensor.view(tensor.size(0), tensor.size(1), -1) s, _ = torch.max(tensor_flatten, dim=2, keepdim=True) outputs = s + (tensor_flatten - s).exp().sum(dim=2, keepdim=True).log() return outputs class ChannelPool(nn.Module): def forward(self, x): return torch.cat((torch.max(x, 1)[0].unsqueeze(1), torch.mean(x, 1).unsqueeze(1)), dim=1) class SpatialGate(nn.Module): def __init__(self): super(SpatialGate, self).__init__() kernel_size = 7 self.compress = ChannelPool() self.spatial = BasicConv(2, 1, kernel_size, stride=1, padding=int((kernel_size - 1) // 2), relu=False) def forward(self, x): x_compress = self.compress(x) x_out = self.spatial(x_compress) scale = F.sigmoid(x_out) # broadcasting return x * scale # attention modules: class cbam_module(nn.Module): """ module:CBAM input、output= b, c, h, w paper: https://arxiv.org/abs/1807.06521 code: https://github.com/ZjjConan/SimAM/blob/master/networks/attentions """ def __init__(self, gate_channels, reduction=16, pool_types=['avg', 'max'], no_spatial=False): super(cbam_module, self).__init__() self.ChannelGate = ChannelGate(gate_channels, reduction, pool_types) self.no_spatial = no_spatial if not no_spatial: self.SpatialGate = SpatialGate() @staticmethod def get_module_name(): return "cbam" def forward(self, x): x_out = self.ChannelGate(x) if not self.no_spatial: x_out = self.SpatialGate(x_out) return x_out class se_module(nn.Module): """ module: SE input、output= b, c, h, w from paper Squeeze-and-Excitation Networks SE-Net https://arxiv.org/abs/1709.01507 code: https://github.com/ZjjConan/SimAM/blob/master/networks/attentions """ def __init__(self, channel, reduction=16): super(se_module, self).__init__() self.avg_pool = nn.AdaptiveAvgPool2d(1) self.fc = nn.Sequential( nn.Linear(channel, int(channel // reduction), bias=False), nn.ReLU(inplace=True), nn.Linear(int(channel // reduction), channel, bias=False), nn.Sigmoid() ) @staticmethod def get_module_name(): return "se" def forward(self, x): b, c, _, _ = x.size() y = self.avg_pool(x).view(b, c) y = self.fc(y).view(b, c, 1, 1) return x * y class simam_module(torch.nn.Module): """ module:SimAM input、output= b, c, h, w paper:(ICML) SimAM: A Simple, Parameter-Free Attention Module for Convolutional Neural Networks code: https://github.com/ZjjConan/SimAM/blob/master/networks/attentions/simam_module.py """ def __init__(self, channels=None, e_lambda=1e-4): super(simam_module, self).__init__() self.activaton = nn.Sigmoid() self.e_lambda = e_lambda def __repr__(self): s = self.__class__.__name__ + '(' s += ('lambda=%f)' % self.e_lambda) return s @staticmethod def get_module_name(): return "simam" def forward(self, x): b, c, h, w = x.size() n = w * h - 1 x_minus_mu_square = (x - x.mean(dim=[2, 3], keepdim=True)).pow(2) y = x_minus_mu_square / (4 * (x_minus_mu_square.sum(dim=[2, 3], keepdim=True) / n + self.e_lambda)) + 0.5 return x * self.activaton(y) class ResidualAttention(nn.Module): """ module: ResidualAttention input、output= b, c, h, w Paper:ICCV 2021 Residual Attention: A Simple but Effective Method for Multi-Label Recognition code:https://github.com/xmu-xiaoma666/External-Attention-pytorch/blob/master/attention/ResidualAttention.py """ def __init__(self, channel=512, num_class=1000, la=0.2): super().__init__() self.la = la self.fc = nn.Conv2d(in_channels=channel, out_channels=num_class, kernel_size=1, stride=1, bias=False) def forward(self, x): b, c, h, w = x.shape y_raw = self.fc(x).flatten(2) # b,num_class,hxw y_avg = torch.mean(y_raw, dim=2) # b,num_class y_max = torch.max(y_raw, dim=2)[0] # b,num_class score = y_avg + self.la * y_max return score class eca_module(nn.Module): """Constructs a ECA module. Args: channel: Number of channels of the input feature map k_size: Adaptive selection of kernel size """ def __init__(self, channel, k_size=3): super(eca_module, self).__init__() self.avg_pool = nn.AdaptiveAvgPool2d(1) self.conv = nn.Conv1d(1, 1, kernel_size=k_size, padding=(k_size - 1) // 2, bias=False) self.sigmoid = nn.Sigmoid() def forward(self, x): # x: input features with shape [b, c, h, w] b, c, h, w = x.size() # feature descriptor on the global spatial information y = self.avg_pool(x) # Two different branches of ECA module y = self.conv(y.squeeze(-1).transpose(-1, -2)).transpose(-1, -2).unsqueeze(-1) # Multi-scale information fusion y = self.sigmoid(y) return x * y.expand_as(x) class GAM_Attention(nn.Module): """ module:GAM input= b, in_channels, h, w output= b, out_channels, h, w paper: Global Attention Mechanism: Retain Information to Enhance Channel-Spatial Interactions https://arxiv.org/abs/2112.05561 code: https://mp.weixin.qq.com/s/VL6rXjyUDmHToYTqM32hUg """ def __init__(self, in_channels, out_channels, rate=4): super(GAM_Attention, self).__init__() self.channel_attention = nn.Sequential( nn.Linear(in_channels, int(in_channels / rate)), nn.ReLU(inplace=True), nn.Linear(int(in_channels / rate), in_channels) ) self.spatial_attention = nn.Sequential( nn.Conv2d(in_channels, int(in_channels / rate), kernel_size=7, padding=3), nn.BatchNorm2d(int(in_channels / rate)), nn.ReLU(inplace=True), nn.Conv2d(int(in_channels / rate), out_channels, kernel_size=7, padding=3), nn.BatchNorm2d(out_channels) ) def forward(self, x): b, c, h, w = x.shape x_permute = x.permute(0, 2, 3, 1).view(b, -1, c) x_att_permute = self.channel_attention(x_permute).view(b, h, w, c) x_channel_att = x_att_permute.permute(0, 3, 1, 2) x = x * x_channel_att x_spatial_att = self.spatial_attention(x).sigmoid() out = x * x_spatial_att return out