PTWZ's picture
Upload folder using huggingface_hub
f5f3483 verified
# Copyright 2024 The etils Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tree API."""
import concurrent.futures
import functools
from typing import Any, Callable, Iterable, Iterator, Optional, TypeVar
from etils import enp
from etils import etqdm
from etils.array_types import Array
from etils.etree import backend as backend_lib
from etils.etree.typing import LeafFn, Tree # pylint: disable=g-importing-member,g-multiple-import
_T = Any # TODO(pytype): Replace by `TypeVar`
_Tin = Any # Could make this TypeVar if typing support variadic
_Tout = TypeVar('_Tout')
class TreeAPI:
"""Tree API, using either `jax.tree_utils`, `tf.nest` or `tree` backend."""
def __init__(self, backend: backend_lib.Backend):
self.backend = backend
def map(
self,
map_fn: Callable[..., _Tout], # Callable[[_Tin0, _Tin1,...], Tout]
*trees: Tree[_Tin], # _Tin0, _Tin1,...
is_leaf: Optional[LeafFn] = None,
) -> Tree[_Tout]:
"""Same as `tree.map_structure`.
Args:
map_fn: Worker function
*trees: Nested input to pass to the `map_fn`
is_leaf: Don't recurse into leaf if `is_leaf(node)` is `True`
Returns:
The nested structure after `map_fn` has been applied.
"""
return self.backend.map(map_fn, *trees, is_leaf=is_leaf)
def parallel_map(
self,
map_fn: Callable[..., _Tout], # Callable[[_Tin0, _Tin1,...], Tout]
*trees: Tree[_Tin], # _Tin0, _Tin1,...
num_threads: Optional[int] = None,
progress_bar: bool = False,
is_leaf: Optional[LeafFn] = None,
) -> Tree[_Tout]:
"""Same as `tree.map_structure` but apply `map_fn` in parallel.
Args:
map_fn: Worker function
*trees: Nested input to pass to the `map_fn`
num_threads: Number of workers (default to CPU count * 5)
progress_bar: If True, display a progression bar.
is_leaf: Don't recurse into leaf if `is_leaf(node)` is `True`
Returns:
The nested structure after `map_fn` has been applied.
"""
# TODO(epot): Allow nesting `parallel_map` while keeping max num threads
# constant. How to avoid dead locks ?
with concurrent.futures.ThreadPoolExecutor(
max_workers=num_threads
) as executor:
launch_worker = functools.partial(executor.submit, map_fn)
futures = self.backend.map(launch_worker, *trees, is_leaf=is_leaf)
leaves, _ = self.backend.flatten(futures, is_leaf=is_leaf)
itr = concurrent.futures.as_completed(leaves)
if progress_bar:
itr = etqdm.tqdm(itr, total=len(leaves))
for f in itr: # Propagate exception to main thread.
if f.exception():
raise f.exception()
return self.backend.map(lambda f: f.result(), futures)
def unzip(self, tree: Tree[Iterable[_T]]) -> Iterator[Tree[_T]]:
"""Unpack a tree of iterable.
This is the reverse operation of `tree.map_structure(zip, *trees)`
Example:
```python
etree.unzip({'a': np.array([1, 2, 3])}) == [{'a': 1}, {'a': 2}, {'a': 3}]
```
Args:
tree: The tree to unzip
Yields:
Trees of same structure than the input, but with individual elements.
"""
leaves, treedef = self.backend.flatten(tree)
for leaf_elems in zip(*leaves): # TODO(py310): check=True
yield self.backend.unflatten(treedef, leaf_elems)
def stack(
self, trees: Iterable[Tree[Array['*s']]]
) -> Tree[Array['n_trees *s']]:
"""Stack a tree of `Iterable[Array]`.
Supports `jax`, `tf`, `np`.
Example:
```python
etree.stack([
{'a': np.array([1])},
{'a': np.array([2])},
{'a': np.array([3])},
]) == {
'a': np.array([[1], [2], [3]])
}
```
Args:
trees: The list of tree to stack
Returns:
Tree of arrays.
"""
return self.backend.map(_stack, *trees)
def spec_like(
self,
tree: Tree[Array],
*,
ignore_other: bool = True,
) -> Tree[enp.ArraySpec]:
"""Inspect a tree of array, works with any array type.
Example:
```python
model = MyModel()
variables = model.init(jax.random.PRNGKey(0), x)
# Inspect the `variables` tree structures
print(etree.spec_like(variables))
```
Args:
tree: The tree of array
ignore_other: If `True`, non-array are forwarded as-is.
Returns:
The tree of `enp.ArraySpec`.
"""
def _to_spec_array(array):
if not enp.ArraySpec.is_array(array):
if ignore_other:
return array
else:
raise TypeError(f'Unknown array type: {type(array)}')
else:
return enp.ArraySpec.from_array(array)
return self.backend.map(_to_spec_array, tree)
def _stack(*arrs: Array) -> Array:
"""Stack arrays together."""
xnp = enp.lazy.get_xnp(arrs[0])
return xnp.stack(arrs)