Spaces:
Building
Building
# Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# ============================================================================== | |
"""Utilities to hold expected dimension sizes.""" | |
import math | |
import re | |
from typing import Any, Collection, Dict, Optional, Sized, Tuple | |
Shape = Tuple[Optional[int], ...] | |
class Dimensions: | |
"""A lightweight utility that maps strings to shape tuples. | |
The most basic usage is: | |
.. code:: | |
>>> dims = chex.Dimensions(B=3, T=5, N=7) # You can specify any letters. | |
>>> dims['NBT'] | |
(7, 3, 5) | |
This is useful when dealing with many differently shaped arrays. For instance, | |
let's check the shape of this array: | |
.. code:: | |
>>> x = jnp.array([[2, 0, 5, 6, 3], | |
... [5, 4, 4, 3, 3], | |
... [0, 0, 5, 2, 0]]) | |
>>> chex.assert_shape(x, dims['BT']) | |
The dimension sizes can be gotten directly, e.g. :code:`dims.N == 7`. This can | |
be useful in many applications. For instance, let's one-hot encode our array. | |
.. code:: | |
>>> y = jax.nn.one_hot(x, dims.N) | |
>>> chex.assert_shape(y, dims['BTN']) | |
You can also store the shape of a given array in :code:`dims`, e.g. | |
.. code:: | |
>>> z = jnp.array([[0, 6, 0, 2], | |
... [4, 2, 2, 4]]) | |
>>> dims['XY'] = z.shape | |
>>> dims | |
Dimensions(B=3, N=7, T=5, X=2, Y=4) | |
You can access the flat size of a shape as | |
.. code:: | |
>>> dims.size('BT') # Same as prod(dims['BT']). | |
15 | |
You can set a wildcard dimension, cf. :func:`chex.assert_shape`: | |
.. code:: | |
>>> dims.W = None | |
>>> dims['BTW'] | |
(3, 5, None) | |
Or you can use the wildcard character `'*'` directly: | |
.. code:: | |
>>> dims['BT*'] | |
(3, 5, None) | |
Single digits are interpreted as literal integers. Note that this notation | |
is limited to single-digit literals. | |
.. code:: | |
>>> dims['BT123'] | |
(3, 5, 1, 2, 3) | |
Support for single digits was mainly included to accommodate dummy axes | |
introduced for consistent broadcasting. For instance, instead of using | |
:func:`jnp.expand_dims <jax.numpy.expand_dims>` you could do the following: | |
.. code:: | |
>>> w = y * x # Cannot broadcast (3, 5, 7) with (3, 5) | |
Traceback (most recent call last): | |
... | |
ValueError: Incompatible shapes for broadcasting: ((3, 5, 7), (1, 3, 5)) | |
>>> w = y * x.reshape(dims['BT1']) | |
>>> chex.assert_shape(w, dims['BTN']) | |
Sometimes you only care about some array dimensions but not all. You can use | |
an underscore to ignore an axis, e.g. | |
.. code:: | |
>>> chex.assert_rank(y, 3) | |
>>> dims['__M'] = y.shape # Skip the first two axes. | |
Finally note that a single-character key returns a tuple of length one. | |
.. code:: | |
>>> dims['M'] | |
(7,) | |
""" | |
# Tell static type checker not to worry about attribute errors. | |
_HAS_DYNAMIC_ATTRIBUTES = True | |
def __init__(self, **dim_sizes) -> None: | |
for dim, size in dim_sizes.items(): | |
self._setdim(dim, size) | |
def size(self, key: str) -> int: | |
"""Returns the flat size of a given named shape, i.e. prod(shape).""" | |
if None in (shape := self[key]): | |
raise ValueError( | |
f"cannot take product of shape '{key}' = {shape}, " | |
'because it contains wildcard dimensions') | |
return math.prod(shape) | |
def __getitem__(self, key: str) -> Shape: | |
self._validate_key(key) | |
return tuple(self._getdim(dim) for dim in key) | |
def __setitem__(self, key: str, value: Collection[Optional[int]]) -> None: | |
self._validate_key(key) | |
self._validate_value(value) | |
if len(key) != len(value): | |
raise ValueError( | |
f'key string {repr(key)} and shape {tuple(value)} ' | |
'have different lengths') | |
for dim, size in zip(key, value): | |
self._setdim(dim, size) | |
def __delitem__(self, key: str) -> None: | |
self._validate_key(key) | |
for dim in key: | |
self._deldim(dim) | |
def __repr__(self) -> str: | |
args = ', '.join(f'{k}={v}' for k, v in sorted(self._asdict().items())) | |
return f'{type(self).__name__}({args})' | |
def _asdict(self) -> Dict[str, Optional[int]]: | |
return {k: v for k, v in self.__dict__.items() | |
if re.fullmatch(r'[a-zA-Z]', k)} | |
def _getdim(self, dim: str) -> Optional[int]: | |
if dim == '*': | |
return None | |
if re.fullmatch(r'[0-9]', dim): | |
return int(dim) | |
try: | |
return getattr(self, dim) | |
except AttributeError as e: | |
raise KeyError(dim) from e | |
def _setdim(self, dim: str, size: Optional[int]) -> None: | |
if dim == '_': # Skip. | |
return | |
self._validate_dim(dim) | |
setattr(self, dim, _optional_int(size)) | |
def _deldim(self, dim: str) -> None: | |
if dim == '_': # Skip. | |
return | |
self._validate_dim(dim) | |
try: | |
return delattr(self, dim) | |
except AttributeError as e: | |
raise KeyError(dim) from e | |
def _validate_key(self, key: Any) -> None: | |
if not isinstance(key, str): | |
raise TypeError(f'key must be a string; got: {type(key).__name__}') | |
def _validate_value(self, value: Any) -> None: | |
if not isinstance(value, Sized): | |
raise TypeError( | |
'value must be sized, i.e. an object with a well-defined len(value); ' | |
f'got object of type: {type(value).__name__}') | |
def _validate_dim(self, dim: Any) -> None: | |
if not isinstance(dim, str): | |
raise TypeError( | |
f'dimension name must be a string; got: {type(dim).__name__}') | |
if not re.fullmatch(r'[a-zA-Z]', dim): | |
raise KeyError( | |
'dimension names may only be contain letters (or \'_\' to skip); ' | |
f'got dimension name: {repr(dim)}') | |
def _optional_int(x: Any) -> Optional[int]: | |
if x is None: | |
return None | |
try: | |
i = int(x) | |
if x == i: | |
return i | |
except ValueError: | |
pass | |
raise TypeError(f'object cannot be interpreted as a python int: {repr(x)}') | |