deanna-emery's picture
updates
93528c6
raw
history blame
20.2 kB
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Contains definitions of segmentation heads."""
from typing import List, Union, Optional, Mapping, Tuple, Any
import tensorflow as tf, tf_keras
from official.modeling import tf_utils
from official.vision.modeling.layers import nn_layers
from official.vision.ops import spatial_transform_ops
class MaskScoring(tf_keras.Model):
"""Creates a mask scoring layer.
This implements mask scoring layer from the paper:
Zhaojin Huang, Lichao Huang, Yongchao Gong, Chang Huang, Xinggang Wang.
Mask Scoring R-CNN.
(https://arxiv.org/pdf/1903.00241.pdf)
"""
def __init__(
self,
num_classes: int,
fc_input_size: List[int],
num_convs: int = 3,
num_filters: int = 256,
use_depthwise_convolution: bool = False,
fc_dims: int = 1024,
num_fcs: int = 2,
activation: str = 'relu',
use_sync_bn: bool = False,
norm_momentum: float = 0.99,
norm_epsilon: float = 0.001,
kernel_regularizer: Optional[tf_keras.regularizers.Regularizer] = None,
bias_regularizer: Optional[tf_keras.regularizers.Regularizer] = None,
**kwargs):
"""Initializes mask scoring layer.
Args:
num_classes: An `int` for number of classes.
fc_input_size: A List of `int` for the input size of the
fully connected layers.
num_convs: An`int` for number of conv layers.
num_filters: An `int` for the number of filters for conv layers.
use_depthwise_convolution: A `bool`, whether or not using depthwise convs.
fc_dims: An `int` number of filters for each fully connected layers.
num_fcs: An `int` for number of fully connected layers.
activation: A `str` name of the activation function.
use_sync_bn: A bool, whether or not to use sync batch normalization.
norm_momentum: A float for the momentum in BatchNorm. Defaults to 0.99.
norm_epsilon: A float for the epsilon value in BatchNorm. Defaults to
0.001.
kernel_regularizer: A `tf_keras.regularizers.Regularizer` object for
Conv2D. Default is None.
bias_regularizer: A `tf_keras.regularizers.Regularizer` object for Conv2D.
**kwargs: Additional keyword arguments to be passed.
"""
super(MaskScoring, self).__init__(**kwargs)
self._config_dict = {
'num_classes': num_classes,
'num_convs': num_convs,
'num_filters': num_filters,
'fc_input_size': fc_input_size,
'fc_dims': fc_dims,
'num_fcs': num_fcs,
'use_sync_bn': use_sync_bn,
'use_depthwise_convolution': use_depthwise_convolution,
'norm_momentum': norm_momentum,
'norm_epsilon': norm_epsilon,
'activation': activation,
'kernel_regularizer': kernel_regularizer,
'bias_regularizer': bias_regularizer,
}
if tf_keras.backend.image_data_format() == 'channels_last':
self._bn_axis = -1
else:
self._bn_axis = 1
self._activation = tf_utils.get_activation(activation)
def build(self, input_shape: Union[tf.TensorShape, List[tf.TensorShape]]):
"""Creates the variables of the mask scoring head."""
conv_op = tf_keras.layers.Conv2D
conv_kwargs = {
'filters': self._config_dict['num_filters'],
'kernel_size': 3,
'padding': 'same',
}
conv_kwargs.update({
'kernel_initializer': tf_keras.initializers.VarianceScaling(
scale=2, mode='fan_out', distribution='untruncated_normal'),
'bias_initializer': tf.zeros_initializer(),
'kernel_regularizer': self._config_dict['kernel_regularizer'],
'bias_regularizer': self._config_dict['bias_regularizer'],
})
bn_op = tf_keras.layers.BatchNormalization
bn_kwargs = {
'axis': self._bn_axis,
'momentum': self._config_dict['norm_momentum'],
'epsilon': self._config_dict['norm_epsilon'],
'synchronized': self._config_dict['use_sync_bn'],
}
self._convs = []
self._conv_norms = []
for i in range(self._config_dict['num_convs']):
if self._config_dict['use_depthwise_convolution']:
self._convs.append(
tf_keras.layers.DepthwiseConv2D(
name='mask-scoring-depthwise-conv-{}'.format(i),
kernel_size=3,
padding='same',
use_bias=False,
depthwise_initializer=tf_keras.initializers.RandomNormal(
stddev=0.01),
depthwise_regularizer=self._config_dict['kernel_regularizer'],
depth_multiplier=1))
norm_name = 'mask-scoring-depthwise-bn-{}'.format(i)
self._conv_norms.append(bn_op(name=norm_name, **bn_kwargs))
conv_name = 'mask-scoring-conv-{}'.format(i)
if 'kernel_initializer' in conv_kwargs:
conv_kwargs['kernel_initializer'] = tf_utils.clone_initializer(
conv_kwargs['kernel_initializer'])
if self._config_dict['use_depthwise_convolution']:
conv_kwargs['kernel_size'] = 1
self._convs.append(conv_op(name=conv_name, **conv_kwargs))
bn_name = 'mask-scoring-bn-{}'.format(i)
self._conv_norms.append(bn_op(name=bn_name, **bn_kwargs))
self._fcs = []
self._fc_norms = []
for i in range(self._config_dict['num_fcs']):
fc_name = 'mask-scoring-fc-{}'.format(i)
self._fcs.append(
tf_keras.layers.Dense(
units=self._config_dict['fc_dims'],
kernel_initializer=tf_keras.initializers.VarianceScaling(
scale=1 / 3.0, mode='fan_out', distribution='uniform'),
kernel_regularizer=self._config_dict['kernel_regularizer'],
bias_regularizer=self._config_dict['bias_regularizer'],
name=fc_name))
bn_name = 'mask-scoring-fc-bn-{}'.format(i)
self._fc_norms.append(bn_op(name=bn_name, **bn_kwargs))
self._classifier = tf_keras.layers.Dense(
units=self._config_dict['num_classes'],
kernel_initializer=tf_keras.initializers.RandomNormal(stddev=0.01),
bias_initializer=tf.zeros_initializer(),
kernel_regularizer=self._config_dict['kernel_regularizer'],
bias_regularizer=self._config_dict['bias_regularizer'],
name='iou-scores')
super(MaskScoring, self).build(input_shape)
def call(self, inputs: tf.Tensor, training: bool = None): # pytype: disable=signature-mismatch # overriding-parameter-count-checks
"""Forward pass mask scoring head.
Args:
inputs: A `tf.Tensor` of the shape [batch_size, width, size, num_classes],
representing the segmentation logits.
training: a `bool` indicating whether it is in `training` mode.
Returns:
mask_scores: A `tf.Tensor` of predicted mask scores
[batch_size, num_classes].
"""
x = tf.stop_gradient(inputs)
for conv, bn in zip(self._convs, self._conv_norms):
x = conv(x)
x = bn(x)
x = self._activation(x)
# Casts feat to float32 so the resize op can be run on TPU.
x = tf.cast(x, tf.float32)
x = tf.image.resize(x, size=self._config_dict['fc_input_size'],
method=tf.image.ResizeMethod.BILINEAR)
# Casts it back to be compatible with the rest opetations.
x = tf.cast(x, inputs.dtype)
_, h, w, filters = x.get_shape().as_list()
x = tf.reshape(x, [-1, h * w * filters])
for fc, bn in zip(self._fcs, self._fc_norms):
x = fc(x)
x = bn(x)
x = self._activation(x)
ious = self._classifier(x)
return ious
def get_config(self) -> Mapping[str, Any]:
return self._config_dict
@classmethod
def from_config(cls, config, custom_objects=None):
return cls(**config)
@tf_keras.utils.register_keras_serializable(package='Vision')
class SegmentationHead(tf_keras.layers.Layer):
"""Creates a segmentation head."""
def __init__(
self,
num_classes: int,
level: Union[int, str],
num_convs: int = 2,
num_filters: int = 256,
use_depthwise_convolution: bool = False,
prediction_kernel_size: int = 1,
upsample_factor: int = 1,
feature_fusion: Optional[str] = None,
decoder_min_level: Optional[int] = None,
decoder_max_level: Optional[int] = None,
low_level: int = 2,
low_level_num_filters: int = 48,
num_decoder_filters: int = 256,
activation: str = 'relu',
logit_activation: Optional[str] = None,
use_sync_bn: bool = False,
norm_momentum: float = 0.99,
norm_epsilon: float = 0.001,
kernel_regularizer: Optional[tf_keras.regularizers.Regularizer] = None,
bias_regularizer: Optional[tf_keras.regularizers.Regularizer] = None,
**kwargs):
"""Initializes a segmentation head.
Args:
num_classes: An `int` number of mask classification categories. The number
of classes does not include background class.
level: An `int` or `str`, level to use to build segmentation head.
num_convs: An `int` number of stacked convolution before the last
prediction layer.
num_filters: An `int` number to specify the number of filters used.
Default is 256.
use_depthwise_convolution: A bool to specify if use depthwise separable
convolutions.
prediction_kernel_size: An `int` number to specify the kernel size of the
prediction layer.
upsample_factor: An `int` number to specify the upsampling factor to
generate finer mask. Default 1 means no upsampling is applied.
feature_fusion: One of the constants in nn_layers.FeatureFusion, namely
`deeplabv3plus`, `pyramid_fusion`, `panoptic_fpn_fusion`,
`deeplabv3plus_sum_to_merge`, or None. If `deeplabv3plus`, features from
decoder_features[level] will be fused with low level feature maps from
backbone. If `pyramid_fusion`, multiscale features will be resized and
fused at the target level.
decoder_min_level: An `int` of minimum level from decoder to use in
feature fusion. It is only used when feature_fusion is set to
`panoptic_fpn_fusion`.
decoder_max_level: An `int` of maximum level from decoder to use in
feature fusion. It is only used when feature_fusion is set to
`panoptic_fpn_fusion`.
low_level: An `int` of backbone level to be used for feature fusion. It is
used when feature_fusion is set to `deeplabv3plus` or
`deeplabv3plus_sum_to_merge`.
low_level_num_filters: An `int` of reduced number of filters for the low
level features before fusing it with higher level features. It is only
used when feature_fusion is set to `deeplabv3plus` or
`deeplabv3plus_sum_to_merge`.
num_decoder_filters: An `int` of number of filters in the decoder outputs.
It is only used when feature_fusion is set to `panoptic_fpn_fusion`.
activation: A `str` that indicates which activation is used, e.g. 'relu',
'swish', etc.
logit_activation: Activation applied to the final classifier layer logits,
e.g. 'sigmoid', 'softmax'. Can be useful in cases when the task does not
use only cross entropy loss.
use_sync_bn: A `bool` that indicates whether to use synchronized batch
normalization across different replicas.
norm_momentum: A `float` of normalization momentum for the moving average.
norm_epsilon: A `float` added to variance to avoid dividing by zero.
kernel_regularizer: A `tf_keras.regularizers.Regularizer` object for
Conv2D. Default is None.
bias_regularizer: A `tf_keras.regularizers.Regularizer` object for Conv2D.
**kwargs: Additional keyword arguments to be passed.
"""
super(SegmentationHead, self).__init__(**kwargs)
self._config_dict = {
'num_classes': num_classes,
'level': level,
'num_convs': num_convs,
'num_filters': num_filters,
'use_depthwise_convolution': use_depthwise_convolution,
'prediction_kernel_size': prediction_kernel_size,
'upsample_factor': upsample_factor,
'feature_fusion': feature_fusion,
'decoder_min_level': decoder_min_level,
'decoder_max_level': decoder_max_level,
'low_level': low_level,
'low_level_num_filters': low_level_num_filters,
'num_decoder_filters': num_decoder_filters,
'activation': activation,
'logit_activation': logit_activation,
'use_sync_bn': use_sync_bn,
'norm_momentum': norm_momentum,
'norm_epsilon': norm_epsilon,
'kernel_regularizer': kernel_regularizer,
'bias_regularizer': bias_regularizer
}
if tf_keras.backend.image_data_format() == 'channels_last':
self._bn_axis = -1
else:
self._bn_axis = 1
self._activation = tf_utils.get_activation(activation)
def build(self, input_shape: Union[tf.TensorShape, List[tf.TensorShape]]):
"""Creates the variables of the segmentation head."""
use_depthwise_convolution = self._config_dict['use_depthwise_convolution']
conv_op = tf_keras.layers.Conv2D
bn_op = tf_keras.layers.BatchNormalization
bn_kwargs = {
'axis': self._bn_axis,
'momentum': self._config_dict['norm_momentum'],
'epsilon': self._config_dict['norm_epsilon'],
'synchronized': self._config_dict['use_sync_bn'],
}
if self._config_dict['feature_fusion'] in {'deeplabv3plus',
'deeplabv3plus_sum_to_merge'}:
# Deeplabv3+ feature fusion layers.
self._dlv3p_conv = conv_op(
kernel_size=1,
padding='same',
use_bias=False,
kernel_initializer=tf_keras.initializers.RandomNormal(stddev=0.01),
kernel_regularizer=self._config_dict['kernel_regularizer'],
name='segmentation_head_deeplabv3p_fusion_conv',
filters=self._config_dict['low_level_num_filters'])
self._dlv3p_norm = bn_op(
name='segmentation_head_deeplabv3p_fusion_norm', **bn_kwargs)
elif self._config_dict['feature_fusion'] == 'panoptic_fpn_fusion':
self._panoptic_fpn_fusion = nn_layers.PanopticFPNFusion(
min_level=self._config_dict['decoder_min_level'],
max_level=self._config_dict['decoder_max_level'],
target_level=self._config_dict['level'],
num_filters=self._config_dict['num_filters'],
num_fpn_filters=self._config_dict['num_decoder_filters'],
activation=self._config_dict['activation'],
kernel_regularizer=self._config_dict['kernel_regularizer'],
bias_regularizer=self._config_dict['bias_regularizer'])
# Segmentation head layers.
self._convs = []
self._norms = []
for i in range(self._config_dict['num_convs']):
if use_depthwise_convolution:
self._convs.append(
tf_keras.layers.DepthwiseConv2D(
name='segmentation_head_depthwise_conv_{}'.format(i),
kernel_size=3,
padding='same',
use_bias=False,
depthwise_initializer=tf_keras.initializers.RandomNormal(
stddev=0.01),
depthwise_regularizer=self._config_dict['kernel_regularizer'],
depth_multiplier=1))
norm_name = 'segmentation_head_depthwise_norm_{}'.format(i)
self._norms.append(bn_op(name=norm_name, **bn_kwargs))
conv_name = 'segmentation_head_conv_{}'.format(i)
self._convs.append(
conv_op(
name=conv_name,
filters=self._config_dict['num_filters'],
kernel_size=3 if not use_depthwise_convolution else 1,
padding='same',
use_bias=False,
kernel_initializer=tf_keras.initializers.RandomNormal(
stddev=0.01),
kernel_regularizer=self._config_dict['kernel_regularizer']))
norm_name = 'segmentation_head_norm_{}'.format(i)
self._norms.append(bn_op(name=norm_name, **bn_kwargs))
self._classifier = conv_op(
name='segmentation_output',
filters=self._config_dict['num_classes'],
kernel_size=self._config_dict['prediction_kernel_size'],
padding='same',
activation=self._config_dict['logit_activation'],
bias_initializer=tf.zeros_initializer(),
kernel_initializer=tf_keras.initializers.RandomNormal(stddev=0.01),
kernel_regularizer=self._config_dict['kernel_regularizer'],
bias_regularizer=self._config_dict['bias_regularizer'])
super().build(input_shape)
def call(self, inputs: Tuple[Union[tf.Tensor, Mapping[str, tf.Tensor]],
Union[tf.Tensor, Mapping[str, tf.Tensor]]]):
"""Forward pass of the segmentation head.
It supports both a tuple of 2 tensors or 2 dictionaries. The first is
backbone endpoints, and the second is decoder endpoints. When inputs are
tensors, they are from a single level of feature maps. When inputs are
dictionaries, they contain multiple levels of feature maps, where the key
is the index of feature map.
Args:
inputs: A tuple of 2 feature map tensors of shape
[batch, height_l, width_l, channels] or 2 dictionaries of tensors:
- key: A `str` of the level of the multilevel features.
- values: A `tf.Tensor` of the feature map tensors, whose shape is
[batch, height_l, width_l, channels].
The first is backbone endpoints, and the second is decoder endpoints.
Returns:
segmentation prediction mask: A `tf.Tensor` of the segmentation mask
scores predicted from input features.
"""
backbone_output = inputs[0]
decoder_output = inputs[1]
if self._config_dict['feature_fusion'] in {'deeplabv3plus',
'deeplabv3plus_sum_to_merge'}:
# deeplabv3+ feature fusion
x = decoder_output[str(self._config_dict['level'])] if isinstance(
decoder_output, dict) else decoder_output
y = backbone_output[str(self._config_dict['low_level'])] if isinstance(
backbone_output, dict) else backbone_output
y = self._dlv3p_norm(self._dlv3p_conv(y))
y = self._activation(y)
x = tf.image.resize(
x, tf.shape(y)[1:3], method=tf.image.ResizeMethod.BILINEAR)
x = tf.cast(x, dtype=y.dtype)
if self._config_dict['feature_fusion'] == 'deeplabv3plus':
x = tf.concat([x, y], axis=self._bn_axis)
else:
x = tf_keras.layers.Add()([x, y])
elif self._config_dict['feature_fusion'] == 'pyramid_fusion':
if not isinstance(decoder_output, dict):
raise ValueError('Only support dictionary decoder_output.')
x = nn_layers.pyramid_feature_fusion(decoder_output,
self._config_dict['level'])
elif self._config_dict['feature_fusion'] == 'panoptic_fpn_fusion':
x = self._panoptic_fpn_fusion(decoder_output)
else:
x = decoder_output[str(self._config_dict['level'])] if isinstance(
decoder_output, dict) else decoder_output
for conv, norm in zip(self._convs, self._norms):
x = conv(x)
x = norm(x)
x = self._activation(x)
if self._config_dict['upsample_factor'] > 1:
x = spatial_transform_ops.nearest_upsampling(
x, scale=self._config_dict['upsample_factor'])
return self._classifier(x)
def get_config(self):
base_config = super().get_config()
return dict(list(base_config.items()) + list(self._config_dict.items()))
@classmethod
def from_config(cls, config):
return cls(**config)