|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Tests for vit vqvae model.""" |
|
from absl.testing import absltest |
|
|
|
from big_vision.models.proj.uvim import vit |
|
import jax |
|
import jax.numpy as jnp |
|
import ml_collections |
|
|
|
|
|
class ViTVQVAEModelTest(absltest.TestCase): |
|
|
|
def test_model(self): |
|
model_config = ml_collections.ConfigDict({ |
|
"input_size": (32, 32), |
|
"code_len": 4, |
|
"width": 16, |
|
"mlp_dim": 64, |
|
"num_heads": 4, |
|
"enc_depth": 1, |
|
"dec_depth": 1, |
|
"with_encoder_ctx": True, |
|
"with_decoder_ctx": True, |
|
"statistics_axis_name": None, |
|
"inputs": { |
|
"in1": (10, 3), |
|
"in2": (25,), |
|
}, |
|
"outputs": { |
|
"out1": (5,), |
|
"out2": (20,), |
|
}, |
|
}) |
|
|
|
model = vit.Model(**model_config) |
|
batch_size = 4 |
|
seq_len = (32 // 8) ** 2 |
|
x = { |
|
"in1": jnp.zeros((batch_size, seq_len, 10, 3)), |
|
"in2": jnp.zeros((batch_size, seq_len, 25)), |
|
} |
|
ctx_image = jnp.zeros((batch_size,) + model_config.input_size + (3,)) |
|
init_rngs = { |
|
"params": jax.random.PRNGKey(0), |
|
"state": jax.random.PRNGKey(1), |
|
} |
|
params = model.init(init_rngs, x, ctx=ctx_image) |
|
self.assertEqual(params.keys(), set(["params", "state"])) |
|
|
|
apply_rngs = { |
|
"dropout": jax.random.PRNGKey(0), |
|
"vqvae": jax.random.PRNGKey(0), |
|
} |
|
(logits, _), params = model.apply( |
|
params, x, ctx=ctx_image, train=True, update_dict=True, |
|
rngs=apply_rngs, mutable=["state"]) |
|
self.assertEqual(logits.keys(), set(["out1", "out2"])) |
|
self.assertEqual(logits["out1"].shape, (batch_size, seq_len, 5)) |
|
self.assertEqual(logits["out2"].shape, (batch_size, seq_len, 20)) |
|
|
|
|
|
if __name__ == "__main__": |
|
absltest.main() |
|
|