Spaces:
Building
Building
# Copyright 2024 The etils Authors. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""Array spec utils.""" | |
# Is there a way of merging this with array_types ? | |
from __future__ import annotations | |
import functools | |
import sys | |
from typing import Any, Optional | |
from etils.enp import numpy_utils | |
from etils.enp.array_types import typing as array_types | |
from etils.enp.typing import Array | |
import numpy as np | |
lazy = numpy_utils.lazy | |
class UnknownArrayError(TypeError): | |
pass | |
def _get_none_spec(): # -> tf.TypeSpec: | |
"""Returns the tf.NoneTensorSpec().""" | |
assert lazy.has_tf | |
# We need this hack as NoneTensorSpec is not exposed in the public API. | |
# (see: b/191132147) | |
ds = lazy.tf.data.Dataset.range(0) | |
ds = ds.map(lambda x: (x, None)) | |
return ds.element_spec[-1] | |
class ArraySpec: | |
"""Structure containing shape/dtype.""" | |
__slots__ = ['shape', 'dtype'] | |
def __init__(self, shape, dtype): | |
if numpy_utils.is_dtype_str(dtype): # Normalize `str` dtype | |
dtype = np.dtype('O') | |
self.shape = tuple(shape) | |
self.dtype = np.dtype(dtype) | |
def __repr__(self) -> str: | |
array_type = array_types.ArrayAliasMeta( | |
dtype=self.dtype, | |
shape=self.shape, | |
) | |
return repr(array_type) | |
def __eq__(self, other) -> bool: | |
if not isinstance(other, type(self)): | |
return False | |
else: | |
return (other.shape, other.dtype) == (self.shape, self.dtype) | |
def __hash__(self) -> int: | |
return hash((self.shape, self.dtype)) | |
def is_array(cls, array: Any) -> bool: | |
"""Returns `True` if the given value can be converted to `ArraySpec`.""" | |
try: | |
cls.from_array(array) | |
except UnknownArrayError: | |
return False | |
else: | |
return True | |
def from_array(cls, array: Array) -> Optional[ArraySpec]: | |
"""Construct the `ArraySpec` from the given array.""" | |
# Could refactor with some dynamic registration mechanism. | |
if isinstance(array, (np.ndarray, np.generic, ArraySpec)): | |
shape = array.shape | |
dtype = array.dtype | |
elif ( | |
lazy.has_jax | |
and isinstance(array, lazy.jax.Array) | |
and lazy.jax.dtypes.issubdtype(array.dtype, lazy.jax.dtypes.prng_key) | |
): | |
shape = array.shape | |
dtype = np.uint32 # `jax.random.PRNGKeyArray` is a constant | |
elif lazy.has_jax and isinstance( | |
array, | |
(lazy.jax.ShapeDtypeStruct, lazy.jax.Array), | |
): | |
shape = array.shape | |
dtype = array.dtype | |
elif lazy.has_tf and isinstance( | |
array, | |
(lazy.tf.TensorSpec, lazy.tf.Tensor), | |
): | |
shape = array.shape | |
# In graph mode, `.shape` values can be `Dimension(32)` | |
shape = (int(s) if s is not None else s for s in shape) | |
dtype = array.dtype.as_numpy_dtype | |
elif lazy.has_tf and isinstance(array, type(_get_none_spec())): | |
return None # Special case for `NoneTensorSpec()` | |
elif _is_grain(array): | |
shape = array.shape | |
dtype = array.dtype | |
elif _is_orbax(array): | |
shape = array.shape | |
dtype = array.dtype | |
elif _is_flax_summarry(array): | |
shape = array.shape | |
dtype = array.dtype | |
elif isinstance(array, array_types.ArrayAliasMeta): | |
try: | |
shape = (int(s) for s in array.shape.split()) | |
except ValueError: | |
raise UnknownArrayError( | |
f'Not supported dynamic shape: {array}' | |
) from None | |
dtype = array.dtype.np_dtype | |
else: | |
raise UnknownArrayError(f'Unknown array-like type: {type(array)}') | |
# Should we also handle `bytes` case ? | |
return cls(shape=shape, dtype=dtype) | |
def is_fake_array(array: Array) -> bool: | |
"""Returns `True` if the given array is a fake array.""" | |
return ( | |
(lazy.has_jax and isinstance(array, lazy.jax.ShapeDtypeStruct)) | |
or (lazy.has_tf and isinstance(array, lazy.tf.TensorSpec)) | |
or isinstance(array, ArraySpec) | |
or _is_orbax(array) | |
or _is_grain(array) | |
or _is_flax_summarry(array) | |
or isinstance(array, array_types.ArrayAliasMeta) | |
) | |
def _is_flax_summarry(value: Array) -> bool: | |
if 'flax.linen' not in sys.modules: | |
return False | |
from flax import linen as nn # pylint: disable=g-import-not-at-top # pytype: disable=import-error | |
return isinstance(value, nn.summary._ArrayRepresentation) # pylint: disable=protected-access | |
def _is_grain(array: Array) -> bool: | |
if 'grain.tensorflow' not in sys.modules: | |
return False | |
from grain import tensorflow as grain # pylint: disable=g-import-not-at-top # pytype: disable=import-error | |
return isinstance(array, grain.ArraySpec) | |
def _is_orbax(array: Array) -> bool: | |
if 'orbax.checkpoint' not in sys.modules: | |
return False | |
from orbax import checkpoint as ocp # pylint: disable=g-import-not-at-top # pytype: disable=import-error | |
return isinstance( | |
array, | |
( | |
ocp.type_handlers.ArrayMetadata, | |
ocp.type_handlers.ScalarMetadata, | |
), | |
) | |