|
"""
|
|
attention modules in ['SimAM', 'CBAM', 'SE', 'GAM'] were applied in the ablation study
|
|
|
|
ver: Dec 24th 15:00
|
|
|
|
|
|
ref:
|
|
https://github.com/xmu-xiaoma666/External-Attention-pytorch
|
|
"""
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import math
|
|
import torch.nn.functional as F
|
|
from torch.nn import init
|
|
|
|
|
|
|
|
class BasicConv(nn.Module):
|
|
def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1, groups=1, relu=True,
|
|
bn=True, bias=False):
|
|
super(BasicConv, self).__init__()
|
|
self.out_channels = out_planes
|
|
self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding,
|
|
dilation=dilation, groups=groups, bias=bias)
|
|
self.bn = nn.BatchNorm2d(out_planes, eps=1e-5, momentum=0.01, affine=True) if bn else None
|
|
self.relu = nn.ReLU() if relu else None
|
|
|
|
def forward(self, x):
|
|
x = self.conv(x)
|
|
if self.bn is not None:
|
|
x = self.bn(x)
|
|
if self.relu is not None:
|
|
x = self.relu(x)
|
|
return x
|
|
|
|
|
|
class Flatten(nn.Module):
|
|
def forward(self, x):
|
|
return x.view(x.size(0), -1)
|
|
|
|
|
|
class ChannelGate(nn.Module):
|
|
def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max']):
|
|
super(ChannelGate, self).__init__()
|
|
self.gate_channels = gate_channels
|
|
self.mlp = nn.Sequential(
|
|
Flatten(),
|
|
nn.Linear(gate_channels, int(gate_channels // reduction_ratio)),
|
|
nn.ReLU(),
|
|
nn.Linear(int(gate_channels // reduction_ratio), gate_channels)
|
|
)
|
|
self.pool_types = pool_types
|
|
|
|
def forward(self, x):
|
|
channel_att_sum = None
|
|
for pool_type in self.pool_types:
|
|
if pool_type == 'avg':
|
|
avg_pool = F.avg_pool2d(x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
|
|
channel_att_raw = self.mlp(avg_pool)
|
|
elif pool_type == 'max':
|
|
max_pool = F.max_pool2d(x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
|
|
channel_att_raw = self.mlp(max_pool)
|
|
elif pool_type == 'lp':
|
|
lp_pool = F.lp_pool2d(x, 2, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
|
|
channel_att_raw = self.mlp(lp_pool)
|
|
elif pool_type == 'lse':
|
|
|
|
lse_pool = logsumexp_2d(x)
|
|
channel_att_raw = self.mlp(lse_pool)
|
|
|
|
if channel_att_sum is None:
|
|
channel_att_sum = channel_att_raw
|
|
else:
|
|
channel_att_sum = channel_att_sum + channel_att_raw
|
|
|
|
scale = F.sigmoid(channel_att_sum).unsqueeze(2).unsqueeze(3).expand_as(x)
|
|
return x * scale
|
|
|
|
|
|
def logsumexp_2d(tensor):
|
|
tensor_flatten = tensor.view(tensor.size(0), tensor.size(1), -1)
|
|
s, _ = torch.max(tensor_flatten, dim=2, keepdim=True)
|
|
outputs = s + (tensor_flatten - s).exp().sum(dim=2, keepdim=True).log()
|
|
return outputs
|
|
|
|
|
|
class ChannelPool(nn.Module):
|
|
def forward(self, x):
|
|
return torch.cat((torch.max(x, 1)[0].unsqueeze(1), torch.mean(x, 1).unsqueeze(1)), dim=1)
|
|
|
|
|
|
class SpatialGate(nn.Module):
|
|
def __init__(self):
|
|
super(SpatialGate, self).__init__()
|
|
kernel_size = 7
|
|
self.compress = ChannelPool()
|
|
self.spatial = BasicConv(2, 1, kernel_size, stride=1, padding=int((kernel_size - 1) // 2), relu=False)
|
|
|
|
def forward(self, x):
|
|
x_compress = self.compress(x)
|
|
x_out = self.spatial(x_compress)
|
|
scale = F.sigmoid(x_out)
|
|
return x * scale
|
|
|
|
|
|
|
|
class cbam_module(nn.Module):
|
|
"""
|
|
module:CBAM
|
|
|
|
input、output= b, c, h, w
|
|
|
|
paper:
|
|
https://arxiv.org/abs/1807.06521
|
|
code:
|
|
https://github.com/ZjjConan/SimAM/blob/master/networks/attentions
|
|
"""
|
|
|
|
def __init__(self, gate_channels, reduction=16, pool_types=['avg', 'max'], no_spatial=False):
|
|
super(cbam_module, self).__init__()
|
|
self.ChannelGate = ChannelGate(gate_channels, reduction, pool_types)
|
|
self.no_spatial = no_spatial
|
|
if not no_spatial:
|
|
self.SpatialGate = SpatialGate()
|
|
|
|
@staticmethod
|
|
def get_module_name():
|
|
return "cbam"
|
|
|
|
def forward(self, x):
|
|
x_out = self.ChannelGate(x)
|
|
if not self.no_spatial:
|
|
x_out = self.SpatialGate(x_out)
|
|
return x_out
|
|
|
|
|
|
class se_module(nn.Module):
|
|
"""
|
|
module: SE
|
|
|
|
input、output= b, c, h, w
|
|
|
|
from paper Squeeze-and-Excitation Networks
|
|
SE-Net https://arxiv.org/abs/1709.01507
|
|
code:
|
|
https://github.com/ZjjConan/SimAM/blob/master/networks/attentions
|
|
"""
|
|
|
|
def __init__(self, channel, reduction=16):
|
|
super(se_module, self).__init__()
|
|
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
|
self.fc = nn.Sequential(
|
|
nn.Linear(channel, int(channel // reduction), bias=False),
|
|
nn.ReLU(inplace=True),
|
|
nn.Linear(int(channel // reduction), channel, bias=False),
|
|
nn.Sigmoid()
|
|
)
|
|
|
|
@staticmethod
|
|
def get_module_name():
|
|
return "se"
|
|
|
|
def forward(self, x):
|
|
b, c, _, _ = x.size()
|
|
y = self.avg_pool(x).view(b, c)
|
|
y = self.fc(y).view(b, c, 1, 1)
|
|
return x * y
|
|
|
|
|
|
class simam_module(torch.nn.Module):
|
|
"""
|
|
module:SimAM
|
|
|
|
input、output= b, c, h, w
|
|
|
|
paper:(ICML)
|
|
SimAM: A Simple, Parameter-Free Attention Module for Convolutional Neural Networks
|
|
code:
|
|
https://github.com/ZjjConan/SimAM/blob/master/networks/attentions/simam_module.py
|
|
"""
|
|
|
|
def __init__(self, channels=None, e_lambda=1e-4):
|
|
super(simam_module, self).__init__()
|
|
|
|
self.activaton = nn.Sigmoid()
|
|
self.e_lambda = e_lambda
|
|
|
|
def __repr__(self):
|
|
s = self.__class__.__name__ + '('
|
|
s += ('lambda=%f)' % self.e_lambda)
|
|
return s
|
|
|
|
@staticmethod
|
|
def get_module_name():
|
|
return "simam"
|
|
|
|
def forward(self, x):
|
|
b, c, h, w = x.size()
|
|
|
|
n = w * h - 1
|
|
|
|
x_minus_mu_square = (x - x.mean(dim=[2, 3], keepdim=True)).pow(2)
|
|
y = x_minus_mu_square / (4 * (x_minus_mu_square.sum(dim=[2, 3], keepdim=True) / n + self.e_lambda)) + 0.5
|
|
|
|
return x * self.activaton(y)
|
|
|
|
|
|
class ResidualAttention(nn.Module):
|
|
"""
|
|
module: ResidualAttention
|
|
|
|
input、output= b, c, h, w
|
|
|
|
Paper:ICCV 2021 Residual Attention: A Simple but Effective Method for Multi-Label Recognition
|
|
code:https://github.com/xmu-xiaoma666/External-Attention-pytorch/blob/master/attention/ResidualAttention.py
|
|
"""
|
|
|
|
def __init__(self, channel=512, num_class=1000, la=0.2):
|
|
super().__init__()
|
|
self.la = la
|
|
self.fc = nn.Conv2d(in_channels=channel, out_channels=num_class, kernel_size=1, stride=1, bias=False)
|
|
|
|
def forward(self, x):
|
|
b, c, h, w = x.shape
|
|
y_raw = self.fc(x).flatten(2)
|
|
y_avg = torch.mean(y_raw, dim=2)
|
|
y_max = torch.max(y_raw, dim=2)[0]
|
|
score = y_avg + self.la * y_max
|
|
return score
|
|
|
|
|
|
class eca_module(nn.Module):
|
|
"""Constructs a ECA module.
|
|
|
|
Args:
|
|
channel: Number of channels of the input feature map
|
|
k_size: Adaptive selection of kernel size
|
|
"""
|
|
def __init__(self, channel, k_size=3):
|
|
super(eca_module, self).__init__()
|
|
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
|
self.conv = nn.Conv1d(1, 1, kernel_size=k_size, padding=(k_size - 1) // 2, bias=False)
|
|
self.sigmoid = nn.Sigmoid()
|
|
|
|
def forward(self, x):
|
|
|
|
b, c, h, w = x.size()
|
|
|
|
|
|
y = self.avg_pool(x)
|
|
|
|
|
|
y = self.conv(y.squeeze(-1).transpose(-1, -2)).transpose(-1, -2).unsqueeze(-1)
|
|
|
|
|
|
y = self.sigmoid(y)
|
|
|
|
return x * y.expand_as(x)
|
|
|
|
|
|
class GAM_Attention(nn.Module):
|
|
"""
|
|
module:GAM
|
|
|
|
input= b, in_channels, h, w
|
|
output= b, out_channels, h, w
|
|
|
|
paper:
|
|
Global Attention Mechanism: Retain Information to Enhance Channel-Spatial Interactions
|
|
https://arxiv.org/abs/2112.05561
|
|
code:
|
|
https://mp.weixin.qq.com/s/VL6rXjyUDmHToYTqM32hUg
|
|
"""
|
|
def __init__(self, in_channels, out_channels, rate=4):
|
|
super(GAM_Attention, self).__init__()
|
|
|
|
self.channel_attention = nn.Sequential(
|
|
nn.Linear(in_channels, int(in_channels / rate)),
|
|
nn.ReLU(inplace=True),
|
|
nn.Linear(int(in_channels / rate), in_channels)
|
|
)
|
|
|
|
self.spatial_attention = nn.Sequential(
|
|
nn.Conv2d(in_channels, int(in_channels / rate), kernel_size=7, padding=3),
|
|
nn.BatchNorm2d(int(in_channels / rate)),
|
|
nn.ReLU(inplace=True),
|
|
nn.Conv2d(int(in_channels / rate), out_channels, kernel_size=7, padding=3),
|
|
nn.BatchNorm2d(out_channels)
|
|
)
|
|
|
|
def forward(self, x):
|
|
b, c, h, w = x.shape
|
|
x_permute = x.permute(0, 2, 3, 1).view(b, -1, c)
|
|
x_att_permute = self.channel_attention(x_permute).view(b, h, w, c)
|
|
x_channel_att = x_att_permute.permute(0, 3, 1, 2)
|
|
|
|
x = x * x_channel_att
|
|
|
|
x_spatial_att = self.spatial_attention(x).sigmoid()
|
|
out = x * x_spatial_att
|
|
|
|
return out
|
|
|