# Copyright 2023 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Common TF utilities.""" import functools import inspect import six import tensorflow as tf, tf_keras from tensorflow.python.util import deprecation from official.modeling import activations @deprecation.deprecated( None, "tf_keras.layers.Layer supports multiple positional args and kwargs as " "input tensors. pack/unpack inputs to override __call__ is no longer " "needed.") def pack_inputs(inputs): """Pack a list of `inputs` tensors to a tuple. Args: inputs: a list of tensors. Returns: a tuple of tensors. if any input is None, replace it with a special constant tensor. """ inputs = tf.nest.flatten(inputs) outputs = [] for x in inputs: if x is None: outputs.append(tf.constant(0, shape=[], dtype=tf.int32)) else: outputs.append(x) return tuple(outputs) @deprecation.deprecated( None, "tf_keras.layers.Layer supports multiple positional args and kwargs as " "input tensors. pack/unpack inputs to override __call__ is no longer " "needed.") def unpack_inputs(inputs): """unpack a tuple of `inputs` tensors to a tuple. Args: inputs: a list of tensors. Returns: a tuple of tensors. if any input is a special constant tensor, replace it with None. """ inputs = tf.nest.flatten(inputs) outputs = [] for x in inputs: if is_special_none_tensor(x): outputs.append(None) else: outputs.append(x) x = tuple(outputs) # To trick the very pointless 'unbalanced-tuple-unpacking' pylint check # from triggering. if len(x) == 1: return x[0] return tuple(outputs) def is_special_none_tensor(tensor): """Checks if a tensor is a special None Tensor.""" return tensor.shape.ndims == 0 and tensor.dtype == tf.int32 def get_activation(identifier, use_keras_layer=False, **kwargs): """Maps an identifier to a Python function, e.g., "relu" => `tf.nn.relu`. It checks string first and if it is one of customized activation not in TF, the corresponding activation will be returned. For non-customized activation names and callable identifiers, always fallback to tf_keras.activations.get. Prefers using keras layers when use_keras_layer=True. Now it only supports 'relu', 'linear', 'identity', 'swish', 'mish', 'leaky_relu', and 'gelu'. Args: identifier: String name of the activation function or callable. use_keras_layer: If True, use keras layer if identifier is allow-listed. **kwargs: Keyword arguments to use to instantiate an activation function. Available only for 'leaky_relu' and 'gelu' when using keras layers. For example: get_activation('leaky_relu', use_keras_layer=True, alpha=0.1) Returns: A Python function corresponding to the activation function or a keras activation layer when use_keras_layer=True. """ if isinstance(identifier, six.string_types): identifier = str(identifier).lower() if use_keras_layer: keras_layer_allowlist = { "relu": "relu", "linear": "linear", "identity": "linear", "swish": "swish", "sigmoid": "sigmoid", "relu6": tf.nn.relu6, "leaky_relu": functools.partial(tf.nn.leaky_relu, **kwargs), "hard_swish": activations.hard_swish, "hard_sigmoid": activations.hard_sigmoid, "mish": activations.mish, "gelu": functools.partial(tf.nn.gelu, **kwargs), } if identifier in keras_layer_allowlist: return tf_keras.layers.Activation(keras_layer_allowlist[identifier]) name_to_fn = { "gelu": activations.gelu, "simple_swish": activations.simple_swish, "hard_swish": activations.hard_swish, "relu6": activations.relu6, "hard_sigmoid": activations.hard_sigmoid, "identity": activations.identity, "mish": activations.mish, } if identifier in name_to_fn: return tf_keras.activations.get(name_to_fn[identifier]) return tf_keras.activations.get(identifier) def get_shape_list(tensor, expected_rank=None, name=None): """Returns a list of the shape of tensor, preferring static dimensions. Args: tensor: A tf.Tensor object to find the shape of. expected_rank: (optional) int. The expected rank of `tensor`. If this is specified and the `tensor` has a different rank, and exception will be thrown. name: Optional name of the tensor for the error message. Returns: A list of dimensions of the shape of tensor. All static dimensions will be returned as python integers, and dynamic dimensions will be returned as tf.Tensor scalars. """ if expected_rank is not None: assert_rank(tensor, expected_rank, name) shape = tensor.shape.as_list() non_static_indexes = [] for (index, dim) in enumerate(shape): if dim is None: non_static_indexes.append(index) if not non_static_indexes: return shape dyn_shape = tf.shape(tensor) for index in non_static_indexes: shape[index] = dyn_shape[index] return shape def assert_rank(tensor, expected_rank, name=None): """Raises an exception if the tensor rank is not of the expected rank. Args: tensor: A tf.Tensor to check the rank of. expected_rank: Python integer or list of integers, expected rank. name: Optional name of the tensor for the error message. Raises: ValueError: If the expected shape doesn't match the actual shape. """ expected_rank_dict = {} if isinstance(expected_rank, six.integer_types): expected_rank_dict[expected_rank] = True else: for x in expected_rank: expected_rank_dict[x] = True actual_rank = tensor.shape.ndims if actual_rank not in expected_rank_dict: raise ValueError( "For the tensor `%s`, the actual tensor rank `%d` (shape = %s) is not " "equal to the expected tensor rank `%s`" % (name, actual_rank, str(tensor.shape), str(expected_rank))) def safe_mean(losses): """Computes a safe mean of the losses. Args: losses: `Tensor` whose elements contain individual loss measurements. Returns: A scalar representing the mean of `losses`. If `num_present` is zero, then zero is returned. """ total = tf.reduce_sum(losses) num_elements = tf.cast(tf.size(losses), dtype=losses.dtype) return tf.math.divide_no_nan(total, num_elements) def get_replica_id(): """Gets replica id depending on the environment.""" context = tf.distribute.get_replica_context() if context is not None: return context.replica_id_in_sync_group else: raise RuntimeError("Unknown replica context. The `get_replica_id` method " "relies on TF 2.x tf.distribute API.") def cross_replica_concat(value, axis, name="cross_replica_concat"): """Concatenates the given `value` across (GPU/TPU) cores, along `axis`. In general, each core ("replica") will pass a replica-specific value as `value` (corresponding to some element of a data-parallel computation taking place across replicas). The resulting concatenated `Tensor` will have the same shape as `value` for all dimensions except `axis`, where it will be larger by a factor of the number of replicas. It will also have the same `dtype` as `value`. The position of a given replica's `value` within the resulting concatenation is determined by that replica's replica ID. For example: With `value` for replica 0 given as 0 0 0 0 0 0 and `value` for replica 1 given as 1 1 1 1 1 1 the resulting concatenation along axis 0 will be 0 0 0 0 0 0 1 1 1 1 1 1 and this result will be identical across all replicas. Note that this API only works in TF2 with `tf.distribute`. Args: value: The `Tensor` to concatenate across replicas. Each replica will have a different value for this `Tensor`, and these replica-specific values will be concatenated. axis: The axis along which to perform the concatenation as a Python integer (not a `Tensor`). E.g., `axis=0` to concatenate along the batch dimension. name: A name for the operation (used to create a name scope). Returns: The result of concatenating `value` along `axis` across replicas. Raises: RuntimeError: when the batch (0-th) dimension is None. """ with tf.name_scope(name): context = tf.distribute.get_replica_context() # Typically this could be hit only if the tensor is derived from a # dataset with finite epochs and drop_remainder=False, where the last # batch could of different batch size and then the dim-0 is of dynamic # shape. if value.shape.as_list()[0] is None: raise RuntimeError(f"{value} has unknown batch.") return context.all_gather(value, axis=axis) def clone_initializer(initializer): # Keras initializer is going to be stateless, which mean reusing the same # initializer will produce same init value when the shapes are the same. if isinstance(initializer, tf_keras.initializers.Initializer): return initializer.__class__.from_config(initializer.get_config()) # When the input is string/dict or other serialized configs, caller will # create a new keras Initializer instance based on that, and we don't need to # do anything return initializer def serialize_keras_object(obj): if hasattr(tf_keras.utils, "legacy"): return tf_keras.utils.legacy.serialize_keras_object(obj) else: return tf_keras.utils.serialize_keras_object(obj) def deserialize_keras_object( config, module_objects=None, custom_objects=None, printable_module_name=None ): if hasattr(tf_keras.utils, "legacy"): return tf_keras.utils.legacy.deserialize_keras_object( config, custom_objects, module_objects, printable_module_name ) else: return tf_keras.utils.deserialize_keras_object( config, custom_objects, module_objects, printable_module_name ) def serialize_layer(layer, use_legacy_format=False): if ( "use_legacy_format" in inspect.getfullargspec(tf_keras.layers.serialize).args ): return tf_keras.layers.serialize(layer, use_legacy_format=use_legacy_format) else: return tf_keras.layers.serialize(layer) def serialize_initializer(initializer, use_legacy_format=False): if ( "use_legacy_format" in inspect.getfullargspec(tf_keras.initializers.serialize).args ): return tf_keras.initializers.serialize( initializer, use_legacy_format=use_legacy_format ) else: return tf_keras.initializers.serialize(initializer) def serialize_regularizer(regularizer, use_legacy_format=False): if ( "use_legacy_format" in inspect.getfullargspec(tf_keras.regularizers.serialize).args ): return tf_keras.regularizers.serialize( regularizer, use_legacy_format=use_legacy_format ) else: return tf_keras.regularizers.serialize(regularizer) def serialize_constraint(constraint, use_legacy_format=False): if ( "use_legacy_format" in inspect.getfullargspec(tf_keras.constraints.serialize).args ): return tf_keras.constraints.serialize( constraint, use_legacy_format=use_legacy_format ) else: return tf_keras.constraints.serialize(constraint) def serialize_activation(activation, use_legacy_format=False): if ( "use_legacy_format" in inspect.getfullargspec(tf_keras.activations.serialize).args ): return tf_keras.activations.serialize( activation, use_legacy_format=use_legacy_format ) else: return tf_keras.activations.serialize(activation)