pranavSIT's picture
added pali inference
74e8f2f
# Copyright 2024 Big Vision Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for GIVT model."""
from absl.testing import parameterized
from big_vision.models.proj.givt import givt
import jax
import jax.numpy as jnp
import numpy as np
from absl.testing import absltest
_BATCH_SIZE = 2
_OUT_DIM = 4
_SEQ_LEN = 16
_NUM_MIXTURES = 4
def _make_test_model(**overwrites):
config = dict(
num_heads=2,
num_decoder_layers=1,
mlp_dim=64,
emb_dim=16,
seq_len=_SEQ_LEN,
out_dim=_OUT_DIM,
num_mixtures=_NUM_MIXTURES,
)
config.update(overwrites)
return givt.Model(**config)
class MaskedTransformerTest(parameterized.TestCase):
@parameterized.product(rng_seed=[0])
def test_masks(self, rng_seed):
m = _make_test_model(style="masked")
mask = m.get_input_mask_training(jax.random.PRNGKey(rng_seed), (2, 16))
self.assertEqual(mask.shape, (2, 16))
# At least one should definitly be masked out.
self.assertTrue(np.all(mask.sum(-1) > 1))
@parameterized.product(
train=[True, False],
multivariate=[True, False],
per_channel_mixtures=[True, False],
drop_labels_probability=[0.0, 0.1],
style=["masked", "ar"],
)
def test_apply(
self,
train,
multivariate,
per_channel_mixtures,
drop_labels_probability,
style,
):
if per_channel_mixtures and multivariate:
self.skipTest("Not supported")
model = _make_test_model(
style=style,
multivariate=multivariate,
num_mixtures=1 if multivariate else _NUM_MIXTURES,
per_channel_mixtures=per_channel_mixtures,
drop_labels_probability=drop_labels_probability,
)
sequence = jax.random.uniform(
jax.random.PRNGKey(0), (_BATCH_SIZE, _SEQ_LEN, _OUT_DIM)
)
labels = jax.random.uniform(
jax.random.PRNGKey(0), (_BATCH_SIZE,), maxval=10
).astype(jnp.int32)
input_mask = jax.random.uniform(
jax.random.PRNGKey(0), (_BATCH_SIZE, _SEQ_LEN)
).astype(jnp.bool_)
variables = model.init(
jax.random.PRNGKey(0),
sequence,
labels,
input_mask=input_mask,
train=train,
)
logits, pdf = model.apply(
variables, sequence, labels, input_mask=input_mask, train=train
)
nll = -pdf.log_prob(sequence)
self.assertFalse(np.any(np.isnan(nll)))
if multivariate:
self.assertEqual(
logits.shape, (_BATCH_SIZE, _SEQ_LEN, _OUT_DIM**2 + _OUT_DIM)
)
self.assertEqual(nll.shape, (_BATCH_SIZE, _SEQ_LEN))
elif per_channel_mixtures:
self.assertEqual(
logits.shape,
(_BATCH_SIZE, _SEQ_LEN, 3 * _NUM_MIXTURES * _OUT_DIM),
)
self.assertEqual(nll.shape, (_BATCH_SIZE, _SEQ_LEN, _OUT_DIM))
else:
self.assertEqual(
logits.shape,
(_BATCH_SIZE, _SEQ_LEN, _NUM_MIXTURES + _NUM_MIXTURES * _OUT_DIM * 2),
)
self.assertEqual(nll.shape, (_BATCH_SIZE, _SEQ_LEN))
sample = pdf.sample(seed=jax.random.PRNGKey(0))
self.assertEqual(sample.shape, (_BATCH_SIZE, _SEQ_LEN, _OUT_DIM))
if __name__ == "__main__":
googletest.main()