|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from absl.testing import parameterized |
|
from big_vision.models.proj.givt import givt |
|
from big_vision.models.proj.givt import parallel_decode |
|
import chex |
|
import jax |
|
import jax.numpy as jnp |
|
|
|
from absl.testing import absltest |
|
|
|
|
|
_BATCH_SIZE = 2 |
|
_OUT_DIM = 4 |
|
_SEQ_LEN = 6 |
|
_NUM_MIXTURES = 4 |
|
|
|
|
|
def _make_test_model(**overwrites): |
|
config = dict( |
|
num_heads=2, |
|
num_decoder_layers=1, |
|
mlp_dim=64, |
|
emb_dim=16, |
|
seq_len=_SEQ_LEN, |
|
out_dim=_OUT_DIM, |
|
num_mixtures=_NUM_MIXTURES, |
|
style="masked", |
|
) |
|
config.update(overwrites) |
|
return givt.Model(**config) |
|
|
|
|
|
def _mask(*flags): |
|
return jnp.asarray(flags).astype(jnp.bool_) |
|
|
|
|
|
class HelperTest(googletest.TestCase): |
|
|
|
def test_get_first_n(self): |
|
with self.subTest("ordered"): |
|
values = jnp.asarray([4, 3, 2, 1, 0]) |
|
k = jnp.asarray([3], jnp.int32) |
|
chex.assert_trees_all_equal( |
|
parallel_decode._get_bottom_k_mask(values, k), _mask(0, 0, 1, 1, 1) |
|
) |
|
|
|
with self.subTest("equal_values"): |
|
values = jnp.ones((5,)) |
|
k = jnp.asarray([3], jnp.int32) |
|
chex.assert_trees_all_equal( |
|
parallel_decode._get_bottom_k_mask(values, k), _mask(1, 1, 1, 0, 0) |
|
) |
|
|
|
with self.subTest("equal_values"): |
|
values = jnp.asarray([1, 2, 2, 2, 3]) |
|
k = jnp.asarray([3], jnp.int32) |
|
chex.assert_trees_all_equal( |
|
parallel_decode._get_bottom_k_mask(values, k), _mask(1, 1, 1, 0, 0) |
|
) |
|
|
|
|
|
class ParallelDecodeTest(parameterized.TestCase): |
|
|
|
def _make_model(self, **overwrites): |
|
model = _make_test_model(**overwrites) |
|
sequence = jax.random.uniform( |
|
jax.random.PRNGKey(0), (_BATCH_SIZE, _SEQ_LEN, _OUT_DIM) |
|
) |
|
labels = jax.random.uniform( |
|
jax.random.PRNGKey(0), (_BATCH_SIZE,), maxval=10 |
|
).astype(jnp.int32) |
|
input_mask = jax.random.uniform( |
|
jax.random.PRNGKey(0), (_BATCH_SIZE, _SEQ_LEN) |
|
).astype(jnp.bool_) |
|
variables = model.init( |
|
jax.random.PRNGKey(0), |
|
sequence, |
|
labels, |
|
input_mask=input_mask, |
|
train=False, |
|
) |
|
return model, variables |
|
|
|
def _test_model(self, rng, model, variables, config): |
|
labels = jnp.ones((_BATCH_SIZE,), dtype=jnp.int32) |
|
state = parallel_decode.decode_masked( |
|
rng, |
|
seq_len=_SEQ_LEN, |
|
feature_dim=_OUT_DIM, |
|
labels=labels, |
|
model=model, |
|
variables=variables, |
|
config=config, |
|
) |
|
self.assertEqual(int(state.step), 4) |
|
|
|
chex.assert_trees_all_equal( |
|
state.uncovered_per_step.sum(0), |
|
jnp.ones((_BATCH_SIZE, _SEQ_LEN), dtype=jnp.int32), |
|
) |
|
|
|
@parameterized.product( |
|
rng_seed=[1, 2], |
|
choice_temperature=[1.0, 4.0], |
|
multivariate=[True, False], |
|
) |
|
def test_decode_masked(self, rng_seed, choice_temperature, multivariate): |
|
rng = jax.random.PRNGKey(rng_seed) |
|
model, variables = self._make_model( |
|
num_mixtures=1 if multivariate else _NUM_MIXTURES, |
|
multivariate=multivariate, |
|
) |
|
config = parallel_decode.MaskedGenerationConfig( |
|
num_steps=4, |
|
choice_temperature=choice_temperature, |
|
) |
|
self._test_model(rng, model, variables, config) |
|
|
|
@parameterized.product( |
|
rng_seed=[1, 2], |
|
choice_temperature=[1.0, 4.0], |
|
w=[0.0, 1.0, 3.0], |
|
per_channel_mixtures=[True, False], |
|
) |
|
def test_cfg(self, rng_seed, choice_temperature, w, per_channel_mixtures): |
|
rng = jax.random.PRNGKey(rng_seed) |
|
model, variables = self._make_model( |
|
num_mixtures=1 if per_channel_mixtures else 3, |
|
drop_labels_probability=0.1, |
|
per_channel_mixtures=per_channel_mixtures, |
|
) |
|
config = parallel_decode.MaskedGenerationConfig( |
|
num_steps=4, |
|
choice_temperature=choice_temperature, |
|
cfg_inference_weight=w, |
|
) |
|
self._test_model(rng, model, variables, config) |
|
|
|
|
|
if __name__ == "__main__": |
|
googletest.main() |
|
|