deanna-emery's picture
updates
93528c6
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Layers for DeepLabV3."""
import tensorflow as tf, tf_keras
from official.modeling import tf_utils
class SpatialPyramidPooling(tf_keras.layers.Layer):
"""Implements the Atrous Spatial Pyramid Pooling.
References:
[Rethinking Atrous Convolution for Semantic Image Segmentation](
https://arxiv.org/pdf/1706.05587.pdf)
[Encoder-Decoder with Atrous Separable Convolution for Semantic Image
Segmentation](https://arxiv.org/pdf/1802.02611.pdf)
"""
def __init__(
self,
output_channels,
dilation_rates,
pool_kernel_size=None,
use_sync_bn=False,
batchnorm_momentum=0.99,
batchnorm_epsilon=0.001,
activation='relu',
dropout=0.5,
kernel_initializer='glorot_uniform',
kernel_regularizer=None,
interpolation='bilinear',
use_depthwise_convolution=False,
**kwargs):
"""Initializes `SpatialPyramidPooling`.
Args:
output_channels: Number of channels produced by SpatialPyramidPooling.
dilation_rates: A list of integers for parallel dilated conv.
pool_kernel_size: A list of integers or None. If None, global average
pooling is applied, otherwise an average pooling of pool_kernel_size
is applied.
use_sync_bn: A bool, whether or not to use sync batch normalization.
batchnorm_momentum: A float for the momentum in BatchNorm. Defaults to
0.99.
batchnorm_epsilon: A float for the epsilon value in BatchNorm. Defaults to
0.001.
activation: A `str` for type of activation to be used. Defaults to 'relu'.
dropout: A float for the dropout rate before output. Defaults to 0.5.
kernel_initializer: Kernel initializer for conv layers. Defaults to
`glorot_uniform`.
kernel_regularizer: Kernel regularizer for conv layers. Defaults to None.
interpolation: The interpolation method for upsampling. Defaults to
`bilinear`.
use_depthwise_convolution: Allows spatial pooling to be separable
depthwise convolusions. [Encoder-Decoder with Atrous Separable
Convolution for Semantic Image Segmentation](
https://arxiv.org/pdf/1802.02611.pdf)
**kwargs: Other keyword arguments for the layer.
"""
super(SpatialPyramidPooling, self).__init__(**kwargs)
self.output_channels = output_channels
self.dilation_rates = dilation_rates
self.use_sync_bn = use_sync_bn
self.batchnorm_momentum = batchnorm_momentum
self.batchnorm_epsilon = batchnorm_epsilon
self.activation = activation
self.dropout = dropout
self.kernel_initializer = tf_keras.initializers.get(kernel_initializer)
self.kernel_regularizer = tf_keras.regularizers.get(kernel_regularizer)
self.interpolation = interpolation
self.input_spec = tf_keras.layers.InputSpec(ndim=4)
self.pool_kernel_size = pool_kernel_size
self.use_depthwise_convolution = use_depthwise_convolution
def build(self, input_shape):
channels = input_shape[3]
self.aspp_layers = []
bn_op = tf_keras.layers.BatchNormalization
if tf_keras.backend.image_data_format() == 'channels_last':
bn_axis = -1
else:
bn_axis = 1
conv_sequential = tf_keras.Sequential([
tf_keras.layers.Conv2D(
filters=self.output_channels,
kernel_size=(1, 1),
kernel_initializer=tf_utils.clone_initializer(
self.kernel_initializer),
kernel_regularizer=self.kernel_regularizer,
use_bias=False),
bn_op(
axis=bn_axis,
momentum=self.batchnorm_momentum,
epsilon=self.batchnorm_epsilon,
synchronized=self.use_sync_bn),
tf_keras.layers.Activation(self.activation)
])
self.aspp_layers.append(conv_sequential)
for dilation_rate in self.dilation_rates:
leading_layers = []
kernel_size = (3, 3)
if self.use_depthwise_convolution:
leading_layers += [
tf_keras.layers.DepthwiseConv2D(
depth_multiplier=1,
kernel_size=kernel_size,
padding='same',
dilation_rate=dilation_rate,
use_bias=False)
]
kernel_size = (1, 1)
conv_sequential = tf_keras.Sequential(leading_layers + [
tf_keras.layers.Conv2D(
filters=self.output_channels,
kernel_size=kernel_size,
padding='same',
kernel_regularizer=self.kernel_regularizer,
kernel_initializer=tf_utils.clone_initializer(
self.kernel_initializer),
dilation_rate=dilation_rate,
use_bias=False),
bn_op(
axis=bn_axis,
momentum=self.batchnorm_momentum,
epsilon=self.batchnorm_epsilon,
synchronized=self.use_sync_bn),
tf_keras.layers.Activation(self.activation)
])
self.aspp_layers.append(conv_sequential)
if self.pool_kernel_size is None:
pool_sequential = tf_keras.Sequential([
tf_keras.layers.GlobalAveragePooling2D(),
tf_keras.layers.Reshape((1, 1, channels))])
else:
pool_sequential = tf_keras.Sequential([
tf_keras.layers.AveragePooling2D(self.pool_kernel_size)])
pool_sequential.add(
tf_keras.Sequential([
tf_keras.layers.Conv2D(
filters=self.output_channels,
kernel_size=(1, 1),
kernel_initializer=tf_utils.clone_initializer(
self.kernel_initializer),
kernel_regularizer=self.kernel_regularizer,
use_bias=False),
bn_op(
axis=bn_axis,
momentum=self.batchnorm_momentum,
epsilon=self.batchnorm_epsilon,
synchronized=self.use_sync_bn),
tf_keras.layers.Activation(self.activation)
]))
self.aspp_layers.append(pool_sequential)
self.projection = tf_keras.Sequential([
tf_keras.layers.Conv2D(
filters=self.output_channels,
kernel_size=(1, 1),
kernel_initializer=tf_utils.clone_initializer(
self.kernel_initializer),
kernel_regularizer=self.kernel_regularizer,
use_bias=False),
bn_op(
axis=bn_axis,
momentum=self.batchnorm_momentum,
epsilon=self.batchnorm_epsilon,
synchronized=self.use_sync_bn),
tf_keras.layers.Activation(self.activation),
tf_keras.layers.Dropout(rate=self.dropout)
])
def call(self, inputs, training=None):
if training is None:
training = tf_keras.backend.learning_phase()
result = []
for i, layer in enumerate(self.aspp_layers):
x = layer(inputs, training=training)
# Apply resize layer to the end of the last set of layers.
if i == len(self.aspp_layers) - 1:
x = tf.image.resize(tf.cast(x, tf.float32), tf.shape(inputs)[1:3])
result.append(tf.cast(x, inputs.dtype))
result = tf.concat(result, axis=-1)
result = self.projection(result, training=training)
return result
def get_config(self):
config = {
'output_channels': self.output_channels,
'dilation_rates': self.dilation_rates,
'pool_kernel_size': self.pool_kernel_size,
'use_sync_bn': self.use_sync_bn,
'batchnorm_momentum': self.batchnorm_momentum,
'batchnorm_epsilon': self.batchnorm_epsilon,
'activation': self.activation,
'dropout': self.dropout,
'kernel_initializer': tf_keras.initializers.serialize(
self.kernel_initializer),
'kernel_regularizer': tf_keras.regularizers.serialize(
self.kernel_regularizer),
'interpolation': self.interpolation,
}
base_config = super(SpatialPyramidPooling, self).get_config()
return dict(list(base_config.items()) + list(config.items()))