PTWZ's picture
Upload folder using huggingface_hub
f5f3483 verified
# Copyright 2024 The etils Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Numpy utils.
Attributes:
tau: The circle constant (2 * pi). (https://tauday.com/)
"""
from __future__ import annotations
import sys
import typing
from typing import Any, Optional, TypeVar
from etils import epy
import numpy as np
if typing.TYPE_CHECKING:
from etils.enp.typing import Array
_T = TypeVar('_T')
# TODO(pytype): Ideally should use `-> Literal[np]:` but Python does not
# support this: https://github.com/python/typing/issues/1039
# Thankfully, pytype correctly auto-infer `np` when returned by `get_xnp`
NpModule = Any
# Mirror math.tau (PEP 628). See https://tauday.com/
tau = 2 * np.pi
# When `strict=False` (in `get_xnp`, `is_array`,...), those types are also
# accepted:
_ARRAY_LIKE_TYPES = (int, bool, float, list, tuple)
# During the class construction, pytype fails because of name conflict between
# the `np` `@property` and the module.
_np = np
class _LazyArrayMeta(type):
def __instancecheck__(cls, obj) -> bool:
return lazy.is_array(obj)
class _LazyImporter:
"""Lazy import module.
Help to write code seamlessly working with np, Jax and TF.
Because libs are lazily imported, TF and Jax are always optional dependencies.
"""
@property
def has_jax(self) -> bool:
return 'jax' in sys.modules
@property
def has_tf(self) -> bool:
return 'tensorflow' in sys.modules
@property
def has_torch(self) -> bool:
return 'torch' in sys.modules
@property
def jax(self):
import jax # pylint: disable=g-import-not-at-top # pytype: disable=import-error
return jax
@property
def jnp(self):
import jax.numpy as jnp # pylint: disable=g-import-not-at-top # pytype: disable=import-error
return jnp
@property
def tf(self):
import tensorflow # pylint: disable=g-import-not-at-top # pytype: disable=import-error
return tensorflow
@property
def tnp(self):
import tensorflow.experimental.numpy as tnp # pylint: disable=g-import-not-at-top # pytype: disable=import-error
return tnp
@property
def torch(self):
import torch # pylint: disable=g-import-not-at-top # pytype: disable=import-error
return torch
@property
def np(self):
return np
def is_np_xnp(self, xnp: NpModule) -> bool:
return xnp is _np
def is_tf_xnp(self, xnp: NpModule) -> bool:
return self.has_tf and xnp is self.tnp
def is_jax_xnp(self, xnp: NpModule) -> bool:
return self.has_jax and xnp is self.jnp
def is_torch_xnp(self, xnp: NpModule) -> bool:
return self.has_torch and xnp is self.torch
def is_np(self, x: Array) -> bool:
return isinstance(x, (np.ndarray, np.generic))
def is_tf(self, x: Array) -> bool:
return self.has_tf and isinstance(
x,
(
self.tnp.ndarray,
self.tf.TensorSpec,
self.tf.__internal__.types.Tensor,
),
)
def is_jax(self, x: Array) -> bool:
return self.has_jax and isinstance(x, self.jnp.ndarray)
def is_torch(self, x: Array) -> bool:
return self.has_torch and isinstance(x, self.torch.Tensor)
def is_array(self, x: Array, *, strict: bool = True) -> bool:
is_array_like = False if strict else isinstance(x, _ARRAY_LIKE_TYPES)
return (
self.is_np(x)
or self.is_jax(x)
or self.is_tf(x)
or self.is_torch(x)
or is_array_like
)
def is_np_dtype(self, dtype) -> bool:
return isinstance(dtype, np.dtype) or epy.issubclass(dtype, np.generic)
def is_tf_dtype(self, dtype) -> bool:
return self.has_tf and isinstance(dtype, self.tf.dtypes.DType)
def is_jax_dtype(self, dtype) -> bool:
# `jnp.int64`,... are `jax._src.numpy.lax_numpy._ScalarMeta`, but
# jnp.ndarray.dtype are numpy dtype
check_jax = self.has_jax and isinstance(dtype, type(self.jnp.float32))
return self.is_np_dtype(dtype) or check_jax
def is_torch_dtype(self, dtype) -> bool:
return self.has_torch and isinstance(dtype, self.torch.dtype)
def is_dtype(self, dtype) -> bool:
return (
self.is_np_dtype(dtype)
or self.is_jax_dtype(dtype)
or self.is_tf_dtype(dtype)
or self.is_torch_dtype(dtype)
)
def as_np_dtype(self, dtype):
if self.is_tf_dtype(dtype):
dtype = dtype.as_numpy_dtype
elif self.is_torch_dtype(dtype):
from etils.enp import compat # pylint: disable=g-import-not-at-top
dtype = compat.dtype_torch_to_np(dtype)
elif not self.is_jax_dtype(dtype) and not self.is_np_dtype(dtype):
raise TypeError(f'Invalid dtype: {dtype!r}')
return np.dtype(dtype)
def as_tf_dtype(self, dtype):
return self.tf.dtypes.as_dtype(self.as_np_dtype(dtype))
def as_jax_dtype(self, dtype):
return self.as_np_dtype(dtype) # Jax and numpy types are mostly similar
def as_torch_dtype(self, dtype):
from etils.enp import compat # pylint: disable=g-import-not-at-top
return compat.dtype_np_to_torch(self.as_np_dtype(dtype))
def as_dtype(self, dtype, *, xnp: NpModule = _np):
"""Normalize to dtype for the given `xnp`."""
if self.is_np_xnp(xnp):
return self.as_np_dtype(dtype)
elif self.is_tf_xnp(xnp):
return self.as_tf_dtype(dtype)
elif self.is_jax_xnp(xnp):
return self.as_jax_dtype(dtype)
elif self.is_torch_xnp(xnp):
return self.as_torch_dtype(dtype)
else:
raise TypeError(f'Unknown xnp: {xnp!r}')
def dtype_from_array(
self,
array_like: Array,
*,
strict: bool = True,
) -> Optional[_np.dtype]:
"""Returns the dtype associated with the array."""
if self.is_array(array_like): # Already an ndarray, normalize the dtype
dtype = array_like.dtype
elif strict: # Not an array and strict mode: error
raise TypeError(
f'Cannot extract dtype from non-array {type(array_like)}, '
'when strict=True.'
)
elif isinstance(array_like, bool):
dtype = np.bool_
elif isinstance(array_like, _ARRAY_LIKE_TYPES): # list, tuple, int, float
# TODO(epot): Could have a smarter way of infering the dtype for
# scalar, int, float,... but difficult to infer list without performance
# cost (one way would be to call `asarray(array_like, dtype=None)`, then
# cast again)
return None
else:
raise TypeError(f'Cannot extract dtype from non-array {type(array_like)}')
return self.as_dtype(dtype)
def get_xnp(self, x: Array, *, strict: bool = True): # -> NpModule:
"""Returns the numpy module associated with the given array.
Args:
x: Either tf, jax or numpy array.
strict: If `False`, default to `np.array` if the array can't be infered (
to support array-like: list, tuple,...)
Returns:
The numpy module.
"""
# This is inspired from NEP 37 but without the `__array_module__` magic:
# https://numpy.org/neps/nep-0037-array-module.html
# Note there is also an implementation of NEP 37 from the author, but look
# overly complicated and not available at google.
# https://github.com/seberg/numpy-dispatch
if self.is_jax(x):
return self.jnp
elif self.is_tf(x):
return self.tnp
elif self.is_np(x):
return np
elif self.is_torch(x):
return self.torch
elif not strict and isinstance(x, _ARRAY_LIKE_TYPES):
# `strict=False` support `[0, 0, 0]`, `0`,...
return np
else:
raise TypeError(
f'Cannot infer the numpy module from array: {type(x).__name__}'
)
@property
def is_tnp_enabled(self) -> bool:
"""Returns `True` if numpy mode is enabled."""
return self.has_tf and hasattr(self.tf.Tensor, 'reshape')
class LazyArray(metaclass=_LazyArrayMeta):
"""Represent `tf.Tensor`, `jax.ndarray`, `np.ndarray`, `torch.Tensor`.
Allow to check isinstance without triggering imports from other modules:
```
assert isinstance(jnp.zeros((2,)), enp.lazy.LazyArray)
```
"""
lazy = _LazyImporter()
def get_np_module(array: Array, *, strict: bool = True): # -> NpModule:
"""Returns the numpy module associated with the given array.
Args:
array: Either tf, jax or numpy array.
strict: If `False`, default to `np.array` if the array can't be infered (
to support array-like: list, tuple,...)
Returns:
The numpy module.
"""
return lazy.get_xnp(array, strict=strict)
def is_dtype_str(dtype) -> bool:
"""Returns True if the dtype is `str`."""
# tf.string.as_numpy_dtype is object
try:
dtype = np.dtype(dtype)
except TypeError: # `jax.random.PRNGKeyArray` fail.
return False
return dtype.type in {np.object_, np.str_, np.bytes_}
def is_array_str(x: Any) -> bool:
"""Returns True if the given array is a `str` array.
Note: Also returns True for scalar `str`, `bytes` values. For compatibility
with `tensor.numpy()` which returns `bytes`
Args:
x: The array to test
Returns:
True or False
"""
# `Tensor(shape=(), dtype=tf.string).numpy()` returns `bytes`.
if isinstance(x, (bytes, str)):
return True
elif is_array(x):
return is_dtype_str(x.dtype)
else:
return False
def is_array(x: Any) -> bool:
"""Returns `True` if array is np or `jnp` array."""
if isinstance(x, np.ndarray):
return True
elif lazy.has_jax and isinstance(x, lazy.jnp.ndarray):
return True
else:
return False
@np.vectorize
def _to_str_array(x):
"""Decodes bytes -> str array."""
# tf.string tensors are returned as bytes, so need to convert them back to str
return x.decode('utf8') if isinstance(x, bytes) else x
@typing.overload
def normalize_bytes2str(x: bytes) -> str:
...
@typing.overload
def normalize_bytes2str(x: _T) -> _T:
...
# Ideally could also add `BytesArray -> StrArray`, but both `bytes` and `str`
# are `StrArray`
def normalize_bytes2str(x):
"""Normalize `bytes` array to `str` (UTF-8).
Example of usage:
```python
for ex in tfds.as_numpy(ds): # tf.data returns `tf.string` as `bytes`
ex = tf.nest.map_structure(enp.normalize_bytes2str, ex)
```
Args:
x: Any array
Returns:
x: `bytes` array are decoded as `str`
"""
if isinstance(x, str):
return x
if isinstance(x, bytes):
return x.decode('utf8')
elif is_array_str(x):
# Note: `np.char.decode` is likely faster but don't work on `object` nor
# bytes arrays.
return _to_str_array(x)
else:
return x