deanna-emery's picture
updates
93528c6
raw
history blame
15.8 kB
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Contains definitions of Residual Networks with Deeplab modifications."""
from typing import Callable, Optional, Tuple, List
import numpy as np
import tensorflow as tf, tf_keras
from official.modeling import hyperparams
from official.modeling import tf_utils
from official.vision.modeling.backbones import factory
from official.vision.modeling.layers import nn_blocks
from official.vision.modeling.layers import nn_layers
layers = tf_keras.layers
# Specifications for different ResNet variants.
# Each entry specifies block configurations of the particular ResNet variant.
# Each element in the block configuration is in the following format:
# (block_fn, num_filters, block_repeats)
RESNET_SPECS = {
50: [
('bottleneck', 64, 3),
('bottleneck', 128, 4),
('bottleneck', 256, 6),
('bottleneck', 512, 3),
],
101: [
('bottleneck', 64, 3),
('bottleneck', 128, 4),
('bottleneck', 256, 23),
('bottleneck', 512, 3),
],
152: [
('bottleneck', 64, 3),
('bottleneck', 128, 8),
('bottleneck', 256, 36),
('bottleneck', 512, 3),
],
200: [
('bottleneck', 64, 3),
('bottleneck', 128, 24),
('bottleneck', 256, 36),
('bottleneck', 512, 3),
],
}
@tf_keras.utils.register_keras_serializable(package='Vision')
class DilatedResNet(tf_keras.Model):
"""Creates a ResNet model with Deeplabv3 modifications.
This backbone is suitable for semantic segmentation. This implements
Liang-Chieh Chen, George Papandreou, Florian Schroff, Hartwig Adam.
Rethinking Atrous Convolution for Semantic Image Segmentation.
(https://arxiv.org/pdf/1706.05587)
"""
def __init__(
self,
model_id: int,
output_stride: int,
input_specs: tf_keras.layers.InputSpec = layers.InputSpec(
shape=[None, None, None, 3]),
stem_type: str = 'v0',
resnetd_shortcut: bool = False,
replace_stem_max_pool: bool = False,
se_ratio: Optional[float] = None,
init_stochastic_depth_rate: float = 0.0,
multigrid: Optional[Tuple[int]] = None,
last_stage_repeats: int = 1,
activation: str = 'relu',
use_sync_bn: bool = False,
norm_momentum: float = 0.99,
norm_epsilon: float = 0.001,
kernel_initializer: str = 'VarianceScaling',
kernel_regularizer: Optional[tf_keras.regularizers.Regularizer] = None,
bias_regularizer: Optional[tf_keras.regularizers.Regularizer] = None,
**kwargs):
"""Initializes a ResNet model with DeepLab modification.
Args:
model_id: An `int` specifies depth of ResNet backbone model.
output_stride: An `int` of output stride, ratio of input to output
resolution.
input_specs: A `tf_keras.layers.InputSpec` of the input tensor.
stem_type: A `str` of stem type. Can be `v0` or `v1`. `v1` replaces 7x7
conv by 3 3x3 convs.
resnetd_shortcut: A `bool` of whether to use ResNet-D shortcut in
downsampling blocks.
replace_stem_max_pool: A `bool` of whether to replace the max pool in stem
with a stride-2 conv,
se_ratio: A `float` or None. Ratio of the Squeeze-and-Excitation layer.
init_stochastic_depth_rate: A `float` of initial stochastic depth rate.
multigrid: A tuple of the same length as the number of blocks in the last
resnet stage.
last_stage_repeats: An `int` that specifies how many times last stage is
repeated.
activation: A `str` name of the activation function.
use_sync_bn: If True, use synchronized batch normalization.
norm_momentum: A `float` of normalization momentum for the moving average.
norm_epsilon: A `float` added to variance to avoid dividing by zero.
kernel_initializer: A str for kernel initializer of convolutional layers.
kernel_regularizer: A `tf_keras.regularizers.Regularizer` object for
Conv2D. Default to None.
bias_regularizer: A `tf_keras.regularizers.Regularizer` object for Conv2D.
Default to None.
**kwargs: Additional keyword arguments to be passed.
"""
self._model_id = model_id
self._output_stride = output_stride
self._input_specs = input_specs
self._use_sync_bn = use_sync_bn
self._activation = activation
self._norm_momentum = norm_momentum
self._norm_epsilon = norm_epsilon
self._norm = layers.BatchNormalization
self._kernel_initializer = kernel_initializer
self._kernel_regularizer = kernel_regularizer
self._bias_regularizer = bias_regularizer
self._stem_type = stem_type
self._resnetd_shortcut = resnetd_shortcut
self._replace_stem_max_pool = replace_stem_max_pool
self._se_ratio = se_ratio
self._init_stochastic_depth_rate = init_stochastic_depth_rate
if tf_keras.backend.image_data_format() == 'channels_last':
bn_axis = -1
else:
bn_axis = 1
# Build ResNet.
inputs = tf_keras.Input(shape=input_specs.shape[1:])
if stem_type == 'v0':
x = layers.Conv2D(
filters=64,
kernel_size=7,
strides=2,
use_bias=False,
padding='same',
kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer)(
inputs)
x = self._norm(
axis=bn_axis,
momentum=norm_momentum,
epsilon=norm_epsilon,
synchronized=use_sync_bn)(
x)
x = tf_utils.get_activation(activation)(x)
elif stem_type == 'v1':
x = layers.Conv2D(
filters=64,
kernel_size=3,
strides=2,
use_bias=False,
padding='same',
kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer)(
inputs)
x = self._norm(
axis=bn_axis,
momentum=norm_momentum,
epsilon=norm_epsilon,
synchronized=use_sync_bn)(
x)
x = tf_utils.get_activation(activation)(x)
x = layers.Conv2D(
filters=64,
kernel_size=3,
strides=1,
use_bias=False,
padding='same',
kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer)(
x)
x = self._norm(
axis=bn_axis,
momentum=norm_momentum,
epsilon=norm_epsilon,
synchronized=use_sync_bn)(
x)
x = tf_utils.get_activation(activation)(x)
x = layers.Conv2D(
filters=128,
kernel_size=3,
strides=1,
use_bias=False,
padding='same',
kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer)(
x)
x = self._norm(
axis=bn_axis,
momentum=norm_momentum,
epsilon=norm_epsilon,
synchronized=use_sync_bn)(
x)
x = tf_utils.get_activation(activation)(x)
else:
raise ValueError('Stem type {} not supported.'.format(stem_type))
if replace_stem_max_pool:
x = layers.Conv2D(
filters=64,
kernel_size=3,
strides=2,
use_bias=False,
padding='same',
kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer)(
x)
x = self._norm(
axis=bn_axis,
momentum=norm_momentum,
epsilon=norm_epsilon,
synchronized=use_sync_bn)(
x)
x = tf_utils.get_activation(activation, use_keras_layer=True)(x)
else:
x = layers.MaxPool2D(pool_size=3, strides=2, padding='same')(x)
normal_resnet_stage = int(np.math.log2(self._output_stride)) - 2
endpoints = {}
for i in range(normal_resnet_stage + 1):
spec = RESNET_SPECS[model_id][i]
if spec[0] == 'bottleneck':
block_fn = nn_blocks.BottleneckBlock
else:
raise ValueError('Block fn `{}` is not supported.'.format(spec[0]))
x = self._block_group(
inputs=x,
filters=spec[1],
strides=(1 if i == 0 else 2),
dilation_rate=1,
block_fn=block_fn,
block_repeats=spec[2],
stochastic_depth_drop_rate=nn_layers.get_stochastic_depth_rate(
self._init_stochastic_depth_rate, i + 2, 4 + last_stage_repeats),
name='block_group_l{}'.format(i + 2))
endpoints[str(i + 2)] = x
dilation_rate = 2
for i in range(normal_resnet_stage + 1, 3 + last_stage_repeats):
spec = RESNET_SPECS[model_id][i] if i < 3 else RESNET_SPECS[model_id][-1]
if spec[0] == 'bottleneck':
block_fn = nn_blocks.BottleneckBlock
else:
raise ValueError('Block fn `{}` is not supported.'.format(spec[0]))
x = self._block_group(
inputs=x,
filters=spec[1],
strides=1,
dilation_rate=dilation_rate,
block_fn=block_fn,
block_repeats=spec[2],
stochastic_depth_drop_rate=nn_layers.get_stochastic_depth_rate(
self._init_stochastic_depth_rate, i + 2, 4 + last_stage_repeats),
multigrid=multigrid if i >= 3 else None,
name='block_group_l{}'.format(i + 2))
dilation_rate *= 2
endpoints[str(normal_resnet_stage + 2)] = x
self._output_specs = {l: endpoints[l].get_shape() for l in endpoints}
super(DilatedResNet, self).__init__(
inputs=inputs, outputs=endpoints, **kwargs)
def _block_group(self,
inputs: tf.Tensor,
filters: int,
strides: int,
dilation_rate: int,
block_fn: Callable[..., tf_keras.layers.Layer],
block_repeats: int = 1,
stochastic_depth_drop_rate: float = 0.0,
multigrid: Optional[List[int]] = None,
name: str = 'block_group'):
"""Creates one group of blocks for the ResNet model.
Deeplab applies strides at the last block.
Args:
inputs: A `tf.Tensor` of size `[batch, channels, height, width]`.
filters: An `int` off number of filters for the first convolution of the
layer.
strides: An `int` of stride to use for the first convolution of the layer.
If greater than 1, this layer will downsample the input.
dilation_rate: An `int` of diluted convolution rates.
block_fn: Either `nn_blocks.ResidualBlock` or `nn_blocks.BottleneckBlock`.
block_repeats: An `int` of number of blocks contained in the layer.
stochastic_depth_drop_rate: A `float` of drop rate of the current block
group.
multigrid: A list of `int` or None. If specified, dilation rates for each
block is scaled up by its corresponding factor in the multigrid.
name: A `str` name for the block.
Returns:
The output `tf.Tensor` of the block layer.
"""
if multigrid is not None and len(multigrid) != block_repeats:
raise ValueError('multigrid has to match number of block_repeats')
if multigrid is None:
multigrid = [1] * block_repeats
# TODO(arashwan): move striding at the of the block.
x = block_fn(
filters=filters,
strides=strides,
dilation_rate=dilation_rate * multigrid[0],
use_projection=True,
stochastic_depth_drop_rate=stochastic_depth_drop_rate,
se_ratio=self._se_ratio,
resnetd_shortcut=self._resnetd_shortcut,
kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer,
activation=self._activation,
use_sync_bn=self._use_sync_bn,
norm_momentum=self._norm_momentum,
norm_epsilon=self._norm_epsilon)(
inputs)
for i in range(1, block_repeats):
x = block_fn(
filters=filters,
strides=1,
dilation_rate=dilation_rate * multigrid[i],
use_projection=False,
stochastic_depth_drop_rate=stochastic_depth_drop_rate,
resnetd_shortcut=self._resnetd_shortcut,
se_ratio=self._se_ratio,
kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer,
activation=self._activation,
use_sync_bn=self._use_sync_bn,
norm_momentum=self._norm_momentum,
norm_epsilon=self._norm_epsilon)(
x)
return tf.identity(x, name=name)
def get_config(self):
config_dict = {
'model_id': self._model_id,
'output_stride': self._output_stride,
'stem_type': self._stem_type,
'resnetd_shortcut': self._resnetd_shortcut,
'replace_stem_max_pool': self._replace_stem_max_pool,
'se_ratio': self._se_ratio,
'init_stochastic_depth_rate': self._init_stochastic_depth_rate,
'activation': self._activation,
'use_sync_bn': self._use_sync_bn,
'norm_momentum': self._norm_momentum,
'norm_epsilon': self._norm_epsilon,
'kernel_initializer': self._kernel_initializer,
'kernel_regularizer': self._kernel_regularizer,
'bias_regularizer': self._bias_regularizer,
}
return config_dict
@classmethod
def from_config(cls, config, custom_objects=None):
return cls(**config)
@property
def output_specs(self):
"""A dict of {level: TensorShape} pairs for the model output."""
return self._output_specs
@factory.register_backbone_builder('dilated_resnet')
def build_dilated_resnet(
input_specs: tf_keras.layers.InputSpec,
backbone_config: hyperparams.Config,
norm_activation_config: hyperparams.Config,
l2_regularizer: tf_keras.regularizers.Regularizer = None) -> tf_keras.Model: # pytype: disable=annotation-type-mismatch # typed-keras
"""Builds ResNet backbone from a config."""
backbone_type = backbone_config.type
backbone_cfg = backbone_config.get()
assert backbone_type == 'dilated_resnet', (f'Inconsistent backbone type '
f'{backbone_type}')
return DilatedResNet(
model_id=backbone_cfg.model_id,
output_stride=backbone_cfg.output_stride,
input_specs=input_specs,
stem_type=backbone_cfg.stem_type,
resnetd_shortcut=backbone_cfg.resnetd_shortcut,
replace_stem_max_pool=backbone_cfg.replace_stem_max_pool,
se_ratio=backbone_cfg.se_ratio,
init_stochastic_depth_rate=backbone_cfg.stochastic_depth_drop_rate,
multigrid=backbone_cfg.multigrid,
last_stage_repeats=backbone_cfg.last_stage_repeats,
activation=norm_activation_config.activation,
use_sync_bn=norm_activation_config.use_sync_bn,
norm_momentum=norm_activation_config.norm_momentum,
norm_epsilon=norm_activation_config.norm_epsilon,
kernel_regularizer=l2_regularizer)