|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Unit tests for `samplers.py`.""" |
|
|
|
from absl.testing import absltest |
|
from absl.testing import parameterized |
|
|
|
import chex |
|
from clrs._src import probing |
|
from clrs._src import samplers |
|
from clrs._src import specs |
|
import jax |
|
import numpy as np |
|
|
|
|
|
class SamplersTest(parameterized.TestCase): |
|
|
|
@parameterized.parameters(*specs.CLRS_30_ALGS) |
|
def test_sampler_determinism(self, name): |
|
num_samples = 3 |
|
num_nodes = 10 |
|
sampler, _ = samplers.build_sampler(name, num_samples, num_nodes) |
|
|
|
np.random.seed(47) |
|
feedback = sampler.next() |
|
expected = feedback.outputs[0].data.copy() |
|
|
|
np.random.seed(48) |
|
feedback = sampler.next() |
|
actual = feedback.outputs[0].data.copy() |
|
|
|
|
|
np.testing.assert_array_equal(expected, actual) |
|
|
|
@parameterized.parameters(*specs.CLRS_30_ALGS) |
|
def test_sampler_batch_determinism(self, name): |
|
num_samples = 10 |
|
batch_size = 5 |
|
num_nodes = 10 |
|
seed = 0 |
|
sampler_1, _ = samplers.build_sampler( |
|
name, num_samples, num_nodes, seed=seed) |
|
sampler_2, _ = samplers.build_sampler( |
|
name, num_samples, num_nodes, seed=seed) |
|
|
|
feedback_1 = sampler_1.next(batch_size) |
|
feedback_2 = sampler_2.next(batch_size) |
|
|
|
|
|
jax.tree_util.tree_map(np.testing.assert_array_equal, feedback_1, |
|
feedback_2) |
|
|
|
def test_end_to_end(self): |
|
num_samples = 7 |
|
num_nodes = 3 |
|
sampler, _ = samplers.build_sampler("bfs", num_samples, num_nodes) |
|
feedback = sampler.next() |
|
|
|
inputs = feedback.features.inputs |
|
self.assertLen(inputs, 4) |
|
self.assertEqual(inputs[0].name, "pos") |
|
self.assertEqual(inputs[0].data.shape, (num_samples, num_nodes)) |
|
|
|
outputs = feedback.outputs |
|
self.assertLen(outputs, 1) |
|
self.assertEqual(outputs[0].name, "pi") |
|
self.assertEqual(outputs[0].data.shape, (num_samples, num_nodes)) |
|
|
|
def test_batch_size(self): |
|
num_samples = 7 |
|
num_nodes = 3 |
|
sampler, _ = samplers.build_sampler("bfs", num_samples, num_nodes) |
|
|
|
|
|
feedback = sampler.next() |
|
for dp in feedback.features.inputs: |
|
self.assertEqual(dp.data.shape[0], num_samples) |
|
|
|
for dp in feedback.outputs: |
|
self.assertEqual(dp.data.shape[0], num_samples) |
|
|
|
for dp in feedback.features.hints: |
|
self.assertEqual(dp.data.shape[1], num_samples) |
|
|
|
self.assertLen(feedback.features.lengths, num_samples) |
|
|
|
|
|
batch_size = 5 |
|
feedback = sampler.next(batch_size) |
|
|
|
for dp in feedback.features.inputs: |
|
self.assertEqual(dp.data.shape[0], batch_size) |
|
|
|
for dp in feedback.outputs: |
|
self.assertEqual(dp.data.shape[0], batch_size) |
|
|
|
for dp in feedback.features.hints: |
|
self.assertEqual(dp.data.shape[1], batch_size) |
|
|
|
self.assertLen(feedback.features.lengths, batch_size) |
|
|
|
def test_batch_io(self): |
|
sample = [ |
|
probing.DataPoint( |
|
name="x", |
|
location=specs.Location.NODE, |
|
type_=specs.Type.SCALAR, |
|
data=np.zeros([1, 3]), |
|
), |
|
probing.DataPoint( |
|
name="y", |
|
location=specs.Location.EDGE, |
|
type_=specs.Type.MASK, |
|
data=np.zeros([1, 3, 3]), |
|
), |
|
] |
|
|
|
trajectory = [sample.copy(), sample.copy(), sample.copy(), sample.copy()] |
|
batched = samplers._batch_io(trajectory) |
|
|
|
np.testing.assert_array_equal(batched[0].data, np.zeros([4, 3])) |
|
np.testing.assert_array_equal(batched[1].data, np.zeros([4, 3, 3])) |
|
|
|
def test_batch_hint(self): |
|
sample0 = [ |
|
probing.DataPoint( |
|
name="x", |
|
location=specs.Location.NODE, |
|
type_=specs.Type.MASK, |
|
data=np.zeros([2, 1, 3]), |
|
), |
|
probing.DataPoint( |
|
name="y", |
|
location=specs.Location.NODE, |
|
type_=specs.Type.POINTER, |
|
data=np.zeros([2, 1, 3]), |
|
), |
|
] |
|
|
|
sample1 = [ |
|
probing.DataPoint( |
|
name="x", |
|
location=specs.Location.NODE, |
|
type_=specs.Type.MASK, |
|
data=np.zeros([1, 1, 3]), |
|
), |
|
probing.DataPoint( |
|
name="y", |
|
location=specs.Location.NODE, |
|
type_=specs.Type.POINTER, |
|
data=np.zeros([1, 1, 3]), |
|
), |
|
] |
|
|
|
trajectory = [sample0, sample1] |
|
batched, lengths = samplers._batch_hints(trajectory, 0) |
|
|
|
np.testing.assert_array_equal(batched[0].data, np.zeros([2, 2, 3])) |
|
np.testing.assert_array_equal(batched[1].data, np.zeros([2, 2, 3])) |
|
np.testing.assert_array_equal(lengths, np.array([2, 1])) |
|
|
|
batched, lengths = samplers._batch_hints(trajectory, 5) |
|
|
|
np.testing.assert_array_equal(batched[0].data, np.zeros([5, 2, 3])) |
|
np.testing.assert_array_equal(batched[1].data, np.zeros([5, 2, 3])) |
|
np.testing.assert_array_equal(lengths, np.array([2, 1])) |
|
|
|
def test_padding(self): |
|
lens = np.random.choice(10, (10,), replace=True) + 1 |
|
trajectory = [] |
|
for len_ in lens: |
|
trajectory.append([ |
|
probing.DataPoint( |
|
name="x", |
|
location=specs.Location.NODE, |
|
type_=specs.Type.MASK, |
|
data=np.ones([len_, 1, 3]), |
|
) |
|
]) |
|
|
|
batched, lengths = samplers._batch_hints(trajectory, 0) |
|
np.testing.assert_array_equal(lengths, lens) |
|
|
|
for i in range(len(lens)): |
|
ones = batched[0].data[:lens[i], i, :] |
|
zeros = batched[0].data[lens[i]:, i, :] |
|
np.testing.assert_array_equal(ones, np.ones_like(ones)) |
|
np.testing.assert_array_equal(zeros, np.zeros_like(zeros)) |
|
|
|
|
|
class ProcessRandomPosTest(parameterized.TestCase): |
|
|
|
@parameterized.parameters(["insertion_sort", "naive_string_matcher"]) |
|
def test_random_pos(self, algorithm_name): |
|
batch_size, length = 12, 10 |
|
def _make_sampler(): |
|
sampler, _ = samplers.build_sampler( |
|
algorithm_name, |
|
seed=0, |
|
num_samples=100, |
|
length=length, |
|
) |
|
while True: |
|
yield sampler.next(batch_size) |
|
sampler_1 = _make_sampler() |
|
sampler_2 = _make_sampler() |
|
sampler_2 = samplers.process_random_pos(sampler_2, np.random.RandomState(0)) |
|
|
|
batch_without_rand_pos = next(sampler_1) |
|
batch_with_rand_pos = next(sampler_2) |
|
pos_idx = [x.name for x in batch_without_rand_pos.features.inputs].index( |
|
"pos") |
|
fixed_pos = batch_without_rand_pos.features.inputs[pos_idx] |
|
rand_pos = batch_with_rand_pos.features.inputs[pos_idx] |
|
self.assertEqual(rand_pos.location, specs.Location.NODE) |
|
self.assertEqual(rand_pos.type_, specs.Type.SCALAR) |
|
self.assertEqual(rand_pos.data.shape, (batch_size, length)) |
|
self.assertEqual(rand_pos.data.shape, fixed_pos.data.shape) |
|
self.assertEqual(rand_pos.type_, fixed_pos.type_) |
|
self.assertEqual(rand_pos.location, fixed_pos.location) |
|
|
|
assert (rand_pos.data.std(axis=0) > 1e-3).all() |
|
assert (fixed_pos.data.std(axis=0) < 1e-9).all() |
|
if "string" in algorithm_name: |
|
expected = np.concatenate([np.arange(4*length//5)/(4*length//5), |
|
np.arange(length//5)/(length//5)]) |
|
else: |
|
expected = np.arange(length)/length |
|
np.testing.assert_array_equal( |
|
fixed_pos.data, np.broadcast_to(expected, (batch_size, length))) |
|
|
|
batch_with_rand_pos.features.inputs[pos_idx] = fixed_pos |
|
chex.assert_trees_all_equal(batch_with_rand_pos, batch_without_rand_pos) |
|
|
|
|
|
if __name__ == "__main__": |
|
absltest.main() |
|
|