Spaces:
Paused
Paused
from __future__ import annotations | |
import functools | |
from typing import Callable, Dict, List, Sequence, Tuple, Union | |
import torch | |
from functorch._C import dim as _C | |
from ._parsing import ( | |
_ellipsis, | |
AnonymousAxis, | |
comma_separate, | |
parse_pattern, | |
validate_rearrange_expressions, | |
) | |
__all__ = ["rearrange"] | |
dims = _C.dims | |
def _create_rearrange_callable( | |
tensor_ndim: int, pattern: str, **axes_lengths: int | |
) -> Callable[[torch.Tensor], torch.Tensor]: | |
r"""Translate an `einops`-style pattern into a callable that performs the rearrange using first-class dimensions. | |
Since the an equivalent result is computed for tensors with the same number of dimensions, with the same pattern and | |
specified axes lengths, this function can be memoized. | |
Args: | |
tensor_ndim (int): the number of dimensions in the tensor to rearrange | |
pattern (str): the `einops`-style rearrangement pattern | |
axes_lengths (int): any additional length specifications for dimensions | |
Returns: | |
Callable[[torch.Tensor], torch.Tensor]: a callable that performs the rearrangement | |
""" | |
left, right = parse_pattern(pattern, axes_lengths) | |
validate_rearrange_expressions(left, right, axes_lengths) | |
n_anon_dims = sum(not dim for dim in left.composition) | |
if left.has_ellipsis: | |
n_ellipsis_dims = tensor_ndim - (len(left.composition) - 1) | |
n_named_dims = len(left.identifiers) - 1 | |
if (pattern_ndim := n_anon_dims + n_named_dims) > tensor_ndim: | |
raise ValueError( | |
f"Number of dimensions in pattern ({pattern_ndim}) must be less than or equal to the number of " | |
f"dimensions in the tensor ({tensor_ndim})" | |
) | |
else: | |
n_ellipsis_dims = 0 | |
n_named_dims = len(left.identifiers) | |
if (pattern_ndim := len(left.composition)) != tensor_ndim: | |
raise ValueError( | |
f"Number of dimensions in pattern ({pattern_ndim}) must be equal to the number of dimensions in " | |
f"the tensor ({tensor_ndim})" | |
) | |
n_dims = n_named_dims + n_ellipsis_dims + n_anon_dims | |
if n_dims == 0: | |
# an identity rearrangement on a 0-dimension tensor | |
return lambda tensor: tensor | |
first_class_dims: Tuple[str, ...] = tuple(f"d{i}" for i in range(n_dims)) | |
identifier_dim_map: Dict[Union[str, AnonymousAxis], Tuple[str, ...]] = {} | |
anon_axes: List[AnonymousAxis] = [] | |
# map the left-hand side identifiers to strings representing first class dims | |
dims_i = 0 | |
for dimension in left.composition: | |
if isinstance(dimension, list): | |
for identifier in dimension: | |
# non-unitary anon axes are not allowed in rearrange & unitary anon axes are represented as empty lists | |
assert isinstance(identifier, str) | |
identifier_dim_map[identifier] = (first_class_dims[dims_i],) | |
dims_i += 1 | |
if not dimension: | |
# unitary anonymous axis | |
anon_axis = AnonymousAxis("1") | |
identifier_dim_map[anon_axis] = (first_class_dims[dims_i],) | |
anon_axes.append(anon_axis) | |
dimension.append(anon_axis) | |
dims_i += 1 | |
elif dimension == _ellipsis: | |
identifier = _ellipsis | |
identifier_dim_map[identifier] = tuple( | |
first_class_dims[dims_i + j] for j in range(n_ellipsis_dims) | |
) | |
dims_i += n_ellipsis_dims | |
else: | |
raise ValueError(f"Unexpected dimension: {dimension}") | |
def composition_to_dims( | |
composition: Sequence[Union[List[Union[str, AnonymousAxis]], str]] | |
) -> List[Union[str, Tuple[str, ...]]]: | |
"""Convert a `ParsedExpression.composition` into a `Tensor.__getitem__` index of strings representing first | |
class dims.""" | |
dim_composition: List[Union[str, Tuple[str, ...]]] = [] | |
for dimension in composition: | |
if isinstance(dimension, list): | |
dim_composition.append( | |
tuple( | |
dim | |
for identifier in dimension | |
for dim in identifier_dim_map[identifier] | |
) | |
) | |
elif dimension == _ellipsis: | |
dim_composition.extend(identifier_dim_map[_ellipsis]) | |
else: | |
raise ValueError(f"Unexpected dimension: {dimension}") | |
return dim_composition | |
left_dims = composition_to_dims(left.composition) | |
right_dims = composition_to_dims(right.composition) | |
anon_dims = tuple(identifier_dim_map[axis][0] for axis in anon_axes) | |
specified_lengths = tuple( | |
(identifier_dim_map[axis][0], length) for axis, length in axes_lengths.items() | |
) | |
custom_rearrange_callable_name = "do_rearrange" | |
custom_rearrange_callable_code = ( | |
( | |
f"def {custom_rearrange_callable_name}(tensor):\n" | |
f" {comma_separate(first_class_dims)} = dims({n_dims})\n" | |
) | |
+ ( | |
"".join( | |
f" {dim}.size = {length}\n" for (dim, length) in specified_lengths | |
) | |
if specified_lengths | |
else "" | |
) | |
+ f" tensor = tensor[{comma_separate(left_dims)}].order({comma_separate(right_dims)})\n" | |
+ ( | |
f" return tensor.sum({comma_separate([anon_dims])}, keepdim=False)\n" | |
if anon_dims | |
else " return tensor\n" | |
) | |
) | |
exec(custom_rearrange_callable_code) | |
return locals()[custom_rearrange_callable_name] | |
def rearrange( | |
tensor: Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor, ...]], | |
pattern: str, | |
**axes_lengths: int, | |
) -> torch.Tensor: | |
r"""A native implementation of `einops.rearrange`, a reader-friendly smart element reordering for multidimensional | |
tensors. This operation includes functionality of transpose (axes permutation), reshape (view), squeeze, unsqueeze, | |
stack, concatenate and other operations. | |
See: https://einops.rocks/api/rearrange/ | |
Args: | |
tensor (Tensor or sequence of Tensor): the tensor(s) to rearrange | |
pattern (str): the rearrangement pattern | |
axes_lengths (int): any additional length specifications for dimensions | |
Returns: | |
Tensor: the rearranged tensor | |
Examples: | |
>>> # suppose we have a set of 32 images in "h w c" format (height-width-channel) | |
>>> images = torch.randn((32, 30, 40, 3)) | |
>>> # stack along first (batch) axis, output is a single array | |
>>> rearrange(images, 'b h w c -> b h w c').shape | |
torch.Size([32, 30, 40, 3]) | |
>>> # concatenate images along height (vertical axis), 960 = 32 * 30 | |
>>> rearrange(images, 'b h w c -> (b h) w c').shape | |
torch.Size([960, 40, 3]) | |
>>> # concatenated images along horizontal axis, 1280 = 32 * 40 | |
>>> rearrange(images, 'b h w c -> h (b w) c').shape | |
torch.Size([30, 1280, 3]) | |
>>> # reordered axes to "b c h w" format for deep learning | |
>>> rearrange(images, 'b h w c -> b c h w').shape | |
torch.Size([32, 3, 30, 40]) | |
>>> # flattened each image into a vector, 3600 = 30 * 40 * 3 | |
>>> rearrange(images, 'b h w c -> b (c h w)').shape | |
torch.Size([32, 3600]) | |
>>> # split each image into 4 smaller (top-left, top-right, bottom-left, bottom-right), 128 = 32 * 2 * 2 | |
>>> rearrange(images, 'b (h1 h) (w1 w) c -> (b h1 w1) h w c', h1=2, w1=2).shape | |
torch.Size([128, 15, 20, 3]) | |
>>> # space-to-depth operation | |
>>> rearrange(images, 'b (h h1) (w w1) c -> b h w (c h1 w1)', h1=2, w1=2).shape | |
torch.Size([32, 15, 20, 12]) | |
""" | |
if not isinstance(tensor, torch.Tensor): | |
tensor = torch.stack(tensor) | |
rearrange_callable = _create_rearrange_callable( | |
tensor.ndim, pattern, **axes_lengths | |
) | |
return rearrange_callable(tensor) | |