Spaces:
Configuration error
Configuration error
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import paddle | |
import paddle.nn as nn | |
import paddle.nn.functional as F | |
from paddleseg.models import layers | |
class AttentionBlock(nn.Layer): | |
"""General self-attention block/non-local block. | |
The original article refers to refer to https://arxiv.org/abs/1706.03762. | |
Args: | |
key_in_channels (int): Input channels of key feature. | |
query_in_channels (int): Input channels of query feature. | |
channels (int): Output channels of key/query transform. | |
out_channels (int): Output channels. | |
share_key_query (bool): Whether share projection weight between key | |
and query projection. | |
query_downsample (nn.Module): Query downsample module. | |
key_downsample (nn.Module): Key downsample module. | |
key_query_num_convs (int): Number of convs for key/query projection. | |
value_out_num_convs (int): Number of convs for value projection. | |
key_query_norm (bool): Whether to use BN for key/query projection. | |
value_out_norm (bool): Whether to use BN for value projection. | |
matmul_norm (bool): Whether normalize attention map with sqrt of | |
channels | |
with_out (bool): Whether use out projection. | |
""" | |
def __init__(self, key_in_channels, query_in_channels, channels, | |
out_channels, share_key_query, query_downsample, | |
key_downsample, key_query_num_convs, value_out_num_convs, | |
key_query_norm, value_out_norm, matmul_norm, with_out): | |
super(AttentionBlock, self).__init__() | |
if share_key_query: | |
assert key_in_channels == query_in_channels | |
self.with_out = with_out | |
self.key_in_channels = key_in_channels | |
self.query_in_channels = query_in_channels | |
self.out_channels = out_channels | |
self.channels = channels | |
self.share_key_query = share_key_query | |
self.key_project = self.build_project( | |
key_in_channels, | |
channels, | |
num_convs=key_query_num_convs, | |
use_conv_module=key_query_norm) | |
if share_key_query: | |
self.query_project = self.key_project | |
else: | |
self.query_project = self.build_project( | |
query_in_channels, | |
channels, | |
num_convs=key_query_num_convs, | |
use_conv_module=key_query_norm) | |
self.value_project = self.build_project( | |
key_in_channels, | |
channels if self.with_out else out_channels, | |
num_convs=value_out_num_convs, | |
use_conv_module=value_out_norm) | |
if self.with_out: | |
self.out_project = self.build_project( | |
channels, | |
out_channels, | |
num_convs=value_out_num_convs, | |
use_conv_module=value_out_norm) | |
else: | |
self.out_project = None | |
self.query_downsample = query_downsample | |
self.key_downsample = key_downsample | |
self.matmul_norm = matmul_norm | |
def build_project(self, in_channels, channels, num_convs, use_conv_module): | |
if use_conv_module: | |
convs = [ | |
layers.ConvBNReLU( | |
in_channels=in_channels, | |
out_channels=channels, | |
kernel_size=1, | |
bias_attr=False) | |
] | |
for _ in range(num_convs - 1): | |
convs.append( | |
layers.ConvBNReLU( | |
in_channels=channels, | |
out_channels=channels, | |
kernel_size=1, | |
bias_attr=False)) | |
else: | |
convs = [nn.Conv2D(in_channels, channels, 1)] | |
for _ in range(num_convs - 1): | |
convs.append(nn.Conv2D(channels, channels, 1)) | |
if len(convs) > 1: | |
convs = nn.Sequential(*convs) | |
else: | |
convs = convs[0] | |
return convs | |
def forward(self, query_feats, key_feats): | |
query_shape = paddle.shape(query_feats) | |
query = self.query_project(query_feats) | |
if self.query_downsample is not None: | |
query = self.query_downsample(query) | |
query = query.flatten(2).transpose([0, 2, 1]) | |
key = self.key_project(key_feats) | |
value = self.value_project(key_feats) | |
if self.key_downsample is not None: | |
key = self.key_downsample(key) | |
value = self.key_downsample(value) | |
key = key.flatten(2) | |
value = value.flatten(2).transpose([0, 2, 1]) | |
sim_map = paddle.matmul(query, key) | |
if self.matmul_norm: | |
sim_map = (self.channels**-0.5) * sim_map | |
sim_map = F.softmax(sim_map, axis=-1) | |
context = paddle.matmul(sim_map, value) | |
context = paddle.transpose(context, [0, 2, 1]) | |
context = paddle.reshape( | |
context, [0, self.out_channels, query_shape[2], query_shape[3]]) | |
if self.out_project is not None: | |
context = self.out_project(context) | |
return context | |
class DualAttentionModule(nn.Layer): | |
""" | |
Dual attention module. | |
Args: | |
in_channels (int): The number of input channels. | |
out_channels (int): The number of output channels. | |
""" | |
def __init__(self, in_channels, out_channels): | |
super().__init__() | |
inter_channels = in_channels // 4 | |
self.channel_conv = layers.ConvBNReLU(in_channels, inter_channels, 1) | |
self.position_conv = layers.ConvBNReLU(in_channels, inter_channels, 1) | |
self.pam = PAM(inter_channels) | |
self.cam = CAM(inter_channels) | |
self.conv1 = layers.ConvBNReLU(inter_channels, inter_channels, 3) | |
self.conv2 = layers.ConvBNReLU(inter_channels, inter_channels, 3) | |
self.conv3 = layers.ConvBNReLU(inter_channels, out_channels, 3) | |
def forward(self, feats): | |
channel_feats = self.channel_conv(feats) | |
channel_feats = self.cam(channel_feats) | |
channel_feats = self.conv1(channel_feats) | |
position_feats = self.position_conv(feats) | |
position_feats = self.pam(position_feats) | |
position_feats = self.conv2(position_feats) | |
feats_sum = position_feats + channel_feats | |
out = self.conv3(feats_sum) | |
return out | |
class PAM(nn.Layer): | |
""" | |
Position attention module. | |
Args: | |
in_channels (int): The number of input channels. | |
""" | |
def __init__(self, in_channels): | |
super().__init__() | |
mid_channels = in_channels // 8 | |
self.mid_channels = mid_channels | |
self.in_channels = in_channels | |
self.query_conv = nn.Conv2D(in_channels, mid_channels, 1, 1) | |
self.key_conv = nn.Conv2D(in_channels, mid_channels, 1, 1) | |
self.value_conv = nn.Conv2D(in_channels, in_channels, 1, 1) | |
self.gamma = self.create_parameter( | |
shape=[1], | |
dtype='float32', | |
default_initializer=nn.initializer.Constant(0)) | |
def forward(self, x): | |
x_shape = paddle.shape(x) | |
# query: n, h * w, c1 | |
query = self.query_conv(x) | |
query = paddle.reshape(query, (0, self.mid_channels, -1)) | |
query = paddle.transpose(query, (0, 2, 1)) | |
# key: n, c1, h * w | |
key = self.key_conv(x) | |
key = paddle.reshape(key, (0, self.mid_channels, -1)) | |
# sim: n, h * w, h * w | |
sim = paddle.bmm(query, key) | |
sim = F.softmax(sim, axis=-1) | |
value = self.value_conv(x) | |
value = paddle.reshape(value, (0, self.in_channels, -1)) | |
sim = paddle.transpose(sim, (0, 2, 1)) | |
# feat: from (n, c2, h * w) -> (n, c2, h, w) | |
feat = paddle.bmm(value, sim) | |
feat = paddle.reshape(feat, | |
(0, self.in_channels, x_shape[2], x_shape[3])) | |
out = self.gamma * feat + x | |
return out | |
class CAM(nn.Layer): | |
""" | |
Channel attention module. | |
Args: | |
in_channels (int): The number of input channels. | |
""" | |
def __init__(self, channels): | |
super().__init__() | |
self.channels = channels | |
self.gamma = self.create_parameter( | |
shape=[1], | |
dtype='float32', | |
default_initializer=nn.initializer.Constant(0)) | |
def forward(self, x): | |
x_shape = paddle.shape(x) | |
# query: n, c, h * w | |
query = paddle.reshape(x, (0, self.channels, -1)) | |
# key: n, h * w, c | |
key = paddle.reshape(x, (0, self.channels, -1)) | |
key = paddle.transpose(key, (0, 2, 1)) | |
# sim: n, c, c | |
sim = paddle.bmm(query, key) | |
# The danet author claims that this can avoid gradient divergence | |
sim = paddle.max(sim, axis=-1, keepdim=True).tile( | |
[1, 1, self.channels]) - sim | |
sim = F.softmax(sim, axis=-1) | |
# feat: from (n, c, h * w) to (n, c, h, w) | |
value = paddle.reshape(x, (0, self.channels, -1)) | |
feat = paddle.bmm(sim, value) | |
feat = paddle.reshape(feat, (0, self.channels, x_shape[2], x_shape[3])) | |
out = self.gamma * feat + x | |
return out | |