|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Unit tests for `losses.py`.""" |
|
|
|
from typing import Generator |
|
|
|
from absl.testing import absltest |
|
from absl.testing import parameterized |
|
|
|
from clrs._src import dataset |
|
from clrs._src import losses |
|
from clrs._src import probing |
|
from clrs._src import samplers |
|
from clrs._src import specs |
|
import jax |
|
import jax.numpy as jnp |
|
import numpy as np |
|
|
|
_Array = np.ndarray |
|
_Location = specs.Location |
|
|
|
|
|
def _make_sampler(algo: str, nb_nodes: int) -> samplers.Sampler: |
|
sampler, _ = samplers.build_sampler( |
|
algo, |
|
seed=samplers.CLRS30['val']['seed'], |
|
num_samples=samplers.CLRS30['val']['num_samples'], |
|
length=nb_nodes, |
|
) |
|
return sampler |
|
|
|
|
|
def _make_iterable_sampler( |
|
algo: str, batch_size: int, |
|
nb_nodes: int) -> Generator[samplers.Feedback, None, None]: |
|
sampler = _make_sampler(algo, nb_nodes) |
|
while True: |
|
yield sampler.next(batch_size) |
|
|
|
|
|
def _as_pred_data(x, nb_nodes, seed, batch_axis): |
|
"""Fake a prediction from a data point.""" |
|
|
|
key = jax.random.PRNGKey(seed) |
|
data = jax.random.permutation(key, x.data, axis=batch_axis) |
|
|
|
if x.type_ == specs.Type.POINTER: |
|
return jax.nn.one_hot(data, nb_nodes) |
|
return data |
|
|
|
|
|
def _mask_datapoint(x, seed, t_axis=None): |
|
"""Add some masking to data.""" |
|
key = jax.random.PRNGKey(seed) |
|
data = x.data |
|
if x.type_ == specs.Type.MASK: |
|
|
|
mask_shape = list(data.shape) |
|
if t_axis is not None: |
|
mask_shape[t_axis] = 1 |
|
mask = jax.random.uniform(key, tuple(mask_shape)) < 0.2 |
|
data = jnp.where(mask, specs.OutputClass.MASKED, data) |
|
elif x.type_ in [specs.Type.CATEGORICAL, specs.Type.MASK_ONE]: |
|
|
|
mask_shape = list(data.shape)[:-1] |
|
if t_axis is not None: |
|
mask_shape[t_axis] = 1 |
|
mask = jax.random.uniform(key, tuple(mask_shape)) < 0.2 |
|
data = jnp.where(mask[..., None], specs.OutputClass.MASKED, data) |
|
return probing.DataPoint(name=x.name, location=x.location, type_=x.type_, |
|
data=data) |
|
|
|
|
|
def _rand_diff(seed, shape): |
|
return 2.0 * jax.random.uniform(jax.random.PRNGKey(seed), shape) - 1.0 |
|
|
|
|
|
def _rand_mask(seed, shape, p=0.5): |
|
return (jax.random.uniform(jax.random.PRNGKey(seed), shape) > p).astype(float) |
|
|
|
|
|
def invert(d): |
|
"""Dict of lists -> list of dicts.""" |
|
if d: |
|
return [dict(zip(d, i)) for i in zip(*d.values())] |
|
|
|
|
|
def _create_data(algo, nb_nodes): |
|
batch_size = 8 |
|
|
|
ds = _make_iterable_sampler(algo, batch_size, nb_nodes) |
|
full_sample = next(ds) |
|
|
|
chunk_length = full_sample.features.lengths[0].astype(int) |
|
chunked_ds = dataset.chunkify( |
|
_make_iterable_sampler(algo, batch_size, nb_nodes), |
|
chunk_length) |
|
chunk_sample = next(chunked_ds) |
|
return full_sample, chunk_sample |
|
|
|
|
|
class FullVsChunkLossesTest(parameterized.TestCase): |
|
"""Test that the full and chunked versions of the losses match.""" |
|
|
|
|
|
@parameterized.parameters('dfs', 'floyd_warshall') |
|
def test_output_loss(self, algo): |
|
nb_nodes = 16 |
|
full_sample, chunk_sample = _create_data(algo, nb_nodes) |
|
|
|
|
|
for truth_full, truth_chunked in zip(full_sample.outputs, |
|
chunk_sample.outputs): |
|
chunk_output_loss = losses.output_loss_chunked( |
|
truth=_mask_datapoint(truth_chunked, seed=0), |
|
pred=_as_pred_data(truth_chunked, nb_nodes, 0, 1), |
|
is_last=chunk_sample.features.is_last, |
|
nb_nodes=nb_nodes, |
|
) |
|
full_output_loss = losses.output_loss( |
|
truth=_mask_datapoint(truth_full, seed=0), |
|
pred=_as_pred_data(truth_full, nb_nodes, 0, 0), |
|
nb_nodes=nb_nodes, |
|
) |
|
np.testing.assert_allclose(chunk_output_loss, full_output_loss, rtol=1e-4) |
|
|
|
@parameterized.parameters('dfs', 'floyd_warshall') |
|
def test_hint_loss(self, algo): |
|
nb_nodes = 16 |
|
full_sample, chunk_sample = _create_data(algo, nb_nodes) |
|
for truth_full, truth_chunked in zip(full_sample.features.hints, |
|
chunk_sample.features.hints): |
|
np.testing.assert_array_equal(truth_full.data, truth_chunked.data) |
|
pred = _as_pred_data(truth_chunked, nb_nodes, 0, 1) |
|
chunk_hint_loss = losses.hint_loss_chunked( |
|
truth=_mask_datapoint(truth_chunked, seed=1, t_axis=0), |
|
pred=pred, |
|
is_first=chunk_sample.features.is_first, |
|
nb_nodes=nb_nodes, |
|
) |
|
|
|
full_preds = pred[1:] |
|
full_hint_loss = losses.hint_loss( |
|
truth=_mask_datapoint(truth_full, 1, t_axis=0), |
|
preds=full_preds, |
|
lengths=full_sample.features.lengths, |
|
nb_nodes=nb_nodes, |
|
) |
|
np.testing.assert_allclose(chunk_hint_loss, full_hint_loss, rtol=1e-4) |
|
|
|
|
|
if __name__ == '__main__': |
|
absltest.main() |
|
|