|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Tests for bert.""" |
|
|
|
import tempfile |
|
|
|
from big_vision import input_pipeline |
|
from big_vision.models.proj.flaxformer import bert |
|
from big_vision.models.proj.flaxformer import bert_test_util |
|
import big_vision.pp.builder as pp_builder |
|
import big_vision.pp.ops_general |
|
import big_vision.pp.proj.flaxformer.bert_ops |
|
import flax |
|
import jax |
|
import jax.numpy as jnp |
|
import tensorflow as tf |
|
|
|
|
|
|
|
_BERT_VOCAB = [ |
|
"[PAD]", |
|
"[UNK]", |
|
"this", |
|
"is", |
|
"a", |
|
"test", |
|
"[CLS]", |
|
"[SEP]", |
|
] |
|
_TOKEN_LEN = 16 |
|
|
|
|
|
class BertTest(tf.test.TestCase): |
|
|
|
def test_load_apply(self): |
|
inkey = "text" |
|
vocab_path = f"{tempfile.mkdtemp()}/vocab.txt" |
|
with open(vocab_path, "w") as f: |
|
f.write("\n".join(_BERT_VOCAB)) |
|
ds2, _ = input_pipeline.make_for_inference( |
|
tf.data.Dataset.from_tensor_slices( |
|
{inkey: tf.ragged.constant([["this is a test"]])}), |
|
num_ex_per_process=[1], |
|
preprocess_fn=pp_builder.get_preprocess_fn( |
|
f"bert_tokenize(inkey='{inkey}', vocab_path='{vocab_path}', " |
|
f"max_len={_TOKEN_LEN})" |
|
"|keep('labels')"), |
|
batch_size=1, |
|
) |
|
text = jnp.array(next(iter(ds2))["labels"]) |
|
model = bert.Model(config="base") |
|
variables = model.init(jax.random.PRNGKey(0), text) |
|
params = bert.load(flax.core.unfreeze(variables)["params"], |
|
bert_test_util.create_base_checkpoint()) |
|
x, out = model.apply({"params": params}, text) |
|
self.assertAllEqual(jax.tree_map(jnp.shape, x), (1, 768)) |
|
self.assertAllEqual( |
|
jax.tree_map(jnp.shape, out), { |
|
"transformed": (1, 16, 768), |
|
"pre_logits": (1, 768), |
|
}) |
|
|
|
|
|
if __name__ == "__main__": |
|
tf.test.main() |
|
|