Spaces:
Building
Building
# Copyright 2024 The etils Authors. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""Tree API.""" | |
import concurrent.futures | |
import functools | |
from typing import Any, Callable, Iterable, Iterator, Optional, TypeVar | |
from etils import enp | |
from etils import etqdm | |
from etils.array_types import Array | |
from etils.etree import backend as backend_lib | |
from etils.etree.typing import LeafFn, Tree # pylint: disable=g-importing-member,g-multiple-import | |
_T = Any # TODO(pytype): Replace by `TypeVar` | |
_Tin = Any # Could make this TypeVar if typing support variadic | |
_Tout = TypeVar('_Tout') | |
class TreeAPI: | |
"""Tree API, using either `jax.tree_utils`, `tf.nest` or `tree` backend.""" | |
def __init__(self, backend: backend_lib.Backend): | |
self.backend = backend | |
def map( | |
self, | |
map_fn: Callable[..., _Tout], # Callable[[_Tin0, _Tin1,...], Tout] | |
*trees: Tree[_Tin], # _Tin0, _Tin1,... | |
is_leaf: Optional[LeafFn] = None, | |
) -> Tree[_Tout]: | |
"""Same as `tree.map_structure`. | |
Args: | |
map_fn: Worker function | |
*trees: Nested input to pass to the `map_fn` | |
is_leaf: Don't recurse into leaf if `is_leaf(node)` is `True` | |
Returns: | |
The nested structure after `map_fn` has been applied. | |
""" | |
return self.backend.map(map_fn, *trees, is_leaf=is_leaf) | |
def parallel_map( | |
self, | |
map_fn: Callable[..., _Tout], # Callable[[_Tin0, _Tin1,...], Tout] | |
*trees: Tree[_Tin], # _Tin0, _Tin1,... | |
num_threads: Optional[int] = None, | |
progress_bar: bool = False, | |
is_leaf: Optional[LeafFn] = None, | |
) -> Tree[_Tout]: | |
"""Same as `tree.map_structure` but apply `map_fn` in parallel. | |
Args: | |
map_fn: Worker function | |
*trees: Nested input to pass to the `map_fn` | |
num_threads: Number of workers (default to CPU count * 5) | |
progress_bar: If True, display a progression bar. | |
is_leaf: Don't recurse into leaf if `is_leaf(node)` is `True` | |
Returns: | |
The nested structure after `map_fn` has been applied. | |
""" | |
# TODO(epot): Allow nesting `parallel_map` while keeping max num threads | |
# constant. How to avoid dead locks ? | |
with concurrent.futures.ThreadPoolExecutor( | |
max_workers=num_threads | |
) as executor: | |
launch_worker = functools.partial(executor.submit, map_fn) | |
futures = self.backend.map(launch_worker, *trees, is_leaf=is_leaf) | |
leaves, _ = self.backend.flatten(futures, is_leaf=is_leaf) | |
itr = concurrent.futures.as_completed(leaves) | |
if progress_bar: | |
itr = etqdm.tqdm(itr, total=len(leaves)) | |
for f in itr: # Propagate exception to main thread. | |
if f.exception(): | |
raise f.exception() | |
return self.backend.map(lambda f: f.result(), futures) | |
def unzip(self, tree: Tree[Iterable[_T]]) -> Iterator[Tree[_T]]: | |
"""Unpack a tree of iterable. | |
This is the reverse operation of `tree.map_structure(zip, *trees)` | |
Example: | |
```python | |
etree.unzip({'a': np.array([1, 2, 3])}) == [{'a': 1}, {'a': 2}, {'a': 3}] | |
``` | |
Args: | |
tree: The tree to unzip | |
Yields: | |
Trees of same structure than the input, but with individual elements. | |
""" | |
leaves, treedef = self.backend.flatten(tree) | |
for leaf_elems in zip(*leaves): # TODO(py310): check=True | |
yield self.backend.unflatten(treedef, leaf_elems) | |
def stack( | |
self, trees: Iterable[Tree[Array['*s']]] | |
) -> Tree[Array['n_trees *s']]: | |
"""Stack a tree of `Iterable[Array]`. | |
Supports `jax`, `tf`, `np`. | |
Example: | |
```python | |
etree.stack([ | |
{'a': np.array([1])}, | |
{'a': np.array([2])}, | |
{'a': np.array([3])}, | |
]) == { | |
'a': np.array([[1], [2], [3]]) | |
} | |
``` | |
Args: | |
trees: The list of tree to stack | |
Returns: | |
Tree of arrays. | |
""" | |
return self.backend.map(_stack, *trees) | |
def spec_like( | |
self, | |
tree: Tree[Array], | |
*, | |
ignore_other: bool = True, | |
) -> Tree[enp.ArraySpec]: | |
"""Inspect a tree of array, works with any array type. | |
Example: | |
```python | |
model = MyModel() | |
variables = model.init(jax.random.PRNGKey(0), x) | |
# Inspect the `variables` tree structures | |
print(etree.spec_like(variables)) | |
``` | |
Args: | |
tree: The tree of array | |
ignore_other: If `True`, non-array are forwarded as-is. | |
Returns: | |
The tree of `enp.ArraySpec`. | |
""" | |
def _to_spec_array(array): | |
if not enp.ArraySpec.is_array(array): | |
if ignore_other: | |
return array | |
else: | |
raise TypeError(f'Unknown array type: {type(array)}') | |
else: | |
return enp.ArraySpec.from_array(array) | |
return self.backend.map(_to_spec_array, tree) | |
def _stack(*arrs: Array) -> Array: | |
"""Stack arrays together.""" | |
xnp = enp.lazy.get_xnp(arrs[0]) | |
return xnp.stack(arrs) | |