Spaces:
Sleeping
Sleeping
# Copyright 2023 The TensorFlow Authors. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""Contains common building blocks for neural networks.""" | |
from __future__ import absolute_import | |
from __future__ import division | |
from __future__ import print_function | |
import tensorflow as tf, tf_keras | |
from official.modeling import tf_utils | |
class ResidualBlock(tf_keras.layers.Layer): | |
"""A residual block.""" | |
def __init__(self, | |
filters, | |
strides, | |
use_projection=False, | |
kernel_initializer='VarianceScaling', | |
kernel_regularizer=None, | |
bias_regularizer=None, | |
activation='relu', | |
use_sync_bn=False, | |
norm_momentum=0.99, | |
norm_epsilon=0.001, | |
**kwargs): | |
"""A residual block with BN after convolutions. | |
Args: | |
filters: `int` number of filters for the first two convolutions. Note that | |
the third and final convolution will use 4 times as many filters. | |
strides: `int` block stride. If greater than 1, this block will ultimately | |
downsample the input. | |
use_projection: `bool` for whether this block should use a projection | |
shortcut (versus the default identity shortcut). This is usually `True` | |
for the first block of a block group, which may change the number of | |
filters and the resolution. | |
kernel_initializer: kernel_initializer for convolutional layers. | |
kernel_regularizer: tf_keras.regularizers.Regularizer object for Conv2D. | |
Default to None. | |
bias_regularizer: tf_keras.regularizers.Regularizer object for Conv2d. | |
Default to None. | |
activation: `str` name of the activation function. | |
use_sync_bn: if True, use synchronized batch normalization. | |
norm_momentum: `float` normalization omentum for the moving average. | |
norm_epsilon: `float` small float added to variance to avoid dividing by | |
zero. | |
**kwargs: keyword arguments to be passed. | |
""" | |
super(ResidualBlock, self).__init__(**kwargs) | |
self._filters = filters | |
self._strides = strides | |
self._use_projection = use_projection | |
self._use_sync_bn = use_sync_bn | |
self._activation = activation | |
self._kernel_initializer = kernel_initializer | |
self._norm_momentum = norm_momentum | |
self._norm_epsilon = norm_epsilon | |
self._kernel_regularizer = kernel_regularizer | |
self._bias_regularizer = bias_regularizer | |
if use_sync_bn: | |
self._norm = tf_keras.layers.experimental.SyncBatchNormalization | |
else: | |
self._norm = tf_keras.layers.BatchNormalization | |
if tf_keras.backend.image_data_format() == 'channels_last': | |
self._bn_axis = -1 | |
else: | |
self._bn_axis = 1 | |
self._activation_fn = tf_utils.get_activation(activation) | |
def build(self, input_shape): | |
if self._use_projection: | |
self._shortcut = tf_keras.layers.Conv2D( | |
filters=self._filters, | |
kernel_size=1, | |
strides=self._strides, | |
use_bias=False, | |
kernel_initializer=self._kernel_initializer, | |
kernel_regularizer=self._kernel_regularizer, | |
bias_regularizer=self._bias_regularizer) | |
self._norm0 = self._norm( | |
axis=self._bn_axis, | |
momentum=self._norm_momentum, | |
epsilon=self._norm_epsilon) | |
self._conv1 = tf_keras.layers.Conv2D( | |
filters=self._filters, | |
kernel_size=3, | |
strides=self._strides, | |
padding='same', | |
use_bias=False, | |
kernel_initializer=self._kernel_initializer, | |
kernel_regularizer=self._kernel_regularizer, | |
bias_regularizer=self._bias_regularizer) | |
self._norm1 = self._norm( | |
axis=self._bn_axis, | |
momentum=self._norm_momentum, | |
epsilon=self._norm_epsilon) | |
self._conv2 = tf_keras.layers.Conv2D( | |
filters=self._filters, | |
kernel_size=3, | |
strides=1, | |
padding='same', | |
use_bias=False, | |
kernel_initializer=self._kernel_initializer, | |
kernel_regularizer=self._kernel_regularizer, | |
bias_regularizer=self._bias_regularizer) | |
self._norm2 = self._norm( | |
axis=self._bn_axis, | |
momentum=self._norm_momentum, | |
epsilon=self._norm_epsilon) | |
super(ResidualBlock, self).build(input_shape) | |
def get_config(self): | |
config = { | |
'filters': self._filters, | |
'strides': self._strides, | |
'use_projection': self._use_projection, | |
'kernel_initializer': self._kernel_initializer, | |
'kernel_regularizer': self._kernel_regularizer, | |
'bias_regularizer': self._bias_regularizer, | |
'activation': self._activation, | |
'use_sync_bn': self._use_sync_bn, | |
'norm_momentum': self._norm_momentum, | |
'norm_epsilon': self._norm_epsilon | |
} | |
base_config = super(ResidualBlock, self).get_config() | |
return dict(list(base_config.items()) + list(config.items())) | |
def call(self, inputs): | |
shortcut = inputs | |
if self._use_projection: | |
shortcut = self._shortcut(shortcut) | |
shortcut = self._norm0(shortcut) | |
x = self._conv1(inputs) | |
x = self._norm1(x) | |
x = self._activation_fn(x) | |
x = self._conv2(x) | |
x = self._norm2(x) | |
return self._activation_fn(x + shortcut) | |
class BottleneckBlock(tf_keras.layers.Layer): | |
"""A standard bottleneck block.""" | |
def __init__(self, | |
filters, | |
strides, | |
use_projection=False, | |
kernel_initializer='VarianceScaling', | |
kernel_regularizer=None, | |
bias_regularizer=None, | |
activation='relu', | |
use_sync_bn=False, | |
norm_momentum=0.99, | |
norm_epsilon=0.001, | |
**kwargs): | |
"""A standard bottleneck block with BN after convolutions. | |
Args: | |
filters: `int` number of filters for the first two convolutions. Note that | |
the third and final convolution will use 4 times as many filters. | |
strides: `int` block stride. If greater than 1, this block will ultimately | |
downsample the input. | |
use_projection: `bool` for whether this block should use a projection | |
shortcut (versus the default identity shortcut). This is usually `True` | |
for the first block of a block group, which may change the number of | |
filters and the resolution. | |
kernel_initializer: kernel_initializer for convolutional layers. | |
kernel_regularizer: tf_keras.regularizers.Regularizer object for Conv2D. | |
Default to None. | |
bias_regularizer: tf_keras.regularizers.Regularizer object for Conv2d. | |
Default to None. | |
activation: `str` name of the activation function. | |
use_sync_bn: if True, use synchronized batch normalization. | |
norm_momentum: `float` normalization omentum for the moving average. | |
norm_epsilon: `float` small float added to variance to avoid dividing by | |
zero. | |
**kwargs: keyword arguments to be passed. | |
""" | |
super(BottleneckBlock, self).__init__(**kwargs) | |
self._filters = filters | |
self._strides = strides | |
self._use_projection = use_projection | |
self._use_sync_bn = use_sync_bn | |
self._activation = activation | |
self._kernel_initializer = kernel_initializer | |
self._norm_momentum = norm_momentum | |
self._norm_epsilon = norm_epsilon | |
self._kernel_regularizer = kernel_regularizer | |
self._bias_regularizer = bias_regularizer | |
if use_sync_bn: | |
self._norm = tf_keras.layers.experimental.SyncBatchNormalization | |
else: | |
self._norm = tf_keras.layers.BatchNormalization | |
if tf_keras.backend.image_data_format() == 'channels_last': | |
self._bn_axis = -1 | |
else: | |
self._bn_axis = 1 | |
self._activation_fn = tf_utils.get_activation(activation) | |
def build(self, input_shape): | |
if self._use_projection: | |
self._shortcut = tf_keras.layers.Conv2D( | |
filters=self._filters * 4, | |
kernel_size=1, | |
strides=self._strides, | |
use_bias=False, | |
kernel_initializer=self._kernel_initializer, | |
kernel_regularizer=self._kernel_regularizer, | |
bias_regularizer=self._bias_regularizer) | |
self._norm0 = self._norm( | |
axis=self._bn_axis, | |
momentum=self._norm_momentum, | |
epsilon=self._norm_epsilon) | |
self._conv1 = tf_keras.layers.Conv2D( | |
filters=self._filters, | |
kernel_size=1, | |
strides=1, | |
use_bias=False, | |
kernel_initializer=self._kernel_initializer, | |
kernel_regularizer=self._kernel_regularizer, | |
bias_regularizer=self._bias_regularizer) | |
self._norm1 = self._norm( | |
axis=self._bn_axis, | |
momentum=self._norm_momentum, | |
epsilon=self._norm_epsilon) | |
self._conv2 = tf_keras.layers.Conv2D( | |
filters=self._filters, | |
kernel_size=3, | |
strides=self._strides, | |
padding='same', | |
use_bias=False, | |
kernel_initializer=self._kernel_initializer, | |
kernel_regularizer=self._kernel_regularizer, | |
bias_regularizer=self._bias_regularizer) | |
self._norm2 = self._norm( | |
axis=self._bn_axis, | |
momentum=self._norm_momentum, | |
epsilon=self._norm_epsilon) | |
self._conv3 = tf_keras.layers.Conv2D( | |
filters=self._filters * 4, | |
kernel_size=1, | |
strides=1, | |
use_bias=False, | |
kernel_initializer=self._kernel_initializer, | |
kernel_regularizer=self._kernel_regularizer, | |
bias_regularizer=self._bias_regularizer) | |
self._norm3 = self._norm( | |
axis=self._bn_axis, | |
momentum=self._norm_momentum, | |
epsilon=self._norm_epsilon) | |
super(BottleneckBlock, self).build(input_shape) | |
def get_config(self): | |
config = { | |
'filters': self._filters, | |
'strides': self._strides, | |
'use_projection': self._use_projection, | |
'kernel_initializer': self._kernel_initializer, | |
'kernel_regularizer': self._kernel_regularizer, | |
'bias_regularizer': self._bias_regularizer, | |
'activation': self._activation, | |
'use_sync_bn': self._use_sync_bn, | |
'norm_momentum': self._norm_momentum, | |
'norm_epsilon': self._norm_epsilon | |
} | |
base_config = super(BottleneckBlock, self).get_config() | |
return dict(list(base_config.items()) + list(config.items())) | |
def call(self, inputs): | |
shortcut = inputs | |
if self._use_projection: | |
shortcut = self._shortcut(shortcut) | |
shortcut = self._norm0(shortcut) | |
x = self._conv1(inputs) | |
x = self._norm1(x) | |
x = self._activation_fn(x) | |
x = self._conv2(x) | |
x = self._norm2(x) | |
x = self._activation_fn(x) | |
x = self._conv3(x) | |
x = self._norm3(x) | |
return self._activation_fn(x + shortcut) | |