sidharthism's picture
Added model *.pdparams
1ab1a09
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from paddleseg.models import layers
class AttentionBlock(nn.Layer):
"""General self-attention block/non-local block.
The original article refers to refer to https://arxiv.org/abs/1706.03762.
Args:
key_in_channels (int): Input channels of key feature.
query_in_channels (int): Input channels of query feature.
channels (int): Output channels of key/query transform.
out_channels (int): Output channels.
share_key_query (bool): Whether share projection weight between key
and query projection.
query_downsample (nn.Module): Query downsample module.
key_downsample (nn.Module): Key downsample module.
key_query_num_convs (int): Number of convs for key/query projection.
value_out_num_convs (int): Number of convs for value projection.
key_query_norm (bool): Whether to use BN for key/query projection.
value_out_norm (bool): Whether to use BN for value projection.
matmul_norm (bool): Whether normalize attention map with sqrt of
channels
with_out (bool): Whether use out projection.
"""
def __init__(self, key_in_channels, query_in_channels, channels,
out_channels, share_key_query, query_downsample,
key_downsample, key_query_num_convs, value_out_num_convs,
key_query_norm, value_out_norm, matmul_norm, with_out):
super(AttentionBlock, self).__init__()
if share_key_query:
assert key_in_channels == query_in_channels
self.with_out = with_out
self.key_in_channels = key_in_channels
self.query_in_channels = query_in_channels
self.out_channels = out_channels
self.channels = channels
self.share_key_query = share_key_query
self.key_project = self.build_project(
key_in_channels,
channels,
num_convs=key_query_num_convs,
use_conv_module=key_query_norm)
if share_key_query:
self.query_project = self.key_project
else:
self.query_project = self.build_project(
query_in_channels,
channels,
num_convs=key_query_num_convs,
use_conv_module=key_query_norm)
self.value_project = self.build_project(
key_in_channels,
channels if self.with_out else out_channels,
num_convs=value_out_num_convs,
use_conv_module=value_out_norm)
if self.with_out:
self.out_project = self.build_project(
channels,
out_channels,
num_convs=value_out_num_convs,
use_conv_module=value_out_norm)
else:
self.out_project = None
self.query_downsample = query_downsample
self.key_downsample = key_downsample
self.matmul_norm = matmul_norm
def build_project(self, in_channels, channels, num_convs, use_conv_module):
if use_conv_module:
convs = [
layers.ConvBNReLU(
in_channels=in_channels,
out_channels=channels,
kernel_size=1,
bias_attr=False)
]
for _ in range(num_convs - 1):
convs.append(
layers.ConvBNReLU(
in_channels=channels,
out_channels=channels,
kernel_size=1,
bias_attr=False))
else:
convs = [nn.Conv2D(in_channels, channels, 1)]
for _ in range(num_convs - 1):
convs.append(nn.Conv2D(channels, channels, 1))
if len(convs) > 1:
convs = nn.Sequential(*convs)
else:
convs = convs[0]
return convs
def forward(self, query_feats, key_feats):
query_shape = paddle.shape(query_feats)
query = self.query_project(query_feats)
if self.query_downsample is not None:
query = self.query_downsample(query)
query = query.flatten(2).transpose([0, 2, 1])
key = self.key_project(key_feats)
value = self.value_project(key_feats)
if self.key_downsample is not None:
key = self.key_downsample(key)
value = self.key_downsample(value)
key = key.flatten(2)
value = value.flatten(2).transpose([0, 2, 1])
sim_map = paddle.matmul(query, key)
if self.matmul_norm:
sim_map = (self.channels**-0.5) * sim_map
sim_map = F.softmax(sim_map, axis=-1)
context = paddle.matmul(sim_map, value)
context = paddle.transpose(context, [0, 2, 1])
context = paddle.reshape(
context, [0, self.out_channels, query_shape[2], query_shape[3]])
if self.out_project is not None:
context = self.out_project(context)
return context
class DualAttentionModule(nn.Layer):
"""
Dual attention module.
Args:
in_channels (int): The number of input channels.
out_channels (int): The number of output channels.
"""
def __init__(self, in_channels, out_channels):
super().__init__()
inter_channels = in_channels // 4
self.channel_conv = layers.ConvBNReLU(in_channels, inter_channels, 1)
self.position_conv = layers.ConvBNReLU(in_channels, inter_channels, 1)
self.pam = PAM(inter_channels)
self.cam = CAM(inter_channels)
self.conv1 = layers.ConvBNReLU(inter_channels, inter_channels, 3)
self.conv2 = layers.ConvBNReLU(inter_channels, inter_channels, 3)
self.conv3 = layers.ConvBNReLU(inter_channels, out_channels, 3)
def forward(self, feats):
channel_feats = self.channel_conv(feats)
channel_feats = self.cam(channel_feats)
channel_feats = self.conv1(channel_feats)
position_feats = self.position_conv(feats)
position_feats = self.pam(position_feats)
position_feats = self.conv2(position_feats)
feats_sum = position_feats + channel_feats
out = self.conv3(feats_sum)
return out
class PAM(nn.Layer):
"""
Position attention module.
Args:
in_channels (int): The number of input channels.
"""
def __init__(self, in_channels):
super().__init__()
mid_channels = in_channels // 8
self.mid_channels = mid_channels
self.in_channels = in_channels
self.query_conv = nn.Conv2D(in_channels, mid_channels, 1, 1)
self.key_conv = nn.Conv2D(in_channels, mid_channels, 1, 1)
self.value_conv = nn.Conv2D(in_channels, in_channels, 1, 1)
self.gamma = self.create_parameter(
shape=[1],
dtype='float32',
default_initializer=nn.initializer.Constant(0))
def forward(self, x):
x_shape = paddle.shape(x)
# query: n, h * w, c1
query = self.query_conv(x)
query = paddle.reshape(query, (0, self.mid_channels, -1))
query = paddle.transpose(query, (0, 2, 1))
# key: n, c1, h * w
key = self.key_conv(x)
key = paddle.reshape(key, (0, self.mid_channels, -1))
# sim: n, h * w, h * w
sim = paddle.bmm(query, key)
sim = F.softmax(sim, axis=-1)
value = self.value_conv(x)
value = paddle.reshape(value, (0, self.in_channels, -1))
sim = paddle.transpose(sim, (0, 2, 1))
# feat: from (n, c2, h * w) -> (n, c2, h, w)
feat = paddle.bmm(value, sim)
feat = paddle.reshape(feat,
(0, self.in_channels, x_shape[2], x_shape[3]))
out = self.gamma * feat + x
return out
class CAM(nn.Layer):
"""
Channel attention module.
Args:
in_channels (int): The number of input channels.
"""
def __init__(self, channels):
super().__init__()
self.channels = channels
self.gamma = self.create_parameter(
shape=[1],
dtype='float32',
default_initializer=nn.initializer.Constant(0))
def forward(self, x):
x_shape = paddle.shape(x)
# query: n, c, h * w
query = paddle.reshape(x, (0, self.channels, -1))
# key: n, h * w, c
key = paddle.reshape(x, (0, self.channels, -1))
key = paddle.transpose(key, (0, 2, 1))
# sim: n, c, c
sim = paddle.bmm(query, key)
# The danet author claims that this can avoid gradient divergence
sim = paddle.max(sim, axis=-1, keepdim=True).tile(
[1, 1, self.channels]) - sim
sim = F.softmax(sim, axis=-1)
# feat: from (n, c, h * w) to (n, c, h, w)
value = paddle.reshape(x, (0, self.channels, -1))
feat = paddle.bmm(sim, value)
feat = paddle.reshape(feat, (0, self.channels, x_shape[2], x_shape[3]))
out = self.gamma * feat + x
return out