PTWZ's picture
Upload folder using huggingface_hub
f5f3483 verified
# Copyright 2024 The etils Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Itertools utils."""
from __future__ import annotations
import collections
import itertools
from typing import Any, Callable, Iterable, Iterator, TypeVar
# from typing_extensions import Unpack, TypeVarTuple # pytype: disable=not-supported-yet # pylint: disable=g-multiple-import
# TODO(pytype): Once supported, should replace
Unpack = Any
TypeVarTuple = Any
_T = TypeVar('_T')
_KeyT = TypeVar('_KeyT')
_ValuesT = Any # TypeVarTuple('_ValuesT')
_K = TypeVar('_K')
_Tin = TypeVar('_Tin')
_Tout = TypeVar('_Tout')
def _identity(x: _Tin) -> _Tin:
"""Pass through function."""
return x
def groupby(
iterable: Iterable[_Tin],
*,
key: Callable[[_Tin], _K],
value: Callable[[_Tin], _Tout] = _identity,
) -> dict[_K, list[_Tout]]:
"""Similar to `itertools.groupby` but return result as a `dict()`.
Example:
```python
out = epy.groupby(
['555', '4', '11', '11', '333'],
key=len,
value=int,
)
# Order is consistent with above
assert out == {
3: [555, 333],
1: [4],
2: [11, 11],
}
```
Other difference with `itertools.groupby`:
* Iterable do not need to be sorted. Order of the original iterator is
preserved in the group.
* Transformation can be applied to the value too
Args:
iterable: The iterable to group
key: Mapping applied to group the values (should return a hashable)
value: Mapping applied to the values
Returns:
The dict
"""
groups = collections.defaultdict(list)
for v in iterable:
groups[key(v)].append(value(v))
return dict(groups)
def splitby(
iterable: Iterable[_T], predicate: Callable[[_T], bool]
) -> tuple[list[_T], list[_T]]:
"""Split the iterable into 2 lists (false, true), based on the predicate.
Example:
```python
small, big = epy.splitby([100, 4, 4, 1, 200], lambda x: x > 10)
assert small == [4, 4, 1]
assert big == [100, 200]
```
Args:
iterable: The iterable to split
predicate: Function applied to split
Returns:
False list, True list
"""
false_list = []
true_list = []
for v in iterable:
if predicate(v):
true_list.append(v)
else:
false_list.append(v)
return false_list, true_list
def zip_dict( # pytype: disable=invalid-annotation
*dicts: Unpack[dict[_KeyT, _ValuesT]],
) -> Iterator[_KeyT, tuple[Unpack[_ValuesT]]]:
"""Iterate over items of dictionaries grouped by their keys.
Example:
```python
d0 = {'a': 1, 'b': 2}
d1 = {'a': 10, 'b': 20}
d2 = {'a': 100, 'b': 200}
list(epy.zip_dict(d0, d1, d2)) == [
('a', (1, 10, 100)),
('b', (2, 20, 200)),
]
```
Args:
*dicts: The dict to iterate over. Should all have the same keys
Yields:
The iterator of `(key, zip(*values))`
Raises:
KeyError: If dicts does not contain the same keys.
"""
# Set does not keep order like dict, so only use set to compare keys
all_keys = set(itertools.chain(*dicts))
d0 = dicts[0]
if len(all_keys) != len(d0):
raise KeyError(f'Missing keys: {all_keys ^ set(d0)}')
for key in d0: # set merge all keys
# Will raise KeyError if the dict don't have the same keys
yield key, tuple(d[key] for d in dicts)