# Copyright 2023 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Layers for DeepLabV3.""" import tensorflow as tf, tf_keras from official.modeling import tf_utils class SpatialPyramidPooling(tf_keras.layers.Layer): """Implements the Atrous Spatial Pyramid Pooling. References: [Rethinking Atrous Convolution for Semantic Image Segmentation]( https://arxiv.org/pdf/1706.05587.pdf) [Encoder-Decoder with Atrous Separable Convolution for Semantic Image Segmentation](https://arxiv.org/pdf/1802.02611.pdf) """ def __init__( self, output_channels, dilation_rates, pool_kernel_size=None, use_sync_bn=False, batchnorm_momentum=0.99, batchnorm_epsilon=0.001, activation='relu', dropout=0.5, kernel_initializer='glorot_uniform', kernel_regularizer=None, interpolation='bilinear', use_depthwise_convolution=False, **kwargs): """Initializes `SpatialPyramidPooling`. Args: output_channels: Number of channels produced by SpatialPyramidPooling. dilation_rates: A list of integers for parallel dilated conv. pool_kernel_size: A list of integers or None. If None, global average pooling is applied, otherwise an average pooling of pool_kernel_size is applied. use_sync_bn: A bool, whether or not to use sync batch normalization. batchnorm_momentum: A float for the momentum in BatchNorm. Defaults to 0.99. batchnorm_epsilon: A float for the epsilon value in BatchNorm. Defaults to 0.001. activation: A `str` for type of activation to be used. Defaults to 'relu'. dropout: A float for the dropout rate before output. Defaults to 0.5. kernel_initializer: Kernel initializer for conv layers. Defaults to `glorot_uniform`. kernel_regularizer: Kernel regularizer for conv layers. Defaults to None. interpolation: The interpolation method for upsampling. Defaults to `bilinear`. use_depthwise_convolution: Allows spatial pooling to be separable depthwise convolusions. [Encoder-Decoder with Atrous Separable Convolution for Semantic Image Segmentation]( https://arxiv.org/pdf/1802.02611.pdf) **kwargs: Other keyword arguments for the layer. """ super(SpatialPyramidPooling, self).__init__(**kwargs) self.output_channels = output_channels self.dilation_rates = dilation_rates self.use_sync_bn = use_sync_bn self.batchnorm_momentum = batchnorm_momentum self.batchnorm_epsilon = batchnorm_epsilon self.activation = activation self.dropout = dropout self.kernel_initializer = tf_keras.initializers.get(kernel_initializer) self.kernel_regularizer = tf_keras.regularizers.get(kernel_regularizer) self.interpolation = interpolation self.input_spec = tf_keras.layers.InputSpec(ndim=4) self.pool_kernel_size = pool_kernel_size self.use_depthwise_convolution = use_depthwise_convolution def build(self, input_shape): channels = input_shape[3] self.aspp_layers = [] bn_op = tf_keras.layers.BatchNormalization if tf_keras.backend.image_data_format() == 'channels_last': bn_axis = -1 else: bn_axis = 1 conv_sequential = tf_keras.Sequential([ tf_keras.layers.Conv2D( filters=self.output_channels, kernel_size=(1, 1), kernel_initializer=tf_utils.clone_initializer( self.kernel_initializer), kernel_regularizer=self.kernel_regularizer, use_bias=False), bn_op( axis=bn_axis, momentum=self.batchnorm_momentum, epsilon=self.batchnorm_epsilon, synchronized=self.use_sync_bn), tf_keras.layers.Activation(self.activation) ]) self.aspp_layers.append(conv_sequential) for dilation_rate in self.dilation_rates: leading_layers = [] kernel_size = (3, 3) if self.use_depthwise_convolution: leading_layers += [ tf_keras.layers.DepthwiseConv2D( depth_multiplier=1, kernel_size=kernel_size, padding='same', dilation_rate=dilation_rate, use_bias=False) ] kernel_size = (1, 1) conv_sequential = tf_keras.Sequential(leading_layers + [ tf_keras.layers.Conv2D( filters=self.output_channels, kernel_size=kernel_size, padding='same', kernel_regularizer=self.kernel_regularizer, kernel_initializer=tf_utils.clone_initializer( self.kernel_initializer), dilation_rate=dilation_rate, use_bias=False), bn_op( axis=bn_axis, momentum=self.batchnorm_momentum, epsilon=self.batchnorm_epsilon, synchronized=self.use_sync_bn), tf_keras.layers.Activation(self.activation) ]) self.aspp_layers.append(conv_sequential) if self.pool_kernel_size is None: pool_sequential = tf_keras.Sequential([ tf_keras.layers.GlobalAveragePooling2D(), tf_keras.layers.Reshape((1, 1, channels))]) else: pool_sequential = tf_keras.Sequential([ tf_keras.layers.AveragePooling2D(self.pool_kernel_size)]) pool_sequential.add( tf_keras.Sequential([ tf_keras.layers.Conv2D( filters=self.output_channels, kernel_size=(1, 1), kernel_initializer=tf_utils.clone_initializer( self.kernel_initializer), kernel_regularizer=self.kernel_regularizer, use_bias=False), bn_op( axis=bn_axis, momentum=self.batchnorm_momentum, epsilon=self.batchnorm_epsilon, synchronized=self.use_sync_bn), tf_keras.layers.Activation(self.activation) ])) self.aspp_layers.append(pool_sequential) self.projection = tf_keras.Sequential([ tf_keras.layers.Conv2D( filters=self.output_channels, kernel_size=(1, 1), kernel_initializer=tf_utils.clone_initializer( self.kernel_initializer), kernel_regularizer=self.kernel_regularizer, use_bias=False), bn_op( axis=bn_axis, momentum=self.batchnorm_momentum, epsilon=self.batchnorm_epsilon, synchronized=self.use_sync_bn), tf_keras.layers.Activation(self.activation), tf_keras.layers.Dropout(rate=self.dropout) ]) def call(self, inputs, training=None): if training is None: training = tf_keras.backend.learning_phase() result = [] for i, layer in enumerate(self.aspp_layers): x = layer(inputs, training=training) # Apply resize layer to the end of the last set of layers. if i == len(self.aspp_layers) - 1: x = tf.image.resize(tf.cast(x, tf.float32), tf.shape(inputs)[1:3]) result.append(tf.cast(x, inputs.dtype)) result = tf.concat(result, axis=-1) result = self.projection(result, training=training) return result def get_config(self): config = { 'output_channels': self.output_channels, 'dilation_rates': self.dilation_rates, 'pool_kernel_size': self.pool_kernel_size, 'use_sync_bn': self.use_sync_bn, 'batchnorm_momentum': self.batchnorm_momentum, 'batchnorm_epsilon': self.batchnorm_epsilon, 'activation': self.activation, 'dropout': self.dropout, 'kernel_initializer': tf_keras.initializers.serialize( self.kernel_initializer), 'kernel_regularizer': tf_keras.regularizers.serialize( self.kernel_regularizer), 'interpolation': self.interpolation, } base_config = super(SpatialPyramidPooling, self).get_config() return dict(list(base_config.items()) + list(config.items()))