# Copyright 2023 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Contains definitions of instance prediction heads.""" from typing import List, Union, Optional # Import libraries import tensorflow as tf, tf_keras from official.modeling import tf_utils @tf_keras.utils.register_keras_serializable(package='Vision') class DetectionHead(tf_keras.layers.Layer): """Creates a detection head.""" def __init__( self, num_classes: int, num_convs: int = 0, num_filters: int = 256, use_separable_conv: bool = False, num_fcs: int = 2, fc_dims: int = 1024, class_agnostic_bbox_pred: bool = False, activation: str = 'relu', use_sync_bn: bool = False, norm_momentum: float = 0.99, norm_epsilon: float = 0.001, kernel_regularizer: Optional[tf_keras.regularizers.Regularizer] = None, bias_regularizer: Optional[tf_keras.regularizers.Regularizer] = None, **kwargs): """Initializes a detection head. Args: num_classes: An `int` for the number of classes. num_convs: An `int` number that represents the number of the intermediate convolution layers before the FC layers. num_filters: An `int` number that represents the number of filters of the intermediate convolution layers. use_separable_conv: A `bool` that indicates whether the separable convolution layers is used. num_fcs: An `int` number that represents the number of FC layers before the predictions. fc_dims: An `int` number that represents the number of dimension of the FC layers. class_agnostic_bbox_pred: `bool`, indicating whether bboxes should be predicted for every class or not. activation: A `str` that indicates which activation is used, e.g. 'relu', 'swish', etc. use_sync_bn: A `bool` that indicates whether to use synchronized batch normalization across different replicas. norm_momentum: A `float` of normalization momentum for the moving average. norm_epsilon: A `float` added to variance to avoid dividing by zero. kernel_regularizer: A `tf_keras.regularizers.Regularizer` object for Conv2D. Default is None. bias_regularizer: A `tf_keras.regularizers.Regularizer` object for Conv2D. **kwargs: Additional keyword arguments to be passed. """ super(DetectionHead, self).__init__(**kwargs) self._config_dict = { 'num_classes': num_classes, 'num_convs': num_convs, 'num_filters': num_filters, 'use_separable_conv': use_separable_conv, 'num_fcs': num_fcs, 'fc_dims': fc_dims, 'class_agnostic_bbox_pred': class_agnostic_bbox_pred, 'activation': activation, 'use_sync_bn': use_sync_bn, 'norm_momentum': norm_momentum, 'norm_epsilon': norm_epsilon, 'kernel_regularizer': kernel_regularizer, 'bias_regularizer': bias_regularizer, } if tf_keras.backend.image_data_format() == 'channels_last': self._bn_axis = -1 else: self._bn_axis = 1 self._activation = tf_utils.get_activation(activation) def build(self, input_shape: Union[tf.TensorShape, List[tf.TensorShape]]): """Creates the variables of the head.""" conv_op = (tf_keras.layers.SeparableConv2D if self._config_dict['use_separable_conv'] else tf_keras.layers.Conv2D) conv_kwargs = { 'filters': self._config_dict['num_filters'], 'kernel_size': 3, 'padding': 'same', } if self._config_dict['use_separable_conv']: conv_kwargs.update({ 'depthwise_initializer': tf_keras.initializers.VarianceScaling( scale=2, mode='fan_out', distribution='untruncated_normal'), 'pointwise_initializer': tf_keras.initializers.VarianceScaling( scale=2, mode='fan_out', distribution='untruncated_normal'), 'bias_initializer': tf.zeros_initializer(), 'depthwise_regularizer': self._config_dict['kernel_regularizer'], 'pointwise_regularizer': self._config_dict['kernel_regularizer'], 'bias_regularizer': self._config_dict['bias_regularizer'], }) else: conv_kwargs.update({ 'kernel_initializer': tf_keras.initializers.VarianceScaling( scale=2, mode='fan_out', distribution='untruncated_normal'), 'bias_initializer': tf.zeros_initializer(), 'kernel_regularizer': self._config_dict['kernel_regularizer'], 'bias_regularizer': self._config_dict['bias_regularizer'], }) bn_op = tf_keras.layers.BatchNormalization bn_kwargs = { 'axis': self._bn_axis, 'momentum': self._config_dict['norm_momentum'], 'epsilon': self._config_dict['norm_epsilon'], 'synchronized': self._config_dict['use_sync_bn'], } self._convs = [] self._conv_norms = [] for i in range(self._config_dict['num_convs']): conv_name = 'detection-conv_{}'.format(i) if 'kernel_initializer' in conv_kwargs: conv_kwargs['kernel_initializer'] = tf_utils.clone_initializer( conv_kwargs['kernel_initializer']) self._convs.append(conv_op(name=conv_name, **conv_kwargs)) bn_name = 'detection-conv-bn_{}'.format(i) self._conv_norms.append(bn_op(name=bn_name, **bn_kwargs)) self._fcs = [] self._fc_norms = [] for i in range(self._config_dict['num_fcs']): fc_name = 'detection-fc_{}'.format(i) self._fcs.append( tf_keras.layers.Dense( units=self._config_dict['fc_dims'], kernel_initializer=tf_keras.initializers.VarianceScaling( scale=1 / 3.0, mode='fan_out', distribution='uniform'), kernel_regularizer=self._config_dict['kernel_regularizer'], bias_regularizer=self._config_dict['bias_regularizer'], name=fc_name)) bn_name = 'detection-fc-bn_{}'.format(i) self._fc_norms.append(bn_op(name=bn_name, **bn_kwargs)) self._classifier = tf_keras.layers.Dense( units=self._config_dict['num_classes'], kernel_initializer=tf_keras.initializers.RandomNormal(stddev=0.01), bias_initializer=tf.zeros_initializer(), kernel_regularizer=self._config_dict['kernel_regularizer'], bias_regularizer=self._config_dict['bias_regularizer'], name='detection-scores') num_box_outputs = (4 if self._config_dict['class_agnostic_bbox_pred'] else self._config_dict['num_classes'] * 4) self._box_regressor = tf_keras.layers.Dense( units=num_box_outputs, kernel_initializer=tf_keras.initializers.RandomNormal(stddev=0.001), bias_initializer=tf.zeros_initializer(), kernel_regularizer=self._config_dict['kernel_regularizer'], bias_regularizer=self._config_dict['bias_regularizer'], name='detection-boxes') super(DetectionHead, self).build(input_shape) def call(self, inputs: tf.Tensor, training: bool = None): """Forward pass of box and class branches for the Mask-RCNN model. Args: inputs: A `tf.Tensor` of the shape [batch_size, num_instances, roi_height, roi_width, roi_channels], representing the ROI features. training: a `bool` indicating whether it is in `training` mode. Returns: class_outputs: A `tf.Tensor` of the shape [batch_size, num_rois, num_classes], representing the class predictions. box_outputs: A `tf.Tensor` of the shape [batch_size, num_rois, num_classes * 4], representing the box predictions. """ roi_features = inputs _, num_rois, height, width, filters = roi_features.get_shape().as_list() x = tf.reshape(roi_features, [-1, height, width, filters]) for conv, bn in zip(self._convs, self._conv_norms): x = conv(x) x = bn(x) x = self._activation(x) _, _, _, filters = x.get_shape().as_list() x = tf.reshape(x, [-1, num_rois, height * width * filters]) for fc, bn in zip(self._fcs, self._fc_norms): x = fc(x) x = bn(x) x = self._activation(x) classes = self._classifier(x) boxes = self._box_regressor(x) return classes, boxes def get_config(self): return self._config_dict @classmethod def from_config(cls, config): return cls(**config) @tf_keras.utils.register_keras_serializable(package='Vision') class MaskHead(tf_keras.layers.Layer): """Creates a mask head.""" def __init__( self, num_classes: int, upsample_factor: int = 2, num_convs: int = 4, num_filters: int = 256, use_separable_conv: bool = False, activation: str = 'relu', use_sync_bn: bool = False, norm_momentum: float = 0.99, norm_epsilon: float = 0.001, kernel_regularizer: Optional[tf_keras.regularizers.Regularizer] = None, bias_regularizer: Optional[tf_keras.regularizers.Regularizer] = None, class_agnostic: bool = False, **kwargs): """Initializes a mask head. Args: num_classes: An `int` of the number of classes. upsample_factor: An `int` that indicates the upsample factor to generate the final predicted masks. It should be >= 1. num_convs: An `int` number that represents the number of the intermediate convolution layers before the mask prediction layers. num_filters: An `int` number that represents the number of filters of the intermediate convolution layers. use_separable_conv: A `bool` that indicates whether the separable convolution layers is used. activation: A `str` that indicates which activation is used, e.g. 'relu', 'swish', etc. use_sync_bn: A `bool` that indicates whether to use synchronized batch normalization across different replicas. norm_momentum: A `float` of normalization momentum for the moving average. norm_epsilon: A `float` added to variance to avoid dividing by zero. kernel_regularizer: A `tf_keras.regularizers.Regularizer` object for Conv2D. Default is None. bias_regularizer: A `tf_keras.regularizers.Regularizer` object for Conv2D. class_agnostic: A `bool`. If set, we use a single channel mask head that is shared between all classes. **kwargs: Additional keyword arguments to be passed. """ super(MaskHead, self).__init__(**kwargs) self._config_dict = { 'num_classes': num_classes, 'upsample_factor': upsample_factor, 'num_convs': num_convs, 'num_filters': num_filters, 'use_separable_conv': use_separable_conv, 'activation': activation, 'use_sync_bn': use_sync_bn, 'norm_momentum': norm_momentum, 'norm_epsilon': norm_epsilon, 'kernel_regularizer': kernel_regularizer, 'bias_regularizer': bias_regularizer, 'class_agnostic': class_agnostic } if tf_keras.backend.image_data_format() == 'channels_last': self._bn_axis = -1 else: self._bn_axis = 1 self._activation = tf_utils.get_activation(activation) def build(self, input_shape: Union[tf.TensorShape, List[tf.TensorShape]]): """Creates the variables of the head.""" conv_op = (tf_keras.layers.SeparableConv2D if self._config_dict['use_separable_conv'] else tf_keras.layers.Conv2D) conv_kwargs = { 'filters': self._config_dict['num_filters'], 'kernel_size': 3, 'padding': 'same', } if self._config_dict['use_separable_conv']: conv_kwargs.update({ 'depthwise_initializer': tf_keras.initializers.VarianceScaling( scale=2, mode='fan_out', distribution='untruncated_normal'), 'pointwise_initializer': tf_keras.initializers.VarianceScaling( scale=2, mode='fan_out', distribution='untruncated_normal'), 'bias_initializer': tf.zeros_initializer(), 'depthwise_regularizer': self._config_dict['kernel_regularizer'], 'pointwise_regularizer': self._config_dict['kernel_regularizer'], 'bias_regularizer': self._config_dict['bias_regularizer'], }) else: conv_kwargs.update({ 'kernel_initializer': tf_keras.initializers.VarianceScaling( scale=2, mode='fan_out', distribution='untruncated_normal'), 'bias_initializer': tf.zeros_initializer(), 'kernel_regularizer': self._config_dict['kernel_regularizer'], 'bias_regularizer': self._config_dict['bias_regularizer'], }) bn_op = tf_keras.layers.BatchNormalization bn_kwargs = { 'axis': self._bn_axis, 'momentum': self._config_dict['norm_momentum'], 'epsilon': self._config_dict['norm_epsilon'], 'synchronized': self._config_dict['use_sync_bn'], } self._convs = [] self._conv_norms = [] for i in range(self._config_dict['num_convs']): conv_name = 'mask-conv_{}'.format(i) for initializer_name in ['kernel_initializer', 'depthwise_initializer', 'pointwise_initializer']: if initializer_name in conv_kwargs: conv_kwargs[initializer_name] = tf_utils.clone_initializer( conv_kwargs[initializer_name]) self._convs.append(conv_op(name=conv_name, **conv_kwargs)) bn_name = 'mask-conv-bn_{}'.format(i) self._conv_norms.append(bn_op(name=bn_name, **bn_kwargs)) self._deconv = tf_keras.layers.Conv2DTranspose( filters=self._config_dict['num_filters'], kernel_size=self._config_dict['upsample_factor'], strides=self._config_dict['upsample_factor'], padding='valid', kernel_initializer=tf_keras.initializers.VarianceScaling( scale=2, mode='fan_out', distribution='untruncated_normal'), bias_initializer=tf.zeros_initializer(), kernel_regularizer=self._config_dict['kernel_regularizer'], bias_regularizer=self._config_dict['bias_regularizer'], name='mask-upsampling') self._deconv_bn = bn_op(name='mask-deconv-bn', **bn_kwargs) if self._config_dict['class_agnostic']: num_filters = 1 else: num_filters = self._config_dict['num_classes'] conv_kwargs = { 'filters': num_filters, 'kernel_size': 1, 'padding': 'valid', } if self._config_dict['use_separable_conv']: conv_kwargs.update({ 'depthwise_initializer': tf_keras.initializers.VarianceScaling( scale=2, mode='fan_out', distribution='untruncated_normal'), 'pointwise_initializer': tf_keras.initializers.VarianceScaling( scale=2, mode='fan_out', distribution='untruncated_normal'), 'bias_initializer': tf.zeros_initializer(), 'depthwise_regularizer': self._config_dict['kernel_regularizer'], 'pointwise_regularizer': self._config_dict['kernel_regularizer'], 'bias_regularizer': self._config_dict['bias_regularizer'], }) else: conv_kwargs.update({ 'kernel_initializer': tf_keras.initializers.VarianceScaling( scale=2, mode='fan_out', distribution='untruncated_normal'), 'bias_initializer': tf.zeros_initializer(), 'kernel_regularizer': self._config_dict['kernel_regularizer'], 'bias_regularizer': self._config_dict['bias_regularizer'], }) self._mask_regressor = conv_op(name='mask-logits', **conv_kwargs) super(MaskHead, self).build(input_shape) def call(self, inputs: List[tf.Tensor], training: bool = None): """Forward pass of mask branch for the Mask-RCNN model. Args: inputs: A `list` of two tensors where inputs[0]: A `tf.Tensor` of shape [batch_size, num_instances, roi_height, roi_width, roi_channels], representing the ROI features. inputs[1]: A `tf.Tensor` of shape [batch_size, num_instances], representing the classes of the ROIs. training: A `bool` indicating whether it is in `training` mode. Returns: mask_outputs: A `tf.Tensor` of shape [batch_size, num_instances, roi_height * upsample_factor, roi_width * upsample_factor], representing the mask predictions. """ roi_features, roi_classes = inputs _, num_rois, height, width, filters = roi_features.get_shape().as_list() x = tf.reshape(roi_features, [-1, height, width, filters]) for conv, bn in zip(self._convs, self._conv_norms): x = conv(x) x = bn(x) x = self._activation(x) x = self._deconv(x) x = self._deconv_bn(x) x = self._activation(x) logits = self._mask_regressor(x) mask_height = height * self._config_dict['upsample_factor'] mask_width = width * self._config_dict['upsample_factor'] if self._config_dict['class_agnostic']: return tf.reshape(logits, [-1, num_rois, mask_height, mask_width]) else: logits = tf.reshape( logits, [-1, num_rois, mask_height, mask_width, self._config_dict['num_classes']]) return tf.gather( logits, tf.cast(roi_classes, dtype=tf.int32), axis=-1, batch_dims=2 ) def get_config(self): return self._config_dict @classmethod def from_config(cls, config): return cls(**config)