File size: 6,125 Bytes
b13b124 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 |
import torch
from mmcv.cnn import ConvModule, constant_init
from torch import nn as nn
from torch.nn import functional as F
class SelfAttentionBlock(nn.Module):
"""General self-attention block/non-local block.
Please refer to https://arxiv.org/abs/1706.03762 for details about key,
query and value.
Args:
key_in_channels (int): Input channels of key feature.
query_in_channels (int): Input channels of query feature.
channels (int): Output channels of key/query transform.
out_channels (int): Output channels.
share_key_query (bool): Whether share projection weight between key
and query projection.
query_downsample (nn.Module): Query downsample module.
key_downsample (nn.Module): Key downsample module.
key_query_num_convs (int): Number of convs for key/query projection.
value_num_convs (int): Number of convs for value projection.
matmul_norm (bool): Whether normalize attention map with sqrt of
channels
with_out (bool): Whether use out projection.
conv_cfg (dict|None): Config of conv layers.
norm_cfg (dict|None): Config of norm layers.
act_cfg (dict|None): Config of activation layers.
"""
def __init__(self, key_in_channels, query_in_channels, channels,
out_channels, share_key_query, query_downsample,
key_downsample, key_query_num_convs, value_out_num_convs,
key_query_norm, value_out_norm, matmul_norm, with_out,
conv_cfg, norm_cfg, act_cfg):
super(SelfAttentionBlock, self).__init__()
if share_key_query:
assert key_in_channels == query_in_channels
self.key_in_channels = key_in_channels
self.query_in_channels = query_in_channels
self.out_channels = out_channels
self.channels = channels
self.share_key_query = share_key_query
self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg
self.act_cfg = act_cfg
self.key_project = self.build_project(
key_in_channels,
channels,
num_convs=key_query_num_convs,
use_conv_module=key_query_norm,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg)
if share_key_query:
self.query_project = self.key_project
else:
self.query_project = self.build_project(
query_in_channels,
channels,
num_convs=key_query_num_convs,
use_conv_module=key_query_norm,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg)
self.value_project = self.build_project(
key_in_channels,
channels if with_out else out_channels,
num_convs=value_out_num_convs,
use_conv_module=value_out_norm,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg)
if with_out:
self.out_project = self.build_project(
channels,
out_channels,
num_convs=value_out_num_convs,
use_conv_module=value_out_norm,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg)
else:
self.out_project = None
self.query_downsample = query_downsample
self.key_downsample = key_downsample
self.matmul_norm = matmul_norm
self.init_weights()
def init_weights(self):
"""Initialize weight of later layer."""
if self.out_project is not None:
if not isinstance(self.out_project, ConvModule):
constant_init(self.out_project, 0)
def build_project(self, in_channels, channels, num_convs, use_conv_module,
conv_cfg, norm_cfg, act_cfg):
"""Build projection layer for key/query/value/out."""
if use_conv_module:
convs = [
ConvModule(
in_channels,
channels,
1,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg)
]
for _ in range(num_convs - 1):
convs.append(
ConvModule(
channels,
channels,
1,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg))
else:
convs = [nn.Conv2d(in_channels, channels, 1)]
for _ in range(num_convs - 1):
convs.append(nn.Conv2d(channels, channels, 1))
if len(convs) > 1:
convs = nn.Sequential(*convs)
else:
convs = convs[0]
return convs
def forward(self, query_feats, key_feats):
"""Forward function."""
batch_size = query_feats.size(0)
query = self.query_project(query_feats)
if self.query_downsample is not None:
query = self.query_downsample(query)
query = query.reshape(*query.shape[:2], -1)
query = query.permute(0, 2, 1).contiguous()
key = self.key_project(key_feats)
value = self.value_project(key_feats)
if self.key_downsample is not None:
key = self.key_downsample(key)
value = self.key_downsample(value)
key = key.reshape(*key.shape[:2], -1)
value = value.reshape(*value.shape[:2], -1)
value = value.permute(0, 2, 1).contiguous()
sim_map = torch.matmul(query, key)
if self.matmul_norm:
sim_map = (self.channels**-.5) * sim_map
sim_map = F.softmax(sim_map, dim=-1)
context = torch.matmul(sim_map, value)
context = context.permute(0, 2, 1).contiguous()
context = context.reshape(batch_size, -1, *query_feats.shape[2:])
if self.out_project is not None:
context = self.out_project(context)
return context
|