File size: 4,502 Bytes
74e8f2f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 |
# Copyright 2024 Big Vision Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from absl.testing import parameterized
from big_vision.models.proj.givt import givt
from big_vision.models.proj.givt import parallel_decode
import chex
import jax
import jax.numpy as jnp
from absl.testing import absltest
_BATCH_SIZE = 2
_OUT_DIM = 4
_SEQ_LEN = 6
_NUM_MIXTURES = 4
def _make_test_model(**overwrites):
config = dict(
num_heads=2,
num_decoder_layers=1,
mlp_dim=64,
emb_dim=16,
seq_len=_SEQ_LEN,
out_dim=_OUT_DIM,
num_mixtures=_NUM_MIXTURES,
style="masked",
)
config.update(overwrites)
return givt.Model(**config)
def _mask(*flags):
return jnp.asarray(flags).astype(jnp.bool_)
class HelperTest(googletest.TestCase):
def test_get_first_n(self):
with self.subTest("ordered"):
values = jnp.asarray([4, 3, 2, 1, 0])
k = jnp.asarray([3], jnp.int32)
chex.assert_trees_all_equal(
parallel_decode._get_bottom_k_mask(values, k), _mask(0, 0, 1, 1, 1)
)
with self.subTest("equal_values"):
values = jnp.ones((5,))
k = jnp.asarray([3], jnp.int32)
chex.assert_trees_all_equal(
parallel_decode._get_bottom_k_mask(values, k), _mask(1, 1, 1, 0, 0)
)
with self.subTest("equal_values"):
values = jnp.asarray([1, 2, 2, 2, 3])
k = jnp.asarray([3], jnp.int32)
chex.assert_trees_all_equal(
parallel_decode._get_bottom_k_mask(values, k), _mask(1, 1, 1, 0, 0)
)
class ParallelDecodeTest(parameterized.TestCase):
def _make_model(self, **overwrites):
model = _make_test_model(**overwrites)
sequence = jax.random.uniform(
jax.random.PRNGKey(0), (_BATCH_SIZE, _SEQ_LEN, _OUT_DIM)
)
labels = jax.random.uniform(
jax.random.PRNGKey(0), (_BATCH_SIZE,), maxval=10
).astype(jnp.int32)
input_mask = jax.random.uniform(
jax.random.PRNGKey(0), (_BATCH_SIZE, _SEQ_LEN)
).astype(jnp.bool_)
variables = model.init(
jax.random.PRNGKey(0),
sequence,
labels,
input_mask=input_mask,
train=False,
)
return model, variables
def _test_model(self, rng, model, variables, config):
labels = jnp.ones((_BATCH_SIZE,), dtype=jnp.int32)
state = parallel_decode.decode_masked(
rng,
seq_len=_SEQ_LEN,
feature_dim=_OUT_DIM,
labels=labels,
model=model,
variables=variables,
config=config,
)
self.assertEqual(int(state.step), 4)
# Each point uncovered exactly once.
chex.assert_trees_all_equal(
state.uncovered_per_step.sum(0),
jnp.ones((_BATCH_SIZE, _SEQ_LEN), dtype=jnp.int32),
)
@parameterized.product(
rng_seed=[1, 2],
choice_temperature=[1.0, 4.0],
multivariate=[True, False],
)
def test_decode_masked(self, rng_seed, choice_temperature, multivariate):
rng = jax.random.PRNGKey(rng_seed)
model, variables = self._make_model(
num_mixtures=1 if multivariate else _NUM_MIXTURES,
multivariate=multivariate,
)
config = parallel_decode.MaskedGenerationConfig(
num_steps=4,
choice_temperature=choice_temperature,
)
self._test_model(rng, model, variables, config)
@parameterized.product(
rng_seed=[1, 2],
choice_temperature=[1.0, 4.0],
w=[0.0, 1.0, 3.0],
per_channel_mixtures=[True, False],
)
def test_cfg(self, rng_seed, choice_temperature, w, per_channel_mixtures):
rng = jax.random.PRNGKey(rng_seed)
model, variables = self._make_model(
num_mixtures=1 if per_channel_mixtures else 3,
drop_labels_probability=0.1,
per_channel_mixtures=per_channel_mixtures,
)
config = parallel_decode.MaskedGenerationConfig(
num_steps=4,
choice_temperature=choice_temperature,
cfg_inference_weight=w,
)
self._test_model(rng, model, variables, config)
if __name__ == "__main__":
googletest.main()
|