|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Builder function to construct tf-slim arg_scope for convolution, fc ops.""" |
|
import tensorflow.compat.v1 as tf |
|
import tf_slim as slim |
|
|
|
from object_detection.core import freezable_batch_norm |
|
from object_detection.protos import hyperparams_pb2 |
|
from object_detection.utils import context_manager |
|
|
|
|
|
|
|
|
|
class KerasLayerHyperparams(object): |
|
""" |
|
A hyperparameter configuration object for Keras layers used in |
|
Object Detection models. |
|
""" |
|
|
|
def __init__(self, hyperparams_config): |
|
"""Builds keras hyperparameter config for layers based on the proto config. |
|
|
|
It automatically converts from Slim layer hyperparameter configs to |
|
Keras layer hyperparameters. Namely, it: |
|
- Builds Keras initializers/regularizers instead of Slim ones |
|
- sets weights_regularizer/initializer to kernel_regularizer/initializer |
|
- converts batchnorm decay to momentum |
|
- converts Slim l2 regularizer weights to the equivalent Keras l2 weights |
|
|
|
Contains a hyperparameter configuration for ops that specifies kernel |
|
initializer, kernel regularizer, activation. Also contains parameters for |
|
batch norm operators based on the configuration. |
|
|
|
Note that if the batch_norm parameters are not specified in the config |
|
(i.e. left to default) then batch norm is excluded from the config. |
|
|
|
Args: |
|
hyperparams_config: hyperparams.proto object containing |
|
hyperparameters. |
|
|
|
Raises: |
|
ValueError: if hyperparams_config is not of type hyperparams.Hyperparams. |
|
""" |
|
if not isinstance(hyperparams_config, |
|
hyperparams_pb2.Hyperparams): |
|
raise ValueError('hyperparams_config not of type ' |
|
'hyperparams_pb.Hyperparams.') |
|
|
|
self._batch_norm_params = None |
|
if hyperparams_config.HasField('batch_norm'): |
|
self._batch_norm_params = _build_keras_batch_norm_params( |
|
hyperparams_config.batch_norm) |
|
|
|
self._activation_fn = _build_activation_fn(hyperparams_config.activation) |
|
|
|
|
|
|
|
|
|
self._op_params = { |
|
'kernel_regularizer': _build_keras_regularizer( |
|
hyperparams_config.regularizer), |
|
'kernel_initializer': _build_initializer( |
|
hyperparams_config.initializer, build_for_keras=True), |
|
'activation': _build_activation_fn(hyperparams_config.activation) |
|
} |
|
|
|
def use_batch_norm(self): |
|
return self._batch_norm_params is not None |
|
|
|
def batch_norm_params(self, **overrides): |
|
"""Returns a dict containing batchnorm layer construction hyperparameters. |
|
|
|
Optionally overrides values in the batchnorm hyperparam dict. Overrides |
|
only apply to individual calls of this method, and do not affect |
|
future calls. |
|
|
|
Args: |
|
**overrides: keyword arguments to override in the hyperparams dictionary |
|
|
|
Returns: dict containing the layer construction keyword arguments, with |
|
values overridden by the `overrides` keyword arguments. |
|
""" |
|
if self._batch_norm_params is None: |
|
new_batch_norm_params = dict() |
|
else: |
|
new_batch_norm_params = self._batch_norm_params.copy() |
|
new_batch_norm_params.update(overrides) |
|
return new_batch_norm_params |
|
|
|
def build_batch_norm(self, training=None, **overrides): |
|
"""Returns a Batch Normalization layer with the appropriate hyperparams. |
|
|
|
If the hyperparams are configured to not use batch normalization, |
|
this will return a Keras Lambda layer that only applies tf.Identity, |
|
without doing any normalization. |
|
|
|
Optionally overrides values in the batch_norm hyperparam dict. Overrides |
|
only apply to individual calls of this method, and do not affect |
|
future calls. |
|
|
|
Args: |
|
training: if True, the normalization layer will normalize using the batch |
|
statistics. If False, the normalization layer will be frozen and will |
|
act as if it is being used for inference. If None, the layer |
|
will look up the Keras learning phase at `call` time to decide what to |
|
do. |
|
**overrides: batch normalization construction args to override from the |
|
batch_norm hyperparams dictionary. |
|
|
|
Returns: Either a FreezableBatchNorm layer (if use_batch_norm() is True), |
|
or a Keras Lambda layer that applies the identity (if use_batch_norm() |
|
is False) |
|
""" |
|
if self.use_batch_norm(): |
|
return freezable_batch_norm.FreezableBatchNorm( |
|
training=training, |
|
**self.batch_norm_params(**overrides) |
|
) |
|
else: |
|
return tf.keras.layers.Lambda(tf.identity) |
|
|
|
def build_activation_layer(self, name='activation'): |
|
"""Returns a Keras layer that applies the desired activation function. |
|
|
|
Args: |
|
name: The name to assign the Keras layer. |
|
Returns: A Keras lambda layer that applies the activation function |
|
specified in the hyperparam config, or applies the identity if the |
|
activation function is None. |
|
""" |
|
if self._activation_fn: |
|
return tf.keras.layers.Lambda(self._activation_fn, name=name) |
|
else: |
|
return tf.keras.layers.Lambda(tf.identity, name=name) |
|
|
|
def params(self, include_activation=False, **overrides): |
|
"""Returns a dict containing the layer construction hyperparameters to use. |
|
|
|
Optionally overrides values in the returned dict. Overrides |
|
only apply to individual calls of this method, and do not affect |
|
future calls. |
|
|
|
Args: |
|
include_activation: If False, activation in the returned dictionary will |
|
be set to `None`, and the activation must be applied via a separate |
|
layer created by `build_activation_layer`. If True, `activation` in the |
|
output param dictionary will be set to the activation function |
|
specified in the hyperparams config. |
|
**overrides: keyword arguments to override in the hyperparams dictionary. |
|
|
|
Returns: dict containing the layer construction keyword arguments, with |
|
values overridden by the `overrides` keyword arguments. |
|
""" |
|
new_params = self._op_params.copy() |
|
new_params['activation'] = None |
|
if include_activation: |
|
new_params['activation'] = self._activation_fn |
|
if self.use_batch_norm() and self.batch_norm_params()['center']: |
|
new_params['use_bias'] = False |
|
else: |
|
new_params['use_bias'] = True |
|
new_params.update(**overrides) |
|
return new_params |
|
|
|
|
|
def build(hyperparams_config, is_training): |
|
"""Builds tf-slim arg_scope for convolution ops based on the config. |
|
|
|
Returns an arg_scope to use for convolution ops containing weights |
|
initializer, weights regularizer, activation function, batch norm function |
|
and batch norm parameters based on the configuration. |
|
|
|
Note that if no normalization parameters are specified in the config, |
|
(i.e. left to default) then both batch norm and group norm are excluded |
|
from the arg_scope. |
|
|
|
The batch norm parameters are set for updates based on `is_training` argument |
|
and conv_hyperparams_config.batch_norm.train parameter. During training, they |
|
are updated only if batch_norm.train parameter is true. However, during eval, |
|
no updates are made to the batch norm variables. In both cases, their current |
|
values are used during forward pass. |
|
|
|
Args: |
|
hyperparams_config: hyperparams.proto object containing |
|
hyperparameters. |
|
is_training: Whether the network is in training mode. |
|
|
|
Returns: |
|
arg_scope_fn: A function to construct tf-slim arg_scope containing |
|
hyperparameters for ops. |
|
|
|
Raises: |
|
ValueError: if hyperparams_config is not of type hyperparams.Hyperparams. |
|
""" |
|
if not isinstance(hyperparams_config, |
|
hyperparams_pb2.Hyperparams): |
|
raise ValueError('hyperparams_config not of type ' |
|
'hyperparams_pb.Hyperparams.') |
|
|
|
normalizer_fn = None |
|
batch_norm_params = None |
|
if hyperparams_config.HasField('batch_norm'): |
|
normalizer_fn = slim.batch_norm |
|
batch_norm_params = _build_batch_norm_params( |
|
hyperparams_config.batch_norm, is_training) |
|
if hyperparams_config.HasField('group_norm'): |
|
normalizer_fn = slim.group_norm |
|
affected_ops = [slim.conv2d, slim.separable_conv2d, slim.conv2d_transpose] |
|
if hyperparams_config.HasField('op') and ( |
|
hyperparams_config.op == hyperparams_pb2.Hyperparams.FC): |
|
affected_ops = [slim.fully_connected] |
|
def scope_fn(): |
|
with (slim.arg_scope([slim.batch_norm], **batch_norm_params) |
|
if batch_norm_params is not None else |
|
context_manager.IdentityContextManager()): |
|
with slim.arg_scope( |
|
affected_ops, |
|
weights_regularizer=_build_slim_regularizer( |
|
hyperparams_config.regularizer), |
|
weights_initializer=_build_initializer( |
|
hyperparams_config.initializer), |
|
activation_fn=_build_activation_fn(hyperparams_config.activation), |
|
normalizer_fn=normalizer_fn) as sc: |
|
return sc |
|
|
|
return scope_fn |
|
|
|
|
|
def _build_activation_fn(activation_fn): |
|
"""Builds a callable activation from config. |
|
|
|
Args: |
|
activation_fn: hyperparams_pb2.Hyperparams.activation |
|
|
|
Returns: |
|
Callable activation function. |
|
|
|
Raises: |
|
ValueError: On unknown activation function. |
|
""" |
|
if activation_fn == hyperparams_pb2.Hyperparams.NONE: |
|
return None |
|
if activation_fn == hyperparams_pb2.Hyperparams.RELU: |
|
return tf.nn.relu |
|
if activation_fn == hyperparams_pb2.Hyperparams.RELU_6: |
|
return tf.nn.relu6 |
|
if activation_fn == hyperparams_pb2.Hyperparams.SWISH: |
|
return tf.nn.swish |
|
raise ValueError('Unknown activation function: {}'.format(activation_fn)) |
|
|
|
|
|
def _build_slim_regularizer(regularizer): |
|
"""Builds a tf-slim regularizer from config. |
|
|
|
Args: |
|
regularizer: hyperparams_pb2.Hyperparams.regularizer proto. |
|
|
|
Returns: |
|
tf-slim regularizer. |
|
|
|
Raises: |
|
ValueError: On unknown regularizer. |
|
""" |
|
regularizer_oneof = regularizer.WhichOneof('regularizer_oneof') |
|
if regularizer_oneof == 'l1_regularizer': |
|
return slim.l1_regularizer(scale=float(regularizer.l1_regularizer.weight)) |
|
if regularizer_oneof == 'l2_regularizer': |
|
return slim.l2_regularizer(scale=float(regularizer.l2_regularizer.weight)) |
|
if regularizer_oneof is None: |
|
return None |
|
raise ValueError('Unknown regularizer function: {}'.format(regularizer_oneof)) |
|
|
|
|
|
def _build_keras_regularizer(regularizer): |
|
"""Builds a keras regularizer from config. |
|
|
|
Args: |
|
regularizer: hyperparams_pb2.Hyperparams.regularizer proto. |
|
|
|
Returns: |
|
Keras regularizer. |
|
|
|
Raises: |
|
ValueError: On unknown regularizer. |
|
""" |
|
regularizer_oneof = regularizer.WhichOneof('regularizer_oneof') |
|
if regularizer_oneof == 'l1_regularizer': |
|
return tf.keras.regularizers.l1(float(regularizer.l1_regularizer.weight)) |
|
if regularizer_oneof == 'l2_regularizer': |
|
|
|
|
|
return tf.keras.regularizers.l2( |
|
float(regularizer.l2_regularizer.weight * 0.5)) |
|
if regularizer_oneof is None: |
|
return None |
|
raise ValueError('Unknown regularizer function: {}'.format(regularizer_oneof)) |
|
|
|
|
|
def _build_initializer(initializer, build_for_keras=False): |
|
"""Build a tf initializer from config. |
|
|
|
Args: |
|
initializer: hyperparams_pb2.Hyperparams.regularizer proto. |
|
build_for_keras: Whether the initializers should be built for Keras |
|
operators. If false builds for Slim. |
|
|
|
Returns: |
|
tf initializer. |
|
|
|
Raises: |
|
ValueError: On unknown initializer. |
|
""" |
|
initializer_oneof = initializer.WhichOneof('initializer_oneof') |
|
if initializer_oneof == 'truncated_normal_initializer': |
|
return tf.truncated_normal_initializer( |
|
mean=initializer.truncated_normal_initializer.mean, |
|
stddev=initializer.truncated_normal_initializer.stddev) |
|
if initializer_oneof == 'random_normal_initializer': |
|
return tf.random_normal_initializer( |
|
mean=initializer.random_normal_initializer.mean, |
|
stddev=initializer.random_normal_initializer.stddev) |
|
if initializer_oneof == 'variance_scaling_initializer': |
|
enum_descriptor = (hyperparams_pb2.VarianceScalingInitializer. |
|
DESCRIPTOR.enum_types_by_name['Mode']) |
|
mode = enum_descriptor.values_by_number[initializer. |
|
variance_scaling_initializer. |
|
mode].name |
|
if build_for_keras: |
|
if initializer.variance_scaling_initializer.uniform: |
|
return tf.variance_scaling_initializer( |
|
scale=initializer.variance_scaling_initializer.factor, |
|
mode=mode.lower(), |
|
distribution='uniform') |
|
else: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
try: |
|
return tf.variance_scaling_initializer( |
|
scale=initializer.variance_scaling_initializer.factor, |
|
mode=mode.lower(), |
|
distribution='truncated_normal') |
|
except ValueError: |
|
truncate_constant = 0.87962566103423978 |
|
truncated_scale = initializer.variance_scaling_initializer.factor / ( |
|
truncate_constant * truncate_constant |
|
) |
|
return tf.variance_scaling_initializer( |
|
scale=truncated_scale, |
|
mode=mode.lower(), |
|
distribution='normal') |
|
|
|
else: |
|
return slim.variance_scaling_initializer( |
|
factor=initializer.variance_scaling_initializer.factor, |
|
mode=mode, |
|
uniform=initializer.variance_scaling_initializer.uniform) |
|
if initializer_oneof is None: |
|
return None |
|
raise ValueError('Unknown initializer function: {}'.format( |
|
initializer_oneof)) |
|
|
|
|
|
def _build_batch_norm_params(batch_norm, is_training): |
|
"""Build a dictionary of batch_norm params from config. |
|
|
|
Args: |
|
batch_norm: hyperparams_pb2.ConvHyperparams.batch_norm proto. |
|
is_training: Whether the models is in training mode. |
|
|
|
Returns: |
|
A dictionary containing batch_norm parameters. |
|
""" |
|
batch_norm_params = { |
|
'decay': batch_norm.decay, |
|
'center': batch_norm.center, |
|
'scale': batch_norm.scale, |
|
'epsilon': batch_norm.epsilon, |
|
|
|
|
|
|
|
'is_training': is_training and batch_norm.train, |
|
} |
|
return batch_norm_params |
|
|
|
|
|
def _build_keras_batch_norm_params(batch_norm): |
|
"""Build a dictionary of Keras BatchNormalization params from config. |
|
|
|
Args: |
|
batch_norm: hyperparams_pb2.ConvHyperparams.batch_norm proto. |
|
|
|
Returns: |
|
A dictionary containing Keras BatchNormalization parameters. |
|
""" |
|
|
|
|
|
|
|
|
|
batch_norm_params = { |
|
'momentum': batch_norm.decay, |
|
'center': batch_norm.center, |
|
'scale': batch_norm.scale, |
|
'epsilon': batch_norm.epsilon, |
|
} |
|
return batch_norm_params |
|
|