|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Tests for GIVT model.""" |
|
|
|
from absl.testing import parameterized |
|
from big_vision.models.proj.givt import givt |
|
import jax |
|
import jax.numpy as jnp |
|
import numpy as np |
|
|
|
from absl.testing import absltest |
|
|
|
|
|
_BATCH_SIZE = 2 |
|
_OUT_DIM = 4 |
|
_SEQ_LEN = 16 |
|
_NUM_MIXTURES = 4 |
|
|
|
|
|
def _make_test_model(**overwrites): |
|
config = dict( |
|
num_heads=2, |
|
num_decoder_layers=1, |
|
mlp_dim=64, |
|
emb_dim=16, |
|
seq_len=_SEQ_LEN, |
|
out_dim=_OUT_DIM, |
|
num_mixtures=_NUM_MIXTURES, |
|
) |
|
config.update(overwrites) |
|
return givt.Model(**config) |
|
|
|
|
|
class MaskedTransformerTest(parameterized.TestCase): |
|
|
|
@parameterized.product(rng_seed=[0]) |
|
def test_masks(self, rng_seed): |
|
m = _make_test_model(style="masked") |
|
mask = m.get_input_mask_training(jax.random.PRNGKey(rng_seed), (2, 16)) |
|
self.assertEqual(mask.shape, (2, 16)) |
|
|
|
self.assertTrue(np.all(mask.sum(-1) > 1)) |
|
|
|
@parameterized.product( |
|
train=[True, False], |
|
multivariate=[True, False], |
|
per_channel_mixtures=[True, False], |
|
drop_labels_probability=[0.0, 0.1], |
|
style=["masked", "ar"], |
|
) |
|
def test_apply( |
|
self, |
|
train, |
|
multivariate, |
|
per_channel_mixtures, |
|
drop_labels_probability, |
|
style, |
|
): |
|
if per_channel_mixtures and multivariate: |
|
self.skipTest("Not supported") |
|
model = _make_test_model( |
|
style=style, |
|
multivariate=multivariate, |
|
num_mixtures=1 if multivariate else _NUM_MIXTURES, |
|
per_channel_mixtures=per_channel_mixtures, |
|
drop_labels_probability=drop_labels_probability, |
|
) |
|
sequence = jax.random.uniform( |
|
jax.random.PRNGKey(0), (_BATCH_SIZE, _SEQ_LEN, _OUT_DIM) |
|
) |
|
labels = jax.random.uniform( |
|
jax.random.PRNGKey(0), (_BATCH_SIZE,), maxval=10 |
|
).astype(jnp.int32) |
|
input_mask = jax.random.uniform( |
|
jax.random.PRNGKey(0), (_BATCH_SIZE, _SEQ_LEN) |
|
).astype(jnp.bool_) |
|
variables = model.init( |
|
jax.random.PRNGKey(0), |
|
sequence, |
|
labels, |
|
input_mask=input_mask, |
|
train=train, |
|
) |
|
logits, pdf = model.apply( |
|
variables, sequence, labels, input_mask=input_mask, train=train |
|
) |
|
nll = -pdf.log_prob(sequence) |
|
self.assertFalse(np.any(np.isnan(nll))) |
|
if multivariate: |
|
self.assertEqual( |
|
logits.shape, (_BATCH_SIZE, _SEQ_LEN, _OUT_DIM**2 + _OUT_DIM) |
|
) |
|
self.assertEqual(nll.shape, (_BATCH_SIZE, _SEQ_LEN)) |
|
elif per_channel_mixtures: |
|
self.assertEqual( |
|
logits.shape, |
|
(_BATCH_SIZE, _SEQ_LEN, 3 * _NUM_MIXTURES * _OUT_DIM), |
|
) |
|
self.assertEqual(nll.shape, (_BATCH_SIZE, _SEQ_LEN, _OUT_DIM)) |
|
else: |
|
self.assertEqual( |
|
logits.shape, |
|
(_BATCH_SIZE, _SEQ_LEN, _NUM_MIXTURES + _NUM_MIXTURES * _OUT_DIM * 2), |
|
) |
|
self.assertEqual(nll.shape, (_BATCH_SIZE, _SEQ_LEN)) |
|
|
|
sample = pdf.sample(seed=jax.random.PRNGKey(0)) |
|
self.assertEqual(sample.shape, (_BATCH_SIZE, _SEQ_LEN, _OUT_DIM)) |
|
|
|
|
|
if __name__ == "__main__": |
|
googletest.main() |
|
|