|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Tests for processors.py.""" |
|
|
|
from absl.testing import absltest |
|
import chex |
|
from clrs._src import processors |
|
import haiku as hk |
|
import jax.numpy as jnp |
|
|
|
|
|
class MemnetTest(absltest.TestCase): |
|
|
|
def test_simple_run_and_check_shapes(self): |
|
|
|
batch_size = 64 |
|
vocab_size = 177 |
|
embedding_size = 64 |
|
sentence_size = 11 |
|
memory_size = 320 |
|
linear_output_size = 128 |
|
num_hops = 2 |
|
use_ln = True |
|
|
|
def forward_fn(queries, stories): |
|
model = processors.MemNetFull( |
|
vocab_size=vocab_size, |
|
embedding_size=embedding_size, |
|
sentence_size=sentence_size, |
|
memory_size=memory_size, |
|
linear_output_size=linear_output_size, |
|
num_hops=num_hops, |
|
use_ln=use_ln) |
|
return model._apply(queries, stories) |
|
|
|
forward = hk.transform(forward_fn) |
|
|
|
queries = jnp.ones([batch_size, sentence_size], dtype=jnp.int32) |
|
stories = jnp.ones([batch_size, memory_size, sentence_size], |
|
dtype=jnp.int32) |
|
|
|
key = hk.PRNGSequence(42) |
|
params = forward.init(next(key), queries, stories) |
|
|
|
model_output = forward.apply(params, None, queries, stories) |
|
chex.assert_shape(model_output, [batch_size, vocab_size]) |
|
chex.assert_type(model_output, jnp.float32) |
|
|
|
|
|
if __name__ == '__main__': |
|
absltest.main() |
|
|