deanna-emery's picture
updates
93528c6
raw
history blame
18.7 kB
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Contains definitions of 3D Residual Networks."""
from typing import Callable, List, Tuple, Optional
# Import libraries
import tensorflow as tf, tf_keras
from official.modeling import hyperparams
from official.modeling import tf_utils
from official.vision.modeling.backbones import factory
from official.vision.modeling.layers import nn_blocks_3d
from official.vision.modeling.layers import nn_layers
layers = tf_keras.layers
RESNET_SPECS = {
50: [
('bottleneck3d', 64, 3),
('bottleneck3d', 128, 4),
('bottleneck3d', 256, 6),
('bottleneck3d', 512, 3),
],
101: [
('bottleneck3d', 64, 3),
('bottleneck3d', 128, 4),
('bottleneck3d', 256, 23),
('bottleneck3d', 512, 3),
],
152: [
('bottleneck3d', 64, 3),
('bottleneck3d', 128, 8),
('bottleneck3d', 256, 36),
('bottleneck3d', 512, 3),
],
200: [
('bottleneck3d', 64, 3),
('bottleneck3d', 128, 24),
('bottleneck3d', 256, 36),
('bottleneck3d', 512, 3),
],
270: [
('bottleneck3d', 64, 4),
('bottleneck3d', 128, 29),
('bottleneck3d', 256, 53),
('bottleneck3d', 512, 4),
],
300: [
('bottleneck3d', 64, 4),
('bottleneck3d', 128, 36),
('bottleneck3d', 256, 54),
('bottleneck3d', 512, 4),
],
350: [
('bottleneck3d', 64, 4),
('bottleneck3d', 128, 36),
('bottleneck3d', 256, 72),
('bottleneck3d', 512, 4),
],
}
@tf_keras.utils.register_keras_serializable(package='Vision')
class ResNet3D(tf_keras.Model):
"""Creates a 3D ResNet family model."""
def __init__(
self,
model_id: int,
temporal_strides: List[int],
temporal_kernel_sizes: List[Tuple[int]],
use_self_gating: Optional[List[int]] = None,
input_specs: tf_keras.layers.InputSpec = layers.InputSpec(
shape=[None, None, None, None, 3]),
stem_type: str = 'v0',
stem_conv_temporal_kernel_size: int = 5,
stem_conv_temporal_stride: int = 2,
stem_pool_temporal_stride: int = 2,
init_stochastic_depth_rate: float = 0.0,
activation: str = 'relu',
se_ratio: Optional[float] = None,
use_sync_bn: bool = False,
norm_momentum: float = 0.99,
norm_epsilon: float = 0.001,
kernel_initializer: str = 'VarianceScaling',
kernel_regularizer: Optional[tf_keras.regularizers.Regularizer] = None,
bias_regularizer: Optional[tf_keras.regularizers.Regularizer] = None,
**kwargs):
"""Initializes a 3D ResNet model.
Args:
model_id: An `int` of depth of ResNet backbone model.
temporal_strides: A list of integers that specifies the temporal strides
for all 3d blocks.
temporal_kernel_sizes: A list of tuples that specifies the temporal kernel
sizes for all 3d blocks in different block groups.
use_self_gating: A list of booleans to specify applying self-gating module
or not in each block group. If None, self-gating is not applied.
input_specs: A `tf_keras.layers.InputSpec` of the input tensor.
stem_type: A `str` of stem type of ResNet. Default to `v0`. If set to
`v1`, use ResNet-D type stem (https://arxiv.org/abs/1812.01187).
stem_conv_temporal_kernel_size: An `int` of temporal kernel size for the
first conv layer.
stem_conv_temporal_stride: An `int` of temporal stride for the first conv
layer.
stem_pool_temporal_stride: An `int` of temporal stride for the first pool
layer.
init_stochastic_depth_rate: A `float` of initial stochastic depth rate.
activation: A `str` of name of the activation function.
se_ratio: A `float` or None. Ratio of the Squeeze-and-Excitation layer.
use_sync_bn: If True, use synchronized batch normalization.
norm_momentum: A `float` of normalization momentum for the moving average.
norm_epsilon: A `float` added to variance to avoid dividing by zero.
kernel_initializer: A str for kernel initializer of convolutional layers.
kernel_regularizer: A `tf_keras.regularizers.Regularizer` object for
Conv2D. Default to None.
bias_regularizer: A `tf_keras.regularizers.Regularizer` object for Conv2D.
Default to None.
**kwargs: Additional keyword arguments to be passed.
"""
self._model_id = model_id
self._temporal_strides = temporal_strides
self._temporal_kernel_sizes = temporal_kernel_sizes
self._input_specs = input_specs
self._stem_type = stem_type
self._stem_conv_temporal_kernel_size = stem_conv_temporal_kernel_size
self._stem_conv_temporal_stride = stem_conv_temporal_stride
self._stem_pool_temporal_stride = stem_pool_temporal_stride
self._use_self_gating = use_self_gating
self._se_ratio = se_ratio
self._init_stochastic_depth_rate = init_stochastic_depth_rate
self._use_sync_bn = use_sync_bn
self._activation = activation
self._norm_momentum = norm_momentum
self._norm_epsilon = norm_epsilon
self._norm = layers.BatchNormalization
self._kernel_initializer = kernel_initializer
self._kernel_regularizer = kernel_regularizer
self._bias_regularizer = bias_regularizer
if tf_keras.backend.image_data_format() == 'channels_last':
self._bn_axis = -1
else:
self._bn_axis = 1
# Build ResNet3D backbone.
inputs = tf_keras.Input(shape=input_specs.shape[1:])
endpoints = self._build_model(inputs)
self._output_specs = {l: endpoints[l].get_shape() for l in endpoints}
super(ResNet3D, self).__init__(inputs=inputs, outputs=endpoints, **kwargs)
def _build_model(self, inputs):
"""Builds model architecture.
Args:
inputs: the keras input spec.
Returns:
endpoints: A dictionary of backbone endpoint features.
"""
# Build stem.
x = self._build_stem(inputs, stem_type=self._stem_type)
temporal_kernel_size = 1 if self._stem_pool_temporal_stride == 1 else 3
x = layers.MaxPool3D(
pool_size=[temporal_kernel_size, 3, 3],
strides=[self._stem_pool_temporal_stride, 2, 2],
padding='same')(x)
# Build intermediate blocks and endpoints.
resnet_specs = RESNET_SPECS[self._model_id]
if len(self._temporal_strides) != len(resnet_specs) or len(
self._temporal_kernel_sizes) != len(resnet_specs):
raise ValueError(
'Number of blocks in temporal specs should equal to resnet_specs.')
endpoints = {}
for i, resnet_spec in enumerate(resnet_specs):
if resnet_spec[0] == 'bottleneck3d':
block_fn = nn_blocks_3d.BottleneckBlock3D
else:
raise ValueError('Block fn `{}` is not supported.'.format(
resnet_spec[0]))
use_self_gating = (
self._use_self_gating[i] if self._use_self_gating else False)
x = self._block_group(
inputs=x,
filters=resnet_spec[1],
temporal_kernel_sizes=self._temporal_kernel_sizes[i],
temporal_strides=self._temporal_strides[i],
spatial_strides=(1 if i == 0 else 2),
block_fn=block_fn,
block_repeats=resnet_spec[2],
stochastic_depth_drop_rate=nn_layers.get_stochastic_depth_rate(
self._init_stochastic_depth_rate, i + 2, 5),
use_self_gating=use_self_gating,
name='block_group_l{}'.format(i + 2))
endpoints[str(i + 2)] = x
return endpoints
def _build_stem(self, inputs, stem_type):
"""Builds stem layer."""
# Build stem.
if stem_type == 'v0':
x = layers.Conv3D(
filters=64,
kernel_size=[self._stem_conv_temporal_kernel_size, 7, 7],
strides=[self._stem_conv_temporal_stride, 2, 2],
use_bias=False,
padding='same',
kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer)(
inputs)
x = self._norm(
axis=self._bn_axis,
momentum=self._norm_momentum,
epsilon=self._norm_epsilon,
synchronized=self._use_sync_bn)(x)
x = tf_utils.get_activation(self._activation)(x)
elif stem_type == 'v1':
x = layers.Conv3D(
filters=32,
kernel_size=[self._stem_conv_temporal_kernel_size, 3, 3],
strides=[self._stem_conv_temporal_stride, 2, 2],
use_bias=False,
padding='same',
kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer)(
inputs)
x = self._norm(
axis=self._bn_axis,
momentum=self._norm_momentum,
epsilon=self._norm_epsilon,
synchronized=self._use_sync_bn)(x)
x = tf_utils.get_activation(self._activation)(x)
x = layers.Conv3D(
filters=32,
kernel_size=[1, 3, 3],
strides=[1, 1, 1],
use_bias=False,
padding='same',
kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer)(
x)
x = self._norm(
axis=self._bn_axis,
momentum=self._norm_momentum,
epsilon=self._norm_epsilon,
synchronized=self._use_sync_bn)(x)
x = tf_utils.get_activation(self._activation)(x)
x = layers.Conv3D(
filters=64,
kernel_size=[1, 3, 3],
strides=[1, 1, 1],
use_bias=False,
padding='same',
kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer)(
x)
x = self._norm(
axis=self._bn_axis,
momentum=self._norm_momentum,
epsilon=self._norm_epsilon,
synchronized=self._use_sync_bn)(x)
x = tf_utils.get_activation(self._activation)(x)
else:
raise ValueError(f'Stem type {stem_type} not supported.')
return x
def _block_group(self,
inputs: tf.Tensor,
filters: int,
temporal_kernel_sizes: Tuple[int],
temporal_strides: int,
spatial_strides: int,
block_fn: Callable[
...,
tf_keras.layers.Layer] = nn_blocks_3d.BottleneckBlock3D,
block_repeats: int = 1,
stochastic_depth_drop_rate: float = 0.0,
use_self_gating: bool = False,
name: str = 'block_group'):
"""Creates one group of blocks for the ResNet3D model.
Args:
inputs: A `tf.Tensor` of size `[batch, channels, height, width]`.
filters: An `int` of number of filters for the first convolution of the
layer.
temporal_kernel_sizes: A tuple that specifies the temporal kernel sizes
for each block in the current group.
temporal_strides: An `int` of temporal strides for the first convolution
in this group.
spatial_strides: An `int` stride to use for the first convolution of the
layer. If greater than 1, this layer will downsample the input.
block_fn: Either `nn_blocks.ResidualBlock` or `nn_blocks.BottleneckBlock`.
block_repeats: An `int` of number of blocks contained in the layer.
stochastic_depth_drop_rate: A `float` of drop rate of the current block
group.
use_self_gating: A `bool` that specifies whether to apply self-gating
module or not.
name: A `str` name for the block.
Returns:
The output `tf.Tensor` of the block layer.
"""
if len(temporal_kernel_sizes) != block_repeats:
raise ValueError(
'Number of elements in `temporal_kernel_sizes` must equal to `block_repeats`.'
)
# Only apply self-gating module in the last block.
use_self_gating_list = [False] * (block_repeats - 1) + [use_self_gating]
x = block_fn(
filters=filters,
temporal_kernel_size=temporal_kernel_sizes[0],
temporal_strides=temporal_strides,
spatial_strides=spatial_strides,
stochastic_depth_drop_rate=stochastic_depth_drop_rate,
use_self_gating=use_self_gating_list[0],
se_ratio=self._se_ratio,
kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer,
activation=self._activation,
use_sync_bn=self._use_sync_bn,
norm_momentum=self._norm_momentum,
norm_epsilon=self._norm_epsilon)(
inputs)
for i in range(1, block_repeats):
x = block_fn(
filters=filters,
temporal_kernel_size=temporal_kernel_sizes[i],
temporal_strides=1,
spatial_strides=1,
stochastic_depth_drop_rate=stochastic_depth_drop_rate,
use_self_gating=use_self_gating_list[i],
se_ratio=self._se_ratio,
kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer,
activation=self._activation,
use_sync_bn=self._use_sync_bn,
norm_momentum=self._norm_momentum,
norm_epsilon=self._norm_epsilon)(
x)
return tf.identity(x, name=name)
def get_config(self):
config_dict = {
'model_id': self._model_id,
'temporal_strides': self._temporal_strides,
'temporal_kernel_sizes': self._temporal_kernel_sizes,
'stem_type': self._stem_type,
'stem_conv_temporal_kernel_size': self._stem_conv_temporal_kernel_size,
'stem_conv_temporal_stride': self._stem_conv_temporal_stride,
'stem_pool_temporal_stride': self._stem_pool_temporal_stride,
'use_self_gating': self._use_self_gating,
'se_ratio': self._se_ratio,
'init_stochastic_depth_rate': self._init_stochastic_depth_rate,
'activation': self._activation,
'use_sync_bn': self._use_sync_bn,
'norm_momentum': self._norm_momentum,
'norm_epsilon': self._norm_epsilon,
'kernel_initializer': self._kernel_initializer,
'kernel_regularizer': self._kernel_regularizer,
'bias_regularizer': self._bias_regularizer,
}
return config_dict
@classmethod
def from_config(cls, config, custom_objects=None):
return cls(**config)
@property
def output_specs(self):
"""A dict of {level: TensorShape} pairs for the model output."""
return self._output_specs
@factory.register_backbone_builder('resnet_3d')
def build_resnet3d(
input_specs: tf_keras.layers.InputSpec,
backbone_config: hyperparams.Config,
norm_activation_config: hyperparams.Config,
l2_regularizer: Optional[tf_keras.regularizers.Regularizer] = None
) -> tf_keras.Model:
"""Builds ResNet 3d backbone from a config."""
backbone_cfg = backbone_config.get()
# Flatten configs before passing to the backbone.
temporal_strides = []
temporal_kernel_sizes = []
use_self_gating = []
for block_spec in backbone_cfg.block_specs:
temporal_strides.append(block_spec.temporal_strides)
temporal_kernel_sizes.append(block_spec.temporal_kernel_sizes)
use_self_gating.append(block_spec.use_self_gating)
return ResNet3D(
model_id=backbone_cfg.model_id,
temporal_strides=temporal_strides,
temporal_kernel_sizes=temporal_kernel_sizes,
use_self_gating=use_self_gating,
input_specs=input_specs,
stem_type=backbone_cfg.stem_type,
stem_conv_temporal_kernel_size=backbone_cfg
.stem_conv_temporal_kernel_size,
stem_conv_temporal_stride=backbone_cfg.stem_conv_temporal_stride,
stem_pool_temporal_stride=backbone_cfg.stem_pool_temporal_stride,
init_stochastic_depth_rate=backbone_cfg.stochastic_depth_drop_rate,
se_ratio=backbone_cfg.se_ratio,
activation=norm_activation_config.activation,
use_sync_bn=norm_activation_config.use_sync_bn,
norm_momentum=norm_activation_config.norm_momentum,
norm_epsilon=norm_activation_config.norm_epsilon,
kernel_regularizer=l2_regularizer)
@factory.register_backbone_builder('resnet_3d_rs')
def build_resnet3d_rs(
input_specs: tf_keras.layers.InputSpec,
backbone_config: hyperparams.Config,
norm_activation_config: hyperparams.Config,
l2_regularizer: Optional[tf_keras.regularizers.Regularizer] = None
) -> tf_keras.Model:
"""Builds ResNet-3D-RS backbone from a config."""
backbone_cfg = backbone_config.get()
# Flatten configs before passing to the backbone.
temporal_strides = []
temporal_kernel_sizes = []
use_self_gating = []
for i, block_spec in enumerate(backbone_cfg.block_specs):
temporal_strides.append(block_spec.temporal_strides)
use_self_gating.append(block_spec.use_self_gating)
block_repeats_i = RESNET_SPECS[backbone_cfg.model_id][i][-1]
temporal_kernel_sizes.append(list(block_spec.temporal_kernel_sizes) *
block_repeats_i)
return ResNet3D(
model_id=backbone_cfg.model_id,
temporal_strides=temporal_strides,
temporal_kernel_sizes=temporal_kernel_sizes,
use_self_gating=use_self_gating,
input_specs=input_specs,
stem_type=backbone_cfg.stem_type,
stem_conv_temporal_kernel_size=backbone_cfg
.stem_conv_temporal_kernel_size,
stem_conv_temporal_stride=backbone_cfg.stem_conv_temporal_stride,
stem_pool_temporal_stride=backbone_cfg.stem_pool_temporal_stride,
init_stochastic_depth_rate=backbone_cfg.stochastic_depth_drop_rate,
se_ratio=backbone_cfg.se_ratio,
activation=norm_activation_config.activation,
use_sync_bn=norm_activation_config.use_sync_bn,
norm_momentum=norm_activation_config.norm_momentum,
norm_epsilon=norm_activation_config.norm_epsilon,
kernel_regularizer=l2_regularizer)