PTWZ's picture
Upload folder using huggingface_hub
f5f3483 verified
# Copyright 2024 The etils Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Check util."""
from __future__ import annotations
import collections
import dataclasses
import functools
import inspect
import typing
from typing import Any, Callable, Optional, TypeVar
from etils import epy
from etils.enp import numpy_utils
from etils.enp import type_parsing
from etils.enp.array_types import typing as array_typing
import numpy as np
_TypeForm = Any
_Fn = TypeVar('_Fn')
# TODO(epot): Support:
# * return annotations
# * Tuple,...
# * Union
# * shape checking
@dataclasses.dataclass
class _ArrayParam:
"""Argument matching an array."""
type: array_typing.ArrayAliasMeta
is_optional: bool
name: str
def asarray(self, v, *, xnp: numpy_utils.NpModule):
"""Convert the value to array of the correct dtype."""
try:
return self.type.dtype.asarray(v, xnp=xnp, casting='none')
except Exception as e: # pylint: disable=broad-except
epy.reraise(e, prefix=f'Invalid {self.name}: ')
@dataclasses.dataclass
class _FnSignatureCache:
"""Cache of the function signature."""
sig: inspect.Signature
has_xnp_kwargs: bool
array_params: dict[str, _ArrayParam]
@typing.overload
def check_and_normalize_arrays(
fn: None = ...,
*,
strict: bool = ...,
) -> Callable[[_Fn], _Fn]:
...
@typing.overload
def check_and_normalize_arrays(
fn: _Fn = ...,
*,
strict: bool = ...,
) -> _Fn:
...
def check_and_normalize_arrays(fn=None, *, strict: bool = True):
"""Check and normalize arrays.
This function:
* Validate that the dtype/shape input arrays match the typing annotations
* Normalize np, jnp, tf types to be consistent
* Add an optional `xnp` argument to convert input arrays to np/jnp/tnp.
See doc at: https://github.com/google/etils/blob/main/etils/array_types/README.md
Example:
```python
@enp.check_and_normalize_arrays(strict=False)
def add(x: FloatArray[...], y: FloatArray[...]) -> y: FloatArray[...]:
return x + y
# Inside the function, `np` normalized to `jnp`
add(np.array(1.), jnp.array(2.)) == jnp.array(3.)
# strict=False, so `list` accepted and normalized to `xnp`
add(jnp.array(1.), [1., 2., 3.]) == jnp.array([2., 3., 4.])
```
Args:
fn: The function to decorate. Arguments will be automatically infered.
strict: If `False`, `fn` will also accept list, int,... in which case those
are automatically converted to `xnp`
Returns:
fn: The decorated function, with dynamic shape checking
"""
if fn is None:
return functools.partial(check_and_normalize_arrays, strict=strict)
fn._array_types_state = None # pylint: disable=protected-access
@functools.wraps(fn)
def decorated_fn(*args, **kwargs):
try:
kwargs = dict(kwargs)
xnp = kwargs.pop('xnp', None)
# First time the function is called, precompute & cache the info
if fn._array_types_state is None: # pylint: disable=protected-access
fn._array_types_state = _parse_signature(fn) # pylint: disable=protected-access
state: _FnSignatureCache = fn._array_types_state # pylint: disable=protected-access
# In case `xnp` do not have default value
if state.has_xnp_kwargs:
kwargs['xnp'] = ...
bound_args = state.sig.bind(*args, **kwargs)
# Filter the non-array args
# TODO(epot): Should raise an error for non-optional when v is None
array_args = {
k: v
for k, v in bound_args.arguments.items()
if k in state.array_params and v is not None
}
# Extract the xnp (either explicitly passed, or auto-infered)
xnp = xnp or _get_xnp(array_args, strict=strict)
_maybe_set_tnp_casting(xnp)
# Normalize all arrays:
# * Convert to xnp
# * Check dtype
array_args = {
k: state.array_params[k].asarray(v, xnp=xnp)
for k, v in array_args.items()
}
# TODO(epot): Check the shape
# Update the arguments after normalization
bound_args.arguments.update(array_args)
# Eventually add `xnp` kwarg
if state.has_xnp_kwargs:
bound_args.arguments['xnp'] = xnp
except Exception as e: # pylint: disable=broad-except
epy.reraise(
e,
prefix=(
f'@enp.check_and_normalize_arrays error for {fn.__qualname__}: '
),
)
return fn(*bound_args.args, **bound_args.kwargs)
return decorated_fn
def _get_xnp(
array_args: dict[str, Any],
*,
strict: bool,
) -> numpy_utils.NpModule:
"""Extract the xnp module common to the args."""
xnps = collections.defaultdict(list)
for k, v in array_args.items():
try:
xnps[numpy_utils.lazy.get_xnp(v, strict=strict)].append(k)
except Exception as e: # pylint: disable=broad-except
epy.reraise(e, prefix=f'Invalid {k}: Expected xnp.ndarray: ')
return _infer_xnp(xnps)
def _infer_xnp(
xnps: dict[numpy_utils.NpModule, list[str]]
) -> numpy_utils.NpModule:
"""Extract the `xnp` module."""
non_np_xnps = set(xnps) - {np} # jnp, tnp, torch take precedence on `np`
# Detecting conflicting xnp
if len(non_np_xnps) > 1:
xnps = {k.__name__: v for k, v in xnps.items()}
raise ValueError(f'Conflicting numpy types: {xnps}')
if not non_np_xnps:
return np
else:
(xnp,) = non_np_xnps
return xnp
@functools.lru_cache(None)
def _maybe_set_tnp_casting(xnp: numpy_utils.NpModule) -> None:
"""If TF numpy mode is not set, make sure `tnp.asarray(1.)` is `tf.float32`.
If user uses TF without numpy mode, it will create casting issues (for
example: `tf.float64 + tf.float32` will raise an error).
To limit the errors encountered, we set `tnp.asarray(1.)` to `tf.float32`
instead of `tf.float64`.
If numpy mode is already activated, then no need to do anything, as
`tf.float64 + tf.float32` will support auto-casting, like Jax and Numpy.
Args:
xnp: numpy module.
"""
if not numpy_utils.lazy.has_tf or xnp is not numpy_utils.lazy.tnp:
return # Not tnp module
if not numpy_utils.lazy.is_tnp_enabled:
# When TF numpy mode is not enabled, `tnp.asarray(1.)` returns tf.float64,
# creating conflict because TF do fail for operations like:
# `tf.float64 + tf.float32`
from tensorflow.python.ops.numpy_ops import np_dtypes # pylint: disable=g-import-not-at-top,g-direct-tensorflow-import # pytype: disable=import-error
if not np_dtypes.is_prefer_float32():
np_dtypes.set_prefer_float32(True)
msg = epy.dedent(
"""
WARNING: Using array types for TF but without numpy mode enabled. It
is recommended to activate numpy mode as:
import tensorflow.experimental.numpy as tnp
tnp.experimental_enable_numpy_behavior(prefer_float32=True)
"""
)
# Use print otherwise this isn't displayed on Colab
# Could have a `epy.logging` module which auto-print on Colab.
print(msg)
def _parse_signature(fn) -> _FnSignatureCache:
"""Parse the function signature."""
# At this point, `ForwardRef` should have been resolved.
try:
hints = typing.get_type_hints(fn)
except Exception as e: # pylint: disable=broad-except
epy.reraise(
e,
prefix=(
f'Could not infer typing annotation of {fn.__qualname__} '
f'defined in {fn.__module__}'
),
)
sig = inspect.signature(fn)
# For each valid params, create the validator
# TODO(py38): Use :=
array_params = {}
for name, param in sig.parameters.items():
array_param = _get_array_param(param, hints)
if array_param is not None:
array_params[name] = array_param
if not array_params:
raise ValueError(
f'Could not detect any array type hints in {fn.__qualname__} with '
f'signature {sig}.'
)
return _FnSignatureCache(
sig=sig,
has_xnp_kwargs='xnp' in sig.parameters,
array_params=array_params,
)
def _get_array_param(
param: inspect.Parameter,
hints: dict[str, _TypeForm],
) -> Optional[_ArrayParam]:
"""Parse the type & hint of the array."""
name = param.name
if name not in hints: # Not an array param
return None
hint = hints[name]
def make_err(msg: str) -> Exception:
return NotImplementedError(
f'`enp.check_and_normalize_arrays` does not support {msg}. Please open '
f'an issue if you need this feature. For `{name}: {hint}`'
)
leaf_types = type_parsing.get_leaf_types(hint)
is_optional = None in leaf_types
# Filter Optional
leaf_types = [t for t in leaf_types if t is not None]
# Currently, only Optional[Array] or Array supported
are_array = [isinstance(l, array_typing.ArrayAliasMeta) for l in leaf_types]
count_array = are_array.count(True)
count_non_array = are_array.count(False)
if count_array and count_non_array:
raise make_err('Union of array and non-array')
if count_array > 1:
raise make_err('Union of arrays')
if count_non_array:
return None # Not an array param
(array_type,) = leaf_types
if param.kind in {
inspect.Parameter.VAR_POSITIONAL,
inspect.Parameter.VAR_KEYWORD,
}:
raise make_err('*args, **kwargs')
return _ArrayParam(
is_optional=is_optional,
type=array_type,
name=name,
)