|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Unit tests for `evaluation.py`.""" |
|
|
|
from absl.testing import absltest |
|
from clrs._src import evaluation |
|
from clrs._src import probing |
|
from clrs._src import specs |
|
|
|
import jax |
|
import jax.numpy as jnp |
|
import numpy as np |
|
|
|
|
|
class EvaluationTest(absltest.TestCase): |
|
|
|
def test_reduce_permutations(self): |
|
b = 8 |
|
n = 16 |
|
pred = jnp.stack([jax.random.permutation(jax.random.PRNGKey(i), n) |
|
for i in range(b)]) |
|
heads = jax.random.randint(jax.random.PRNGKey(42), (b,), 0, n) |
|
|
|
perm = probing.DataPoint(name='test', |
|
type_=specs.Type.PERMUTATION_POINTER, |
|
location=specs.Location.NODE, |
|
data=jax.nn.one_hot(pred, n)) |
|
mask = probing.DataPoint(name='test_mask', |
|
type_=specs.Type.MASK_ONE, |
|
location=specs.Location.NODE, |
|
data=jax.nn.one_hot(heads, n)) |
|
output = evaluation.fuse_perm_and_mask(perm=perm, mask=mask) |
|
expected_output = np.array(pred) |
|
expected_output[np.arange(b), heads] = heads |
|
self.assertEqual(output.name, 'test') |
|
self.assertEqual(output.type_, specs.Type.POINTER) |
|
self.assertEqual(output.location, specs.Location.NODE) |
|
np.testing.assert_allclose(output.data, expected_output) |
|
|
|
|
|
if __name__ == '__main__': |
|
absltest.main() |
|
|