Image Segmentation
Transformers
PyTorch
upernet
Inference Endpoints
mccaly's picture
Upload 660 files
b13b124
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import ConvModule
from mmseg.ops import resize
from ..builder import HEADS
from .decode_head import BaseDecodeHead
try:
from mmcv.ops import PSAMask
except ModuleNotFoundError:
PSAMask = None
@HEADS.register_module()
class PSAHead(BaseDecodeHead):
"""Point-wise Spatial Attention Network for Scene Parsing.
This head is the implementation of `PSANet
<https://hszhao.github.io/papers/eccv18_psanet.pdf>`_.
Args:
mask_size (tuple[int]): The PSA mask size. It usually equals input
size.
psa_type (str): The type of psa module. Options are 'collect',
'distribute', 'bi-direction'. Default: 'bi-direction'
compact (bool): Whether use compact map for 'collect' mode.
Default: True.
shrink_factor (int): The downsample factors of psa mask. Default: 2.
normalization_factor (float): The normalize factor of attention.
psa_softmax (bool): Whether use softmax for attention.
"""
def __init__(self,
mask_size,
psa_type='bi-direction',
compact=False,
shrink_factor=2,
normalization_factor=1.0,
psa_softmax=True,
**kwargs):
if PSAMask is None:
raise RuntimeError('Please install mmcv-full for PSAMask ops')
super(PSAHead, self).__init__(**kwargs)
assert psa_type in ['collect', 'distribute', 'bi-direction']
self.psa_type = psa_type
self.compact = compact
self.shrink_factor = shrink_factor
self.mask_size = mask_size
mask_h, mask_w = mask_size
self.psa_softmax = psa_softmax
if normalization_factor is None:
normalization_factor = mask_h * mask_w
self.normalization_factor = normalization_factor
self.reduce = ConvModule(
self.in_channels,
self.channels,
kernel_size=1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
self.attention = nn.Sequential(
ConvModule(
self.channels,
self.channels,
kernel_size=1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg),
nn.Conv2d(
self.channels, mask_h * mask_w, kernel_size=1, bias=False))
if psa_type == 'bi-direction':
self.reduce_p = ConvModule(
self.in_channels,
self.channels,
kernel_size=1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
self.attention_p = nn.Sequential(
ConvModule(
self.channels,
self.channels,
kernel_size=1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg),
nn.Conv2d(
self.channels, mask_h * mask_w, kernel_size=1, bias=False))
self.psamask_collect = PSAMask('collect', mask_size)
self.psamask_distribute = PSAMask('distribute', mask_size)
else:
self.psamask = PSAMask(psa_type, mask_size)
self.proj = ConvModule(
self.channels * (2 if psa_type == 'bi-direction' else 1),
self.in_channels,
kernel_size=1,
padding=1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
self.bottleneck = ConvModule(
self.in_channels * 2,
self.channels,
kernel_size=3,
padding=1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
def forward(self, inputs):
"""Forward function."""
x = self._transform_inputs(inputs)
identity = x
align_corners = self.align_corners
if self.psa_type in ['collect', 'distribute']:
out = self.reduce(x)
n, c, h, w = out.size()
if self.shrink_factor != 1:
if h % self.shrink_factor and w % self.shrink_factor:
h = (h - 1) // self.shrink_factor + 1
w = (w - 1) // self.shrink_factor + 1
align_corners = True
else:
h = h // self.shrink_factor
w = w // self.shrink_factor
align_corners = False
out = resize(
out,
size=(h, w),
mode='bilinear',
align_corners=align_corners)
y = self.attention(out)
if self.compact:
if self.psa_type == 'collect':
y = y.view(n, h * w,
h * w).transpose(1, 2).view(n, h * w, h, w)
else:
y = self.psamask(y)
if self.psa_softmax:
y = F.softmax(y, dim=1)
out = torch.bmm(
out.view(n, c, h * w), y.view(n, h * w, h * w)).view(
n, c, h, w) * (1.0 / self.normalization_factor)
else:
x_col = self.reduce(x)
x_dis = self.reduce_p(x)
n, c, h, w = x_col.size()
if self.shrink_factor != 1:
if h % self.shrink_factor and w % self.shrink_factor:
h = (h - 1) // self.shrink_factor + 1
w = (w - 1) // self.shrink_factor + 1
align_corners = True
else:
h = h // self.shrink_factor
w = w // self.shrink_factor
align_corners = False
x_col = resize(
x_col,
size=(h, w),
mode='bilinear',
align_corners=align_corners)
x_dis = resize(
x_dis,
size=(h, w),
mode='bilinear',
align_corners=align_corners)
y_col = self.attention(x_col)
y_dis = self.attention_p(x_dis)
if self.compact:
y_dis = y_dis.view(n, h * w,
h * w).transpose(1, 2).view(n, h * w, h, w)
else:
y_col = self.psamask_collect(y_col)
y_dis = self.psamask_distribute(y_dis)
if self.psa_softmax:
y_col = F.softmax(y_col, dim=1)
y_dis = F.softmax(y_dis, dim=1)
x_col = torch.bmm(
x_col.view(n, c, h * w), y_col.view(n, h * w, h * w)).view(
n, c, h, w) * (1.0 / self.normalization_factor)
x_dis = torch.bmm(
x_dis.view(n, c, h * w), y_dis.view(n, h * w, h * w)).view(
n, c, h, w) * (1.0 / self.normalization_factor)
out = torch.cat([x_col, x_dis], 1)
out = self.proj(out)
out = resize(
out,
size=identity.shape[2:],
mode='bilinear',
align_corners=align_corners)
out = self.bottleneck(torch.cat((identity, out), dim=1))
out = self.cls_seg(out)
return out