File size: 6,361 Bytes
f5f3483
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
# Copyright 2022 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utilities to hold expected dimension sizes."""

import math
import re
from typing import Any, Collection, Dict, Optional, Sized, Tuple


Shape = Tuple[Optional[int], ...]


class Dimensions:
  """A lightweight utility that maps strings to shape tuples.

  The most basic usage is:

  .. code::

    >>> dims = chex.Dimensions(B=3, T=5, N=7)  # You can specify any letters.
    >>> dims['NBT']
    (7, 3, 5)

  This is useful when dealing with many differently shaped arrays. For instance,
  let's check the shape of this array:

  .. code::

    >>> x = jnp.array([[2, 0, 5, 6, 3],
    ...                [5, 4, 4, 3, 3],
    ...                [0, 0, 5, 2, 0]])
    >>> chex.assert_shape(x, dims['BT'])

  The dimension sizes can be gotten directly, e.g. :code:`dims.N == 7`. This can
  be useful in many applications. For instance, let's one-hot encode our array.

  .. code::

    >>> y = jax.nn.one_hot(x, dims.N)
    >>> chex.assert_shape(y, dims['BTN'])

  You can also store the shape of a given array in :code:`dims`, e.g.

  .. code::

    >>> z = jnp.array([[0, 6, 0, 2],
    ...                [4, 2, 2, 4]])
    >>> dims['XY'] = z.shape
    >>> dims
    Dimensions(B=3, N=7, T=5, X=2, Y=4)

  You can access the flat size of a shape as

  .. code::

    >>> dims.size('BT')  # Same as prod(dims['BT']).
    15

  You can set a wildcard dimension, cf. :func:`chex.assert_shape`:

  .. code::

    >>> dims.W = None
    >>> dims['BTW']
    (3, 5, None)

  Or you can use the wildcard character `'*'` directly:

  .. code::

    >>> dims['BT*']
    (3, 5, None)

  Single digits are interpreted as literal integers. Note that this notation
  is limited to single-digit literals.

  .. code::

    >>> dims['BT123']
    (3, 5, 1, 2, 3)

  Support for single digits was mainly included to accommodate dummy axes
  introduced for consistent broadcasting. For instance, instead of using
  :func:`jnp.expand_dims <jax.numpy.expand_dims>` you could do the following:

  .. code::

    >>> w = y * x  # Cannot broadcast (3, 5, 7) with (3, 5)
    Traceback (most recent call last):
        ...
    ValueError: Incompatible shapes for broadcasting: ((3, 5, 7), (1, 3, 5))
    >>> w = y * x.reshape(dims['BT1'])
    >>> chex.assert_shape(w, dims['BTN'])

  Sometimes you only care about some array dimensions but not all. You can use
  an underscore to ignore an axis, e.g.

  .. code::

    >>> chex.assert_rank(y, 3)
    >>> dims['__M'] = y.shape  # Skip the first two axes.

  Finally note that a single-character key returns a tuple of length one.

  .. code::

    >>> dims['M']
    (7,)

  """
  # Tell static type checker not to worry about attribute errors.
  _HAS_DYNAMIC_ATTRIBUTES = True

  def __init__(self, **dim_sizes) -> None:
    for dim, size in dim_sizes.items():
      self._setdim(dim, size)

  def size(self, key: str) -> int:
    """Returns the flat size of a given named shape, i.e. prod(shape)."""
    if None in (shape := self[key]):
      raise ValueError(
          f"cannot take product of shape '{key}' = {shape}, "
          'because it contains wildcard dimensions')
    return math.prod(shape)

  def __getitem__(self, key: str) -> Shape:
    self._validate_key(key)
    return tuple(self._getdim(dim) for dim in key)

  def __setitem__(self, key: str, value: Collection[Optional[int]]) -> None:
    self._validate_key(key)
    self._validate_value(value)
    if len(key) != len(value):
      raise ValueError(
          f'key string {repr(key)} and shape {tuple(value)} '
          'have different lengths')
    for dim, size in zip(key, value):
      self._setdim(dim, size)

  def __delitem__(self, key: str) -> None:
    self._validate_key(key)
    for dim in key:
      self._deldim(dim)

  def __repr__(self) -> str:
    args = ', '.join(f'{k}={v}' for k, v in sorted(self._asdict().items()))
    return f'{type(self).__name__}({args})'

  def _asdict(self) -> Dict[str, Optional[int]]:
    return {k: v for k, v in self.__dict__.items()
            if re.fullmatch(r'[a-zA-Z]', k)}

  def _getdim(self, dim: str) -> Optional[int]:
    if dim == '*':
      return None
    if re.fullmatch(r'[0-9]', dim):
      return int(dim)
    try:
      return getattr(self, dim)
    except AttributeError as e:
      raise KeyError(dim) from e

  def _setdim(self, dim: str, size: Optional[int]) -> None:
    if dim == '_':  # Skip.
      return
    self._validate_dim(dim)
    setattr(self, dim, _optional_int(size))

  def _deldim(self, dim: str) -> None:
    if dim == '_':  # Skip.
      return
    self._validate_dim(dim)
    try:
      return delattr(self, dim)
    except AttributeError as e:
      raise KeyError(dim) from e

  def _validate_key(self, key: Any) -> None:
    if not isinstance(key, str):
      raise TypeError(f'key must be a string; got: {type(key).__name__}')

  def _validate_value(self, value: Any) -> None:
    if not isinstance(value, Sized):
      raise TypeError(
          'value must be sized, i.e. an object with a well-defined len(value); '
          f'got object of type: {type(value).__name__}')

  def _validate_dim(self, dim: Any) -> None:
    if not isinstance(dim, str):
      raise TypeError(
          f'dimension name must be a string; got: {type(dim).__name__}')
    if not re.fullmatch(r'[a-zA-Z]', dim):
      raise KeyError(
          'dimension names may only be contain letters (or \'_\' to skip); '
          f'got dimension name: {repr(dim)}')


def _optional_int(x: Any) -> Optional[int]:
  if x is None:
    return None
  try:
    i = int(x)
    if x == i:
      return i
  except ValueError:
    pass
  raise TypeError(f'object cannot be interpreted as a python int: {repr(x)}')