|
import torch |
|
from mmcv.cnn import ConvModule, constant_init |
|
from torch import nn as nn |
|
from torch.nn import functional as F |
|
|
|
|
|
class SelfAttentionBlock(nn.Module): |
|
"""General self-attention block/non-local block. |
|
|
|
Please refer to https://arxiv.org/abs/1706.03762 for details about key, |
|
query and value. |
|
|
|
Args: |
|
key_in_channels (int): Input channels of key feature. |
|
query_in_channels (int): Input channels of query feature. |
|
channels (int): Output channels of key/query transform. |
|
out_channels (int): Output channels. |
|
share_key_query (bool): Whether share projection weight between key |
|
and query projection. |
|
query_downsample (nn.Module): Query downsample module. |
|
key_downsample (nn.Module): Key downsample module. |
|
key_query_num_convs (int): Number of convs for key/query projection. |
|
value_num_convs (int): Number of convs for value projection. |
|
matmul_norm (bool): Whether normalize attention map with sqrt of |
|
channels |
|
with_out (bool): Whether use out projection. |
|
conv_cfg (dict|None): Config of conv layers. |
|
norm_cfg (dict|None): Config of norm layers. |
|
act_cfg (dict|None): Config of activation layers. |
|
""" |
|
|
|
def __init__(self, key_in_channels, query_in_channels, channels, |
|
out_channels, share_key_query, query_downsample, |
|
key_downsample, key_query_num_convs, value_out_num_convs, |
|
key_query_norm, value_out_norm, matmul_norm, with_out, |
|
conv_cfg, norm_cfg, act_cfg): |
|
super(SelfAttentionBlock, self).__init__() |
|
if share_key_query: |
|
assert key_in_channels == query_in_channels |
|
self.key_in_channels = key_in_channels |
|
self.query_in_channels = query_in_channels |
|
self.out_channels = out_channels |
|
self.channels = channels |
|
self.share_key_query = share_key_query |
|
self.conv_cfg = conv_cfg |
|
self.norm_cfg = norm_cfg |
|
self.act_cfg = act_cfg |
|
self.key_project = self.build_project( |
|
key_in_channels, |
|
channels, |
|
num_convs=key_query_num_convs, |
|
use_conv_module=key_query_norm, |
|
conv_cfg=conv_cfg, |
|
norm_cfg=norm_cfg, |
|
act_cfg=act_cfg) |
|
if share_key_query: |
|
self.query_project = self.key_project |
|
else: |
|
self.query_project = self.build_project( |
|
query_in_channels, |
|
channels, |
|
num_convs=key_query_num_convs, |
|
use_conv_module=key_query_norm, |
|
conv_cfg=conv_cfg, |
|
norm_cfg=norm_cfg, |
|
act_cfg=act_cfg) |
|
self.value_project = self.build_project( |
|
key_in_channels, |
|
channels if with_out else out_channels, |
|
num_convs=value_out_num_convs, |
|
use_conv_module=value_out_norm, |
|
conv_cfg=conv_cfg, |
|
norm_cfg=norm_cfg, |
|
act_cfg=act_cfg) |
|
if with_out: |
|
self.out_project = self.build_project( |
|
channels, |
|
out_channels, |
|
num_convs=value_out_num_convs, |
|
use_conv_module=value_out_norm, |
|
conv_cfg=conv_cfg, |
|
norm_cfg=norm_cfg, |
|
act_cfg=act_cfg) |
|
else: |
|
self.out_project = None |
|
|
|
self.query_downsample = query_downsample |
|
self.key_downsample = key_downsample |
|
self.matmul_norm = matmul_norm |
|
|
|
self.init_weights() |
|
|
|
def init_weights(self): |
|
"""Initialize weight of later layer.""" |
|
if self.out_project is not None: |
|
if not isinstance(self.out_project, ConvModule): |
|
constant_init(self.out_project, 0) |
|
|
|
def build_project(self, in_channels, channels, num_convs, use_conv_module, |
|
conv_cfg, norm_cfg, act_cfg): |
|
"""Build projection layer for key/query/value/out.""" |
|
if use_conv_module: |
|
convs = [ |
|
ConvModule( |
|
in_channels, |
|
channels, |
|
1, |
|
conv_cfg=conv_cfg, |
|
norm_cfg=norm_cfg, |
|
act_cfg=act_cfg) |
|
] |
|
for _ in range(num_convs - 1): |
|
convs.append( |
|
ConvModule( |
|
channels, |
|
channels, |
|
1, |
|
conv_cfg=conv_cfg, |
|
norm_cfg=norm_cfg, |
|
act_cfg=act_cfg)) |
|
else: |
|
convs = [nn.Conv2d(in_channels, channels, 1)] |
|
for _ in range(num_convs - 1): |
|
convs.append(nn.Conv2d(channels, channels, 1)) |
|
if len(convs) > 1: |
|
convs = nn.Sequential(*convs) |
|
else: |
|
convs = convs[0] |
|
return convs |
|
|
|
def forward(self, query_feats, key_feats): |
|
"""Forward function.""" |
|
batch_size = query_feats.size(0) |
|
query = self.query_project(query_feats) |
|
if self.query_downsample is not None: |
|
query = self.query_downsample(query) |
|
query = query.reshape(*query.shape[:2], -1) |
|
query = query.permute(0, 2, 1).contiguous() |
|
|
|
key = self.key_project(key_feats) |
|
value = self.value_project(key_feats) |
|
if self.key_downsample is not None: |
|
key = self.key_downsample(key) |
|
value = self.key_downsample(value) |
|
key = key.reshape(*key.shape[:2], -1) |
|
value = value.reshape(*value.shape[:2], -1) |
|
value = value.permute(0, 2, 1).contiguous() |
|
|
|
sim_map = torch.matmul(query, key) |
|
if self.matmul_norm: |
|
sim_map = (self.channels**-.5) * sim_map |
|
sim_map = F.softmax(sim_map, dim=-1) |
|
|
|
context = torch.matmul(sim_map, value) |
|
context = context.permute(0, 2, 1).contiguous() |
|
context = context.reshape(batch_size, -1, *query_feats.shape[2:]) |
|
if self.out_project is not None: |
|
context = self.out_project(context) |
|
return context |
|
|