|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Tests for bert_ops.""" |
|
|
|
import tempfile |
|
|
|
from big_vision import input_pipeline |
|
import big_vision.pp.builder as pp_builder |
|
import big_vision.pp.ops_general |
|
from big_vision.pp.proj.flaxformer import bert_ops |
|
import tensorflow as tf |
|
|
|
|
|
|
|
_BERT_VOCAB = [ |
|
"[PAD]", |
|
"[UNK]", |
|
"more", |
|
"than", |
|
"one", |
|
"[CLS]", |
|
"[SEP]", |
|
] |
|
|
|
|
|
def _create_ds(pp_str, tensor_slices, num_examples): |
|
return input_pipeline.make_for_inference( |
|
tf.data.Dataset.from_tensor_slices(tensor_slices), |
|
num_ex_per_process=[num_examples], |
|
preprocess_fn=pp_builder.get_preprocess_fn(pp_str), |
|
batch_size=num_examples, |
|
)[0] |
|
|
|
|
|
class BertOpsTest(tf.test.TestCase): |
|
|
|
def test_tokenize(self): |
|
inkey = "texts" |
|
vocab_path = f"{tempfile.mkdtemp()}/vocab.txt" |
|
with open(vocab_path, "w") as f: |
|
f.write("\n".join(_BERT_VOCAB)) |
|
pp_str = ( |
|
f"bert_tokenize(inkey='{inkey}', vocab_path='{vocab_path}', max_len=5)" |
|
f"|keep('labels')" |
|
) |
|
tensor_slices = { |
|
inkey: tf.ragged.constant([["one more"], ["more than one"], [""]]) |
|
} |
|
ds = _create_ds(pp_str, tensor_slices, 3) |
|
self.assertAllEqual( |
|
next(iter(ds))["labels"], |
|
[[5, 4, 2, 0, 0], [5, 2, 3, 4, 0], [5, 0, 0, 0, 0]], |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
tf.test.main() |
|
|