File size: 4,095 Bytes
f5f3483
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
# Copyright 2022 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for ``dimensions`` module."""

import doctest

from absl.testing import absltest
from absl.testing import parameterized
from chex._src import asserts
from chex._src import dimensions
import jax
import numpy as np


class _ChexModule:
  """Mock module for providing minimal context to docstring tests."""
  assert_shape = asserts.assert_shape
  assert_rank = asserts.assert_rank
  Dimensions = dimensions.Dimensions  # pylint: disable=invalid-name


class DimensionsTest(parameterized.TestCase):

  def test_docstring_examples(self):
    doctest.run_docstring_examples(
        dimensions.Dimensions,
        globs={'chex': _ChexModule, 'jax': jax, 'jnp': jax.numpy})

  @parameterized.named_parameters([
      ('scalar', '', (), ()),
      ('vector', 'a', (7,), (7,)),
      ('list', 'ab', [7, 11], (7, 11)),
      ('numpy_array', 'abc', np.array([7, 11, 13]), (7, 11, 13)),
      ('case_sensitive', 'aA', (7, 11), (7, 11)),
  ])
  def test_set_ok(self, k, v, shape):
    dims = dimensions.Dimensions(x=23, y=29)
    dims[k] = v
    asserts.assert_shape(np.empty((23, *shape, 29)), dims['x' + k + 'y'])

  def test_set_wildcard(self):
    dims = dimensions.Dimensions(x=23, y=29)
    dims['a_b__'] = (7, 11, 13, 17, 19)
    self.assertEqual(dims['xayb'], (23, 7, 29, 13))
    with self.assertRaisesRegex(KeyError, r'\*'):
      dims['ab*'] = (7, 11, 13)

  def test_get_wildcard(self):
    dims = dimensions.Dimensions(x=23, y=29)
    self.assertEqual(dims['x*y**'], (23, None, 29, None, None))
    asserts.assert_shape(np.empty((23, 1, 29, 2, 3)), dims['x*y**'])
    with self.assertRaisesRegex(KeyError, r'\_'):
      dims['xy_']  # pylint: disable=pointless-statement

  def test_get_literals(self):
    dims = dimensions.Dimensions(x=23, y=29)
    self.assertEqual(dims['x1y23'], (23, 1, 29, 2, 3))

  @parameterized.named_parameters([
      ('scalar', 'a', 7, TypeError, r'value must be sized'),
      ('iterator', 'a', (x for x in [7]), TypeError, r'value must be sized'),
      ('len_mismatch', 'ab', (7, 11, 13), ValueError, r'different length'),
      ('non_integer_size', 'a', (7.001,),
       TypeError, r'cannot be interpreted as a python int'),
      ('bad_key_type', 13, (7,), TypeError, r'key must be a string'),
      ('bad_key_string', '@%^#', (7, 11, 13, 17), KeyError, r'\@'),
  ])
  def test_set_exception(self, k, v, e, m):
    dims = dimensions.Dimensions(x=23, y=29)
    with self.assertRaisesRegex(e, m):
      dims[k] = v

  @parameterized.named_parameters([
      ('bad_key_type', 13, TypeError, r'key must be a string'),
      ('bad_key_string', '@%^#', KeyError, r'\@'),
  ])
  def test_get_exception(self, k, e, m):
    dims = dimensions.Dimensions(x=23, y=29)
    with self.assertRaisesRegex(e, m):
      dims[k]  # pylint: disable=pointless-statement

  @parameterized.named_parameters([
      ('scalar', '', (), 1),
      ('nonscalar', 'ab', (3, 5), 15),
  ])
  def test_size_ok(self, names, shape, expected_size):
    dims = dimensions.Dimensions(**dict(zip(names, shape)))
    self.assertEqual(dims.size(names), expected_size)

  @parameterized.named_parameters([
      ('named', 'ab'),
      ('asterisk', 'a*'),
  ])
  def test_size_fail_wildcard(self, names):
    dims = dimensions.Dimensions(a=3, b=None)
    with self.assertRaisesRegex(ValueError, r'cannot take product of shape'):
      dims.size(names)


if __name__ == '__main__':
  absltest.main()