PuzzleTuning_VPT / PuzzleTuning /Backbone /attention_modules.py
Tianyinus's picture
init submit
edcf5ee verified
"""
attention modules in ['SimAM', 'CBAM', 'SE', 'GAM'] were applied in the ablation study
ver: Dec 24th 15:00
ref:
https://github.com/xmu-xiaoma666/External-Attention-pytorch
"""
import torch
import torch.nn as nn
import math
import torch.nn.functional as F
from torch.nn import init
# help func
class BasicConv(nn.Module):
def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1, groups=1, relu=True,
bn=True, bias=False):
super(BasicConv, self).__init__()
self.out_channels = out_planes
self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding,
dilation=dilation, groups=groups, bias=bias)
self.bn = nn.BatchNorm2d(out_planes, eps=1e-5, momentum=0.01, affine=True) if bn else None
self.relu = nn.ReLU() if relu else None
def forward(self, x):
x = self.conv(x)
if self.bn is not None:
x = self.bn(x)
if self.relu is not None:
x = self.relu(x)
return x
class Flatten(nn.Module):
def forward(self, x):
return x.view(x.size(0), -1)
class ChannelGate(nn.Module):
def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max']):
super(ChannelGate, self).__init__()
self.gate_channels = gate_channels
self.mlp = nn.Sequential(
Flatten(),
nn.Linear(gate_channels, int(gate_channels // reduction_ratio)),
nn.ReLU(),
nn.Linear(int(gate_channels // reduction_ratio), gate_channels)
)
self.pool_types = pool_types
def forward(self, x):
channel_att_sum = None
for pool_type in self.pool_types:
if pool_type == 'avg':
avg_pool = F.avg_pool2d(x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
channel_att_raw = self.mlp(avg_pool)
elif pool_type == 'max':
max_pool = F.max_pool2d(x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
channel_att_raw = self.mlp(max_pool)
elif pool_type == 'lp':
lp_pool = F.lp_pool2d(x, 2, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
channel_att_raw = self.mlp(lp_pool)
elif pool_type == 'lse':
# LSE pool only
lse_pool = logsumexp_2d(x)
channel_att_raw = self.mlp(lse_pool)
if channel_att_sum is None:
channel_att_sum = channel_att_raw
else:
channel_att_sum = channel_att_sum + channel_att_raw
scale = F.sigmoid(channel_att_sum).unsqueeze(2).unsqueeze(3).expand_as(x)
return x * scale
def logsumexp_2d(tensor):
tensor_flatten = tensor.view(tensor.size(0), tensor.size(1), -1)
s, _ = torch.max(tensor_flatten, dim=2, keepdim=True)
outputs = s + (tensor_flatten - s).exp().sum(dim=2, keepdim=True).log()
return outputs
class ChannelPool(nn.Module):
def forward(self, x):
return torch.cat((torch.max(x, 1)[0].unsqueeze(1), torch.mean(x, 1).unsqueeze(1)), dim=1)
class SpatialGate(nn.Module):
def __init__(self):
super(SpatialGate, self).__init__()
kernel_size = 7
self.compress = ChannelPool()
self.spatial = BasicConv(2, 1, kernel_size, stride=1, padding=int((kernel_size - 1) // 2), relu=False)
def forward(self, x):
x_compress = self.compress(x)
x_out = self.spatial(x_compress)
scale = F.sigmoid(x_out) # broadcasting
return x * scale
# attention modules:
class cbam_module(nn.Module):
"""
module:CBAM
input、output= b, c, h, w
paper:
https://arxiv.org/abs/1807.06521
code:
https://github.com/ZjjConan/SimAM/blob/master/networks/attentions
"""
def __init__(self, gate_channels, reduction=16, pool_types=['avg', 'max'], no_spatial=False):
super(cbam_module, self).__init__()
self.ChannelGate = ChannelGate(gate_channels, reduction, pool_types)
self.no_spatial = no_spatial
if not no_spatial:
self.SpatialGate = SpatialGate()
@staticmethod
def get_module_name():
return "cbam"
def forward(self, x):
x_out = self.ChannelGate(x)
if not self.no_spatial:
x_out = self.SpatialGate(x_out)
return x_out
class se_module(nn.Module):
"""
module: SE
input、output= b, c, h, w
from paper Squeeze-and-Excitation Networks
SE-Net https://arxiv.org/abs/1709.01507
code:
https://github.com/ZjjConan/SimAM/blob/master/networks/attentions
"""
def __init__(self, channel, reduction=16):
super(se_module, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Sequential(
nn.Linear(channel, int(channel // reduction), bias=False),
nn.ReLU(inplace=True),
nn.Linear(int(channel // reduction), channel, bias=False),
nn.Sigmoid()
)
@staticmethod
def get_module_name():
return "se"
def forward(self, x):
b, c, _, _ = x.size()
y = self.avg_pool(x).view(b, c)
y = self.fc(y).view(b, c, 1, 1)
return x * y
class simam_module(torch.nn.Module):
"""
module:SimAM
input、output= b, c, h, w
paper:(ICML)
SimAM: A Simple, Parameter-Free Attention Module for Convolutional Neural Networks
code:
https://github.com/ZjjConan/SimAM/blob/master/networks/attentions/simam_module.py
"""
def __init__(self, channels=None, e_lambda=1e-4):
super(simam_module, self).__init__()
self.activaton = nn.Sigmoid()
self.e_lambda = e_lambda
def __repr__(self):
s = self.__class__.__name__ + '('
s += ('lambda=%f)' % self.e_lambda)
return s
@staticmethod
def get_module_name():
return "simam"
def forward(self, x):
b, c, h, w = x.size()
n = w * h - 1
x_minus_mu_square = (x - x.mean(dim=[2, 3], keepdim=True)).pow(2)
y = x_minus_mu_square / (4 * (x_minus_mu_square.sum(dim=[2, 3], keepdim=True) / n + self.e_lambda)) + 0.5
return x * self.activaton(y)
class ResidualAttention(nn.Module):
"""
module: ResidualAttention
input、output= b, c, h, w
Paper:ICCV 2021 Residual Attention: A Simple but Effective Method for Multi-Label Recognition
code:https://github.com/xmu-xiaoma666/External-Attention-pytorch/blob/master/attention/ResidualAttention.py
"""
def __init__(self, channel=512, num_class=1000, la=0.2):
super().__init__()
self.la = la
self.fc = nn.Conv2d(in_channels=channel, out_channels=num_class, kernel_size=1, stride=1, bias=False)
def forward(self, x):
b, c, h, w = x.shape
y_raw = self.fc(x).flatten(2) # b,num_class,hxw
y_avg = torch.mean(y_raw, dim=2) # b,num_class
y_max = torch.max(y_raw, dim=2)[0] # b,num_class
score = y_avg + self.la * y_max
return score
class eca_module(nn.Module):
"""Constructs a ECA module.
Args:
channel: Number of channels of the input feature map
k_size: Adaptive selection of kernel size
"""
def __init__(self, channel, k_size=3):
super(eca_module, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.conv = nn.Conv1d(1, 1, kernel_size=k_size, padding=(k_size - 1) // 2, bias=False)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
# x: input features with shape [b, c, h, w]
b, c, h, w = x.size()
# feature descriptor on the global spatial information
y = self.avg_pool(x)
# Two different branches of ECA module
y = self.conv(y.squeeze(-1).transpose(-1, -2)).transpose(-1, -2).unsqueeze(-1)
# Multi-scale information fusion
y = self.sigmoid(y)
return x * y.expand_as(x)
class GAM_Attention(nn.Module):
"""
module:GAM
input= b, in_channels, h, w
output= b, out_channels, h, w
paper:
Global Attention Mechanism: Retain Information to Enhance Channel-Spatial Interactions
https://arxiv.org/abs/2112.05561
code:
https://mp.weixin.qq.com/s/VL6rXjyUDmHToYTqM32hUg
"""
def __init__(self, in_channels, out_channels, rate=4):
super(GAM_Attention, self).__init__()
self.channel_attention = nn.Sequential(
nn.Linear(in_channels, int(in_channels / rate)),
nn.ReLU(inplace=True),
nn.Linear(int(in_channels / rate), in_channels)
)
self.spatial_attention = nn.Sequential(
nn.Conv2d(in_channels, int(in_channels / rate), kernel_size=7, padding=3),
nn.BatchNorm2d(int(in_channels / rate)),
nn.ReLU(inplace=True),
nn.Conv2d(int(in_channels / rate), out_channels, kernel_size=7, padding=3),
nn.BatchNorm2d(out_channels)
)
def forward(self, x):
b, c, h, w = x.shape
x_permute = x.permute(0, 2, 3, 1).view(b, -1, c)
x_att_permute = self.channel_attention(x_permute).view(b, h, w, c)
x_channel_att = x_att_permute.permute(0, 3, 1, 2)
x = x * x_channel_att
x_spatial_att = self.spatial_attention(x).sigmoid()
out = x * x_spatial_att
return out