Spaces:
Building
Building
# Copyright 2024 The etils Authors. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""Itertools utils.""" | |
from __future__ import annotations | |
import collections | |
import itertools | |
from typing import Any, Callable, Iterable, Iterator, TypeVar | |
# from typing_extensions import Unpack, TypeVarTuple # pytype: disable=not-supported-yet # pylint: disable=g-multiple-import | |
# TODO(pytype): Once supported, should replace | |
Unpack = Any | |
TypeVarTuple = Any | |
_T = TypeVar('_T') | |
_KeyT = TypeVar('_KeyT') | |
_ValuesT = Any # TypeVarTuple('_ValuesT') | |
_K = TypeVar('_K') | |
_Tin = TypeVar('_Tin') | |
_Tout = TypeVar('_Tout') | |
def _identity(x: _Tin) -> _Tin: | |
"""Pass through function.""" | |
return x | |
def groupby( | |
iterable: Iterable[_Tin], | |
*, | |
key: Callable[[_Tin], _K], | |
value: Callable[[_Tin], _Tout] = _identity, | |
) -> dict[_K, list[_Tout]]: | |
"""Similar to `itertools.groupby` but return result as a `dict()`. | |
Example: | |
```python | |
out = epy.groupby( | |
['555', '4', '11', '11', '333'], | |
key=len, | |
value=int, | |
) | |
# Order is consistent with above | |
assert out == { | |
3: [555, 333], | |
1: [4], | |
2: [11, 11], | |
} | |
``` | |
Other difference with `itertools.groupby`: | |
* Iterable do not need to be sorted. Order of the original iterator is | |
preserved in the group. | |
* Transformation can be applied to the value too | |
Args: | |
iterable: The iterable to group | |
key: Mapping applied to group the values (should return a hashable) | |
value: Mapping applied to the values | |
Returns: | |
The dict | |
""" | |
groups = collections.defaultdict(list) | |
for v in iterable: | |
groups[key(v)].append(value(v)) | |
return dict(groups) | |
def splitby( | |
iterable: Iterable[_T], predicate: Callable[[_T], bool] | |
) -> tuple[list[_T], list[_T]]: | |
"""Split the iterable into 2 lists (false, true), based on the predicate. | |
Example: | |
```python | |
small, big = epy.splitby([100, 4, 4, 1, 200], lambda x: x > 10) | |
assert small == [4, 4, 1] | |
assert big == [100, 200] | |
``` | |
Args: | |
iterable: The iterable to split | |
predicate: Function applied to split | |
Returns: | |
False list, True list | |
""" | |
false_list = [] | |
true_list = [] | |
for v in iterable: | |
if predicate(v): | |
true_list.append(v) | |
else: | |
false_list.append(v) | |
return false_list, true_list | |
def zip_dict( # pytype: disable=invalid-annotation | |
*dicts: Unpack[dict[_KeyT, _ValuesT]], | |
) -> Iterator[_KeyT, tuple[Unpack[_ValuesT]]]: | |
"""Iterate over items of dictionaries grouped by their keys. | |
Example: | |
```python | |
d0 = {'a': 1, 'b': 2} | |
d1 = {'a': 10, 'b': 20} | |
d2 = {'a': 100, 'b': 200} | |
list(epy.zip_dict(d0, d1, d2)) == [ | |
('a', (1, 10, 100)), | |
('b', (2, 20, 200)), | |
] | |
``` | |
Args: | |
*dicts: The dict to iterate over. Should all have the same keys | |
Yields: | |
The iterator of `(key, zip(*values))` | |
Raises: | |
KeyError: If dicts does not contain the same keys. | |
""" | |
# Set does not keep order like dict, so only use set to compare keys | |
all_keys = set(itertools.chain(*dicts)) | |
d0 = dicts[0] | |
if len(all_keys) != len(d0): | |
raise KeyError(f'Missing keys: {all_keys ^ set(d0)}') | |
for key in d0: # set merge all keys | |
# Will raise KeyError if the dict don't have the same keys | |
yield key, tuple(d[key] for d in dicts) | |