File size: 11,628 Bytes
85e3d20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
# Copyright 2022 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""Unit tests for `baselines.py`."""

import copy
import functools
from typing import Generator

from absl.testing import absltest
from absl.testing import parameterized
import chex

from clrs._src import baselines
from clrs._src import dataset
from clrs._src import probing
from clrs._src import processors
from clrs._src import samplers
from clrs._src import specs

import haiku as hk
import jax
import numpy as np

_Array = np.ndarray


def _error(x, y):
  return np.sum(np.abs(x-y))


def _make_sampler(algo: str, length: int) -> samplers.Sampler:
  sampler, _ = samplers.build_sampler(
      algo,
      seed=samplers.CLRS30['val']['seed'],
      num_samples=samplers.CLRS30['val']['num_samples'],
      length=length,
  )
  return sampler


def _without_permutation(feedback):
  """Replace should-be permutations with pointers."""
  outputs = []
  for x in feedback.outputs:
    if x.type_ != specs.Type.SHOULD_BE_PERMUTATION:
      outputs.append(x)
      continue
    assert x.location == specs.Location.NODE
    outputs.append(probing.DataPoint(name=x.name, location=x.location,
                                     type_=specs.Type.POINTER, data=x.data))
  return feedback._replace(outputs=outputs)


def _make_iterable_sampler(
    algo: str, batch_size: int,
    length: int) -> Generator[samplers.Feedback, None, None]:
  sampler = _make_sampler(algo, length)
  while True:
    yield _without_permutation(sampler.next(batch_size))


def _remove_permutation_from_spec(spec):
  """Modify spec to turn permutation type to pointer."""
  new_spec = {}
  for k in spec:
    if (spec[k][1] == specs.Location.NODE and
        spec[k][2] == specs.Type.SHOULD_BE_PERMUTATION):
      new_spec[k] = (spec[k][0], spec[k][1], specs.Type.POINTER)
    else:
      new_spec[k] = spec[k]
  return new_spec


class BaselinesTest(parameterized.TestCase):

  def test_full_vs_chunked(self):
    """Test that chunking does not affect gradients."""

    batch_size = 4
    length = 8
    algo = 'insertion_sort'
    spec = _remove_permutation_from_spec(specs.SPECS[algo])
    rng_key = jax.random.PRNGKey(42)

    full_ds = _make_iterable_sampler(algo, batch_size, length)
    chunked_ds = dataset.chunkify(
        _make_iterable_sampler(algo, batch_size, length),
        length)
    double_chunked_ds = dataset.chunkify(
        _make_iterable_sampler(algo, batch_size, length),
        length * 2)

    full_batches = [next(full_ds) for _ in range(2)]
    chunked_batches = [next(chunked_ds) for _ in range(2)]
    double_chunk_batch = next(double_chunked_ds)

    with chex.fake_jit():  # jitting makes test longer

      processor_factory = processors.get_processor_factory(
          'mpnn', use_ln=False, nb_triplet_fts=0)
      common_args = dict(processor_factory=processor_factory, hidden_dim=8,
                         learning_rate=0.01,
                         decode_hints=True, encode_hints=True)

      b_full = baselines.BaselineModel(
          spec, dummy_trajectory=full_batches[0], **common_args)
      b_full.init(full_batches[0].features, seed=42)  # pytype: disable=wrong-arg-types  # jax-ndarray
      full_params = b_full.params
      full_loss_0 = b_full.feedback(rng_key, full_batches[0])
      b_full.params = full_params
      full_loss_1 = b_full.feedback(rng_key, full_batches[1])
      new_full_params = b_full.params

      b_chunked = baselines.BaselineModelChunked(
          spec, dummy_trajectory=chunked_batches[0], **common_args)
      b_chunked.init([[chunked_batches[0].features]], seed=42)  # pytype: disable=wrong-arg-types  # jax-ndarray
      chunked_params = b_chunked.params
      jax.tree_util.tree_map(np.testing.assert_array_equal, full_params,
                             chunked_params)
      chunked_loss_0 = b_chunked.feedback(rng_key, chunked_batches[0])
      b_chunked.params = chunked_params
      chunked_loss_1 = b_chunked.feedback(rng_key, chunked_batches[1])
      new_chunked_params = b_chunked.params

      b_chunked.params = chunked_params
      double_chunked_loss = b_chunked.feedback(rng_key, double_chunk_batch)

    # Test that losses match
    np.testing.assert_allclose(full_loss_0, chunked_loss_0, rtol=1e-4)
    np.testing.assert_allclose(full_loss_1, chunked_loss_1, rtol=1e-4)
    np.testing.assert_allclose(full_loss_0 + full_loss_1,
                               2 * double_chunked_loss,
                               rtol=1e-4)

    # Test that gradients are the same (parameters changed equally).
    # First check that gradients were not zero, i.e., parameters have changed.
    param_change, _ = jax.tree_util.tree_flatten(
        jax.tree_util.tree_map(_error, full_params, new_full_params))
    self.assertGreater(np.mean(param_change), 0.1)
    # Now check that full and chunked gradients are the same.
    jax.tree_util.tree_map(
        functools.partial(np.testing.assert_allclose, rtol=1e-4),
        new_full_params, new_chunked_params)

  def test_multi_vs_single(self):
    """Test that multi = single when we only train one of the algorithms."""

    batch_size = 4
    length = 16
    algos = ['insertion_sort', 'activity_selector', 'bfs']
    spec = [_remove_permutation_from_spec(specs.SPECS[algo]) for algo in algos]
    rng_key = jax.random.PRNGKey(42)

    full_ds = [_make_iterable_sampler(algo, batch_size, length)
               for algo in algos]
    full_batches = [next(ds) for ds in full_ds]
    full_batches_2 = [next(ds) for ds in full_ds]

    with chex.fake_jit():  # jitting makes test longer

      processor_factory = processors.get_processor_factory(
          'mpnn', use_ln=False, nb_triplet_fts=0)
      common_args = dict(processor_factory=processor_factory, hidden_dim=8,
                         learning_rate=0.01,
                         decode_hints=True, encode_hints=True)

      b_single = baselines.BaselineModel(
          spec[0], dummy_trajectory=full_batches[0], **common_args)
      b_multi = baselines.BaselineModel(
          spec, dummy_trajectory=full_batches, **common_args)
      b_single.init(full_batches[0].features, seed=0)  # pytype: disable=wrong-arg-types  # jax-ndarray
      b_multi.init([f.features for f in full_batches], seed=0)  # pytype: disable=wrong-arg-types  # jax-ndarray

      single_params = []
      single_losses = []
      multi_params = []
      multi_losses = []

      single_params.append(copy.deepcopy(b_single.params))
      single_losses.append(b_single.feedback(rng_key, full_batches[0]))
      single_params.append(copy.deepcopy(b_single.params))
      single_losses.append(b_single.feedback(rng_key, full_batches_2[0]))
      single_params.append(copy.deepcopy(b_single.params))

      multi_params.append(copy.deepcopy(b_multi.params))
      multi_losses.append(b_multi.feedback(rng_key, full_batches[0],
                                           algorithm_index=0))
      multi_params.append(copy.deepcopy(b_multi.params))
      multi_losses.append(b_multi.feedback(rng_key, full_batches_2[0],
                                           algorithm_index=0))
      multi_params.append(copy.deepcopy(b_multi.params))

    # Test that losses match
    np.testing.assert_array_equal(single_losses, multi_losses)
    # Test that loss decreased
    assert single_losses[1] < single_losses[0]

    # Test that param changes were the same in single and multi-algorithm
    for single, multi in zip(single_params, multi_params):
      assert hk.data_structures.is_subset(subset=single, superset=multi)
      for module_name, params in single.items():
        jax.tree_util.tree_map(np.testing.assert_array_equal, params,
                               multi[module_name])

    # Test that params change for the trained algorithm, but not the others
    for module_name, params in multi_params[0].items():
      param_changes = jax.tree_util.tree_map(lambda a, b: np.sum(np.abs(a - b)),
                                             params,
                                             multi_params[1][module_name])
      param_change = sum(param_changes.values())
      if module_name in single_params[0]:  # params of trained algorithm
        assert param_change > 1e-3
      else:  # params of non-trained algorithms
        assert param_change == 0.0

  @parameterized.parameters(True, False)
  def test_multi_algorithm_idx(self, is_chunked):
    """Test that algorithm selection works as intended."""

    batch_size = 4
    length = 8
    algos = ['insertion_sort', 'activity_selector', 'bfs']
    spec = [_remove_permutation_from_spec(specs.SPECS[algo]) for algo in algos]
    rng_key = jax.random.PRNGKey(42)

    if is_chunked:
      ds = [dataset.chunkify(_make_iterable_sampler(algo, batch_size, length),
                             2 * length) for algo in algos]
    else:
      ds = [_make_iterable_sampler(algo, batch_size, length) for algo in algos]
    batches = [next(d) for d in ds]

    processor_factory = processors.get_processor_factory(
        'mpnn', use_ln=False, nb_triplet_fts=0)
    common_args = dict(processor_factory=processor_factory, hidden_dim=8,
                       learning_rate=0.01,
                       decode_hints=True, encode_hints=True)
    if is_chunked:
      baseline = baselines.BaselineModelChunked(
          spec, dummy_trajectory=batches, **common_args)
      baseline.init([[f.features for f in batches]], seed=0)  # pytype: disable=wrong-arg-types  # jax-ndarray
    else:
      baseline = baselines.BaselineModel(
          spec, dummy_trajectory=batches, **common_args)
      baseline.init([f.features for f in batches], seed=0)  # pytype: disable=wrong-arg-types  # jax-ndarray

    # Find out what parameters change when we train each algorithm
    def _change(x, y):
      changes = {}
      for module_name, params in x.items():
        changes[module_name] = sum(
            jax.tree_util.tree_map(
                lambda a, b: np.sum(np.abs(a-b)), params, y[module_name]
                ).values())
      return changes

    param_changes = []
    for algo_idx in range(len(algos)):
      init_params = copy.deepcopy(baseline.params)
      _ = baseline.feedback(
          rng_key,
          batches[algo_idx],
          algorithm_index=(0, algo_idx) if is_chunked else algo_idx)
      param_changes.append(_change(init_params, baseline.params))

    # Test that non-changing parameters correspond to encoders/decoders
    # associated with the non-trained algorithms
    unchanged = [[k for k in pc if pc[k] == 0] for pc in param_changes]

    def _get_other_algos(algo_idx, modules):
      return set([k for k in modules if '_construct_encoders_decoders' in k
                  and f'algo_{algo_idx}' not in k])

    for algo_idx in range(len(algos)):
      expected_unchanged = _get_other_algos(algo_idx, baseline.params.keys())
      self.assertNotEmpty(expected_unchanged)
      self.assertSetEqual(expected_unchanged, set(unchanged[algo_idx]))


if __name__ == '__main__':
  absltest.main()