Spaces:
Runtime error
Runtime error
# Copyright 2023 The TensorFlow Authors. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""Common TF utilities.""" | |
import functools | |
import inspect | |
import six | |
import tensorflow as tf, tf_keras | |
from tensorflow.python.util import deprecation | |
from official.modeling import activations | |
def pack_inputs(inputs): | |
"""Pack a list of `inputs` tensors to a tuple. | |
Args: | |
inputs: a list of tensors. | |
Returns: | |
a tuple of tensors. if any input is None, replace it with a special constant | |
tensor. | |
""" | |
inputs = tf.nest.flatten(inputs) | |
outputs = [] | |
for x in inputs: | |
if x is None: | |
outputs.append(tf.constant(0, shape=[], dtype=tf.int32)) | |
else: | |
outputs.append(x) | |
return tuple(outputs) | |
def unpack_inputs(inputs): | |
"""unpack a tuple of `inputs` tensors to a tuple. | |
Args: | |
inputs: a list of tensors. | |
Returns: | |
a tuple of tensors. if any input is a special constant tensor, replace it | |
with None. | |
""" | |
inputs = tf.nest.flatten(inputs) | |
outputs = [] | |
for x in inputs: | |
if is_special_none_tensor(x): | |
outputs.append(None) | |
else: | |
outputs.append(x) | |
x = tuple(outputs) | |
# To trick the very pointless 'unbalanced-tuple-unpacking' pylint check | |
# from triggering. | |
if len(x) == 1: | |
return x[0] | |
return tuple(outputs) | |
def is_special_none_tensor(tensor): | |
"""Checks if a tensor is a special None Tensor.""" | |
return tensor.shape.ndims == 0 and tensor.dtype == tf.int32 | |
def get_activation(identifier, use_keras_layer=False, **kwargs): | |
"""Maps an identifier to a Python function, e.g., "relu" => `tf.nn.relu`. | |
It checks string first and if it is one of customized activation not in TF, | |
the corresponding activation will be returned. For non-customized activation | |
names and callable identifiers, always fallback to tf_keras.activations.get. | |
Prefers using keras layers when use_keras_layer=True. Now it only supports | |
'relu', 'linear', 'identity', 'swish', 'mish', 'leaky_relu', and 'gelu'. | |
Args: | |
identifier: String name of the activation function or callable. | |
use_keras_layer: If True, use keras layer if identifier is allow-listed. | |
**kwargs: Keyword arguments to use to instantiate an activation function. | |
Available only for 'leaky_relu' and 'gelu' when using keras layers. | |
For example: get_activation('leaky_relu', use_keras_layer=True, alpha=0.1) | |
Returns: | |
A Python function corresponding to the activation function or a keras | |
activation layer when use_keras_layer=True. | |
""" | |
if isinstance(identifier, six.string_types): | |
identifier = str(identifier).lower() | |
if use_keras_layer: | |
keras_layer_allowlist = { | |
"relu": "relu", | |
"linear": "linear", | |
"identity": "linear", | |
"swish": "swish", | |
"sigmoid": "sigmoid", | |
"relu6": tf.nn.relu6, | |
"leaky_relu": functools.partial(tf.nn.leaky_relu, **kwargs), | |
"hard_swish": activations.hard_swish, | |
"hard_sigmoid": activations.hard_sigmoid, | |
"mish": activations.mish, | |
"gelu": functools.partial(tf.nn.gelu, **kwargs), | |
} | |
if identifier in keras_layer_allowlist: | |
return tf_keras.layers.Activation(keras_layer_allowlist[identifier]) | |
name_to_fn = { | |
"gelu": activations.gelu, | |
"simple_swish": activations.simple_swish, | |
"hard_swish": activations.hard_swish, | |
"relu6": activations.relu6, | |
"hard_sigmoid": activations.hard_sigmoid, | |
"identity": activations.identity, | |
"mish": activations.mish, | |
} | |
if identifier in name_to_fn: | |
return tf_keras.activations.get(name_to_fn[identifier]) | |
return tf_keras.activations.get(identifier) | |
def get_shape_list(tensor, expected_rank=None, name=None): | |
"""Returns a list of the shape of tensor, preferring static dimensions. | |
Args: | |
tensor: A tf.Tensor object to find the shape of. | |
expected_rank: (optional) int. The expected rank of `tensor`. If this is | |
specified and the `tensor` has a different rank, and exception will be | |
thrown. | |
name: Optional name of the tensor for the error message. | |
Returns: | |
A list of dimensions of the shape of tensor. All static dimensions will | |
be returned as python integers, and dynamic dimensions will be returned | |
as tf.Tensor scalars. | |
""" | |
if expected_rank is not None: | |
assert_rank(tensor, expected_rank, name) | |
shape = tensor.shape.as_list() | |
non_static_indexes = [] | |
for (index, dim) in enumerate(shape): | |
if dim is None: | |
non_static_indexes.append(index) | |
if not non_static_indexes: | |
return shape | |
dyn_shape = tf.shape(tensor) | |
for index in non_static_indexes: | |
shape[index] = dyn_shape[index] | |
return shape | |
def assert_rank(tensor, expected_rank, name=None): | |
"""Raises an exception if the tensor rank is not of the expected rank. | |
Args: | |
tensor: A tf.Tensor to check the rank of. | |
expected_rank: Python integer or list of integers, expected rank. | |
name: Optional name of the tensor for the error message. | |
Raises: | |
ValueError: If the expected shape doesn't match the actual shape. | |
""" | |
expected_rank_dict = {} | |
if isinstance(expected_rank, six.integer_types): | |
expected_rank_dict[expected_rank] = True | |
else: | |
for x in expected_rank: | |
expected_rank_dict[x] = True | |
actual_rank = tensor.shape.ndims | |
if actual_rank not in expected_rank_dict: | |
raise ValueError( | |
"For the tensor `%s`, the actual tensor rank `%d` (shape = %s) is not " | |
"equal to the expected tensor rank `%s`" % | |
(name, actual_rank, str(tensor.shape), str(expected_rank))) | |
def safe_mean(losses): | |
"""Computes a safe mean of the losses. | |
Args: | |
losses: `Tensor` whose elements contain individual loss measurements. | |
Returns: | |
A scalar representing the mean of `losses`. If `num_present` is zero, | |
then zero is returned. | |
""" | |
total = tf.reduce_sum(losses) | |
num_elements = tf.cast(tf.size(losses), dtype=losses.dtype) | |
return tf.math.divide_no_nan(total, num_elements) | |
def get_replica_id(): | |
"""Gets replica id depending on the environment.""" | |
context = tf.distribute.get_replica_context() | |
if context is not None: | |
return context.replica_id_in_sync_group | |
else: | |
raise RuntimeError("Unknown replica context. The `get_replica_id` method " | |
"relies on TF 2.x tf.distribute API.") | |
def cross_replica_concat(value, axis, name="cross_replica_concat"): | |
"""Concatenates the given `value` across (GPU/TPU) cores, along `axis`. | |
In general, each core ("replica") will pass a | |
replica-specific value as `value` (corresponding to some element of a | |
data-parallel computation taking place across replicas). | |
The resulting concatenated `Tensor` will have the same shape as `value` for | |
all dimensions except `axis`, where it will be larger by a factor of the | |
number of replicas. It will also have the same `dtype` as `value`. | |
The position of a given replica's `value` within the resulting concatenation | |
is determined by that replica's replica ID. For | |
example: | |
With `value` for replica 0 given as | |
0 0 0 | |
0 0 0 | |
and `value` for replica 1 given as | |
1 1 1 | |
1 1 1 | |
the resulting concatenation along axis 0 will be | |
0 0 0 | |
0 0 0 | |
1 1 1 | |
1 1 1 | |
and this result will be identical across all replicas. | |
Note that this API only works in TF2 with `tf.distribute`. | |
Args: | |
value: The `Tensor` to concatenate across replicas. Each replica will have a | |
different value for this `Tensor`, and these replica-specific values will | |
be concatenated. | |
axis: The axis along which to perform the concatenation as a Python integer | |
(not a `Tensor`). E.g., `axis=0` to concatenate along the batch dimension. | |
name: A name for the operation (used to create a name scope). | |
Returns: | |
The result of concatenating `value` along `axis` across replicas. | |
Raises: | |
RuntimeError: when the batch (0-th) dimension is None. | |
""" | |
with tf.name_scope(name): | |
context = tf.distribute.get_replica_context() | |
# Typically this could be hit only if the tensor is derived from a | |
# dataset with finite epochs and drop_remainder=False, where the last | |
# batch could of different batch size and then the dim-0 is of dynamic | |
# shape. | |
if value.shape.as_list()[0] is None: | |
raise RuntimeError(f"{value} has unknown batch.") | |
return context.all_gather(value, axis=axis) | |
def clone_initializer(initializer): | |
# Keras initializer is going to be stateless, which mean reusing the same | |
# initializer will produce same init value when the shapes are the same. | |
if isinstance(initializer, tf_keras.initializers.Initializer): | |
return initializer.__class__.from_config(initializer.get_config()) | |
# When the input is string/dict or other serialized configs, caller will | |
# create a new keras Initializer instance based on that, and we don't need to | |
# do anything | |
return initializer | |
def serialize_keras_object(obj): | |
if hasattr(tf_keras.utils, "legacy"): | |
return tf_keras.utils.legacy.serialize_keras_object(obj) | |
else: | |
return tf_keras.utils.serialize_keras_object(obj) | |
def deserialize_keras_object( | |
config, module_objects=None, custom_objects=None, printable_module_name=None | |
): | |
if hasattr(tf_keras.utils, "legacy"): | |
return tf_keras.utils.legacy.deserialize_keras_object( | |
config, custom_objects, module_objects, printable_module_name | |
) | |
else: | |
return tf_keras.utils.deserialize_keras_object( | |
config, custom_objects, module_objects, printable_module_name | |
) | |
def serialize_layer(layer, use_legacy_format=False): | |
if ( | |
"use_legacy_format" | |
in inspect.getfullargspec(tf_keras.layers.serialize).args | |
): | |
return tf_keras.layers.serialize(layer, use_legacy_format=use_legacy_format) | |
else: | |
return tf_keras.layers.serialize(layer) | |
def serialize_initializer(initializer, use_legacy_format=False): | |
if ( | |
"use_legacy_format" | |
in inspect.getfullargspec(tf_keras.initializers.serialize).args | |
): | |
return tf_keras.initializers.serialize( | |
initializer, use_legacy_format=use_legacy_format | |
) | |
else: | |
return tf_keras.initializers.serialize(initializer) | |
def serialize_regularizer(regularizer, use_legacy_format=False): | |
if ( | |
"use_legacy_format" | |
in inspect.getfullargspec(tf_keras.regularizers.serialize).args | |
): | |
return tf_keras.regularizers.serialize( | |
regularizer, use_legacy_format=use_legacy_format | |
) | |
else: | |
return tf_keras.regularizers.serialize(regularizer) | |
def serialize_constraint(constraint, use_legacy_format=False): | |
if ( | |
"use_legacy_format" | |
in inspect.getfullargspec(tf_keras.constraints.serialize).args | |
): | |
return tf_keras.constraints.serialize( | |
constraint, use_legacy_format=use_legacy_format | |
) | |
else: | |
return tf_keras.constraints.serialize(constraint) | |
def serialize_activation(activation, use_legacy_format=False): | |
if ( | |
"use_legacy_format" | |
in inspect.getfullargspec(tf_keras.activations.serialize).args | |
): | |
return tf_keras.activations.serialize( | |
activation, use_legacy_format=use_legacy_format | |
) | |
else: | |
return tf_keras.activations.serialize(activation) | |