Spaces:
Building
Building
# Copyright 2024 The etils Authors. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""Numpy utils. | |
Attributes: | |
tau: The circle constant (2 * pi). (https://tauday.com/) | |
""" | |
from __future__ import annotations | |
import sys | |
import typing | |
from typing import Any, Optional, TypeVar | |
from etils import epy | |
import numpy as np | |
if typing.TYPE_CHECKING: | |
from etils.enp.typing import Array | |
_T = TypeVar('_T') | |
# TODO(pytype): Ideally should use `-> Literal[np]:` but Python does not | |
# support this: https://github.com/python/typing/issues/1039 | |
# Thankfully, pytype correctly auto-infer `np` when returned by `get_xnp` | |
NpModule = Any | |
# Mirror math.tau (PEP 628). See https://tauday.com/ | |
tau = 2 * np.pi | |
# When `strict=False` (in `get_xnp`, `is_array`,...), those types are also | |
# accepted: | |
_ARRAY_LIKE_TYPES = (int, bool, float, list, tuple) | |
# During the class construction, pytype fails because of name conflict between | |
# the `np` `@property` and the module. | |
_np = np | |
class _LazyArrayMeta(type): | |
def __instancecheck__(cls, obj) -> bool: | |
return lazy.is_array(obj) | |
class _LazyImporter: | |
"""Lazy import module. | |
Help to write code seamlessly working with np, Jax and TF. | |
Because libs are lazily imported, TF and Jax are always optional dependencies. | |
""" | |
def has_jax(self) -> bool: | |
return 'jax' in sys.modules | |
def has_tf(self) -> bool: | |
return 'tensorflow' in sys.modules | |
def has_torch(self) -> bool: | |
return 'torch' in sys.modules | |
def jax(self): | |
import jax # pylint: disable=g-import-not-at-top # pytype: disable=import-error | |
return jax | |
def jnp(self): | |
import jax.numpy as jnp # pylint: disable=g-import-not-at-top # pytype: disable=import-error | |
return jnp | |
def tf(self): | |
import tensorflow # pylint: disable=g-import-not-at-top # pytype: disable=import-error | |
return tensorflow | |
def tnp(self): | |
import tensorflow.experimental.numpy as tnp # pylint: disable=g-import-not-at-top # pytype: disable=import-error | |
return tnp | |
def torch(self): | |
import torch # pylint: disable=g-import-not-at-top # pytype: disable=import-error | |
return torch | |
def np(self): | |
return np | |
def is_np_xnp(self, xnp: NpModule) -> bool: | |
return xnp is _np | |
def is_tf_xnp(self, xnp: NpModule) -> bool: | |
return self.has_tf and xnp is self.tnp | |
def is_jax_xnp(self, xnp: NpModule) -> bool: | |
return self.has_jax and xnp is self.jnp | |
def is_torch_xnp(self, xnp: NpModule) -> bool: | |
return self.has_torch and xnp is self.torch | |
def is_np(self, x: Array) -> bool: | |
return isinstance(x, (np.ndarray, np.generic)) | |
def is_tf(self, x: Array) -> bool: | |
return self.has_tf and isinstance( | |
x, | |
( | |
self.tnp.ndarray, | |
self.tf.TensorSpec, | |
self.tf.__internal__.types.Tensor, | |
), | |
) | |
def is_jax(self, x: Array) -> bool: | |
return self.has_jax and isinstance(x, self.jnp.ndarray) | |
def is_torch(self, x: Array) -> bool: | |
return self.has_torch and isinstance(x, self.torch.Tensor) | |
def is_array(self, x: Array, *, strict: bool = True) -> bool: | |
is_array_like = False if strict else isinstance(x, _ARRAY_LIKE_TYPES) | |
return ( | |
self.is_np(x) | |
or self.is_jax(x) | |
or self.is_tf(x) | |
or self.is_torch(x) | |
or is_array_like | |
) | |
def is_np_dtype(self, dtype) -> bool: | |
return isinstance(dtype, np.dtype) or epy.issubclass(dtype, np.generic) | |
def is_tf_dtype(self, dtype) -> bool: | |
return self.has_tf and isinstance(dtype, self.tf.dtypes.DType) | |
def is_jax_dtype(self, dtype) -> bool: | |
# `jnp.int64`,... are `jax._src.numpy.lax_numpy._ScalarMeta`, but | |
# jnp.ndarray.dtype are numpy dtype | |
check_jax = self.has_jax and isinstance(dtype, type(self.jnp.float32)) | |
return self.is_np_dtype(dtype) or check_jax | |
def is_torch_dtype(self, dtype) -> bool: | |
return self.has_torch and isinstance(dtype, self.torch.dtype) | |
def is_dtype(self, dtype) -> bool: | |
return ( | |
self.is_np_dtype(dtype) | |
or self.is_jax_dtype(dtype) | |
or self.is_tf_dtype(dtype) | |
or self.is_torch_dtype(dtype) | |
) | |
def as_np_dtype(self, dtype): | |
if self.is_tf_dtype(dtype): | |
dtype = dtype.as_numpy_dtype | |
elif self.is_torch_dtype(dtype): | |
from etils.enp import compat # pylint: disable=g-import-not-at-top | |
dtype = compat.dtype_torch_to_np(dtype) | |
elif not self.is_jax_dtype(dtype) and not self.is_np_dtype(dtype): | |
raise TypeError(f'Invalid dtype: {dtype!r}') | |
return np.dtype(dtype) | |
def as_tf_dtype(self, dtype): | |
return self.tf.dtypes.as_dtype(self.as_np_dtype(dtype)) | |
def as_jax_dtype(self, dtype): | |
return self.as_np_dtype(dtype) # Jax and numpy types are mostly similar | |
def as_torch_dtype(self, dtype): | |
from etils.enp import compat # pylint: disable=g-import-not-at-top | |
return compat.dtype_np_to_torch(self.as_np_dtype(dtype)) | |
def as_dtype(self, dtype, *, xnp: NpModule = _np): | |
"""Normalize to dtype for the given `xnp`.""" | |
if self.is_np_xnp(xnp): | |
return self.as_np_dtype(dtype) | |
elif self.is_tf_xnp(xnp): | |
return self.as_tf_dtype(dtype) | |
elif self.is_jax_xnp(xnp): | |
return self.as_jax_dtype(dtype) | |
elif self.is_torch_xnp(xnp): | |
return self.as_torch_dtype(dtype) | |
else: | |
raise TypeError(f'Unknown xnp: {xnp!r}') | |
def dtype_from_array( | |
self, | |
array_like: Array, | |
*, | |
strict: bool = True, | |
) -> Optional[_np.dtype]: | |
"""Returns the dtype associated with the array.""" | |
if self.is_array(array_like): # Already an ndarray, normalize the dtype | |
dtype = array_like.dtype | |
elif strict: # Not an array and strict mode: error | |
raise TypeError( | |
f'Cannot extract dtype from non-array {type(array_like)}, ' | |
'when strict=True.' | |
) | |
elif isinstance(array_like, bool): | |
dtype = np.bool_ | |
elif isinstance(array_like, _ARRAY_LIKE_TYPES): # list, tuple, int, float | |
# TODO(epot): Could have a smarter way of infering the dtype for | |
# scalar, int, float,... but difficult to infer list without performance | |
# cost (one way would be to call `asarray(array_like, dtype=None)`, then | |
# cast again) | |
return None | |
else: | |
raise TypeError(f'Cannot extract dtype from non-array {type(array_like)}') | |
return self.as_dtype(dtype) | |
def get_xnp(self, x: Array, *, strict: bool = True): # -> NpModule: | |
"""Returns the numpy module associated with the given array. | |
Args: | |
x: Either tf, jax or numpy array. | |
strict: If `False`, default to `np.array` if the array can't be infered ( | |
to support array-like: list, tuple,...) | |
Returns: | |
The numpy module. | |
""" | |
# This is inspired from NEP 37 but without the `__array_module__` magic: | |
# https://numpy.org/neps/nep-0037-array-module.html | |
# Note there is also an implementation of NEP 37 from the author, but look | |
# overly complicated and not available at google. | |
# https://github.com/seberg/numpy-dispatch | |
if self.is_jax(x): | |
return self.jnp | |
elif self.is_tf(x): | |
return self.tnp | |
elif self.is_np(x): | |
return np | |
elif self.is_torch(x): | |
return self.torch | |
elif not strict and isinstance(x, _ARRAY_LIKE_TYPES): | |
# `strict=False` support `[0, 0, 0]`, `0`,... | |
return np | |
else: | |
raise TypeError( | |
f'Cannot infer the numpy module from array: {type(x).__name__}' | |
) | |
def is_tnp_enabled(self) -> bool: | |
"""Returns `True` if numpy mode is enabled.""" | |
return self.has_tf and hasattr(self.tf.Tensor, 'reshape') | |
class LazyArray(metaclass=_LazyArrayMeta): | |
"""Represent `tf.Tensor`, `jax.ndarray`, `np.ndarray`, `torch.Tensor`. | |
Allow to check isinstance without triggering imports from other modules: | |
``` | |
assert isinstance(jnp.zeros((2,)), enp.lazy.LazyArray) | |
``` | |
""" | |
lazy = _LazyImporter() | |
def get_np_module(array: Array, *, strict: bool = True): # -> NpModule: | |
"""Returns the numpy module associated with the given array. | |
Args: | |
array: Either tf, jax or numpy array. | |
strict: If `False`, default to `np.array` if the array can't be infered ( | |
to support array-like: list, tuple,...) | |
Returns: | |
The numpy module. | |
""" | |
return lazy.get_xnp(array, strict=strict) | |
def is_dtype_str(dtype) -> bool: | |
"""Returns True if the dtype is `str`.""" | |
# tf.string.as_numpy_dtype is object | |
try: | |
dtype = np.dtype(dtype) | |
except TypeError: # `jax.random.PRNGKeyArray` fail. | |
return False | |
return dtype.type in {np.object_, np.str_, np.bytes_} | |
def is_array_str(x: Any) -> bool: | |
"""Returns True if the given array is a `str` array. | |
Note: Also returns True for scalar `str`, `bytes` values. For compatibility | |
with `tensor.numpy()` which returns `bytes` | |
Args: | |
x: The array to test | |
Returns: | |
True or False | |
""" | |
# `Tensor(shape=(), dtype=tf.string).numpy()` returns `bytes`. | |
if isinstance(x, (bytes, str)): | |
return True | |
elif is_array(x): | |
return is_dtype_str(x.dtype) | |
else: | |
return False | |
def is_array(x: Any) -> bool: | |
"""Returns `True` if array is np or `jnp` array.""" | |
if isinstance(x, np.ndarray): | |
return True | |
elif lazy.has_jax and isinstance(x, lazy.jnp.ndarray): | |
return True | |
else: | |
return False | |
def _to_str_array(x): | |
"""Decodes bytes -> str array.""" | |
# tf.string tensors are returned as bytes, so need to convert them back to str | |
return x.decode('utf8') if isinstance(x, bytes) else x | |
def normalize_bytes2str(x: bytes) -> str: | |
... | |
def normalize_bytes2str(x: _T) -> _T: | |
... | |
# Ideally could also add `BytesArray -> StrArray`, but both `bytes` and `str` | |
# are `StrArray` | |
def normalize_bytes2str(x): | |
"""Normalize `bytes` array to `str` (UTF-8). | |
Example of usage: | |
```python | |
for ex in tfds.as_numpy(ds): # tf.data returns `tf.string` as `bytes` | |
ex = tf.nest.map_structure(enp.normalize_bytes2str, ex) | |
``` | |
Args: | |
x: Any array | |
Returns: | |
x: `bytes` array are decoded as `str` | |
""" | |
if isinstance(x, str): | |
return x | |
if isinstance(x, bytes): | |
return x.decode('utf8') | |
elif is_array_str(x): | |
# Note: `np.char.decode` is likely faster but don't work on `object` nor | |
# bytes arrays. | |
return _to_str_array(x) | |
else: | |
return x | |