Spaces:
Running
Running
# Copyright 2024 The TensorFlow Authors. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""Tests for third_party.tensorflow_models.official.nlp.data.classifier_data_lib.""" | |
import os | |
import tempfile | |
from absl.testing import parameterized | |
import tensorflow as tf, tf_keras | |
import tensorflow_datasets as tfds | |
from official.nlp.data import classifier_data_lib | |
from official.nlp.tools import tokenization | |
def decode_record(record, name_to_features): | |
"""Decodes a record to a TensorFlow example.""" | |
return tf.io.parse_single_example(record, name_to_features) | |
class BertClassifierLibTest(tf.test.TestCase, parameterized.TestCase): | |
def setUp(self): | |
super(BertClassifierLibTest, self).setUp() | |
self.model_dir = self.get_temp_dir() | |
self.processors = { | |
"CB": classifier_data_lib.CBProcessor, | |
"SUPERGLUE-RTE": classifier_data_lib.SuperGLUERTEProcessor, | |
"BOOLQ": classifier_data_lib.BoolQProcessor, | |
"WIC": classifier_data_lib.WiCProcessor, | |
} | |
vocab_tokens = [ | |
"[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn", | |
"##ing", "," | |
] | |
with tempfile.NamedTemporaryFile(delete=False) as vocab_writer: | |
vocab_writer.write("".join([x + "\n" for x in vocab_tokens | |
]).encode("utf-8")) | |
vocab_file = vocab_writer.name | |
self.tokenizer = tokenization.FullTokenizer(vocab_file) | |
def test_generate_dataset_from_tfds_processor(self, task_type): | |
with tfds.testing.mock_data(num_examples=5): | |
output_path = os.path.join(self.model_dir, task_type) | |
processor = self.processors[task_type]() | |
classifier_data_lib.generate_tf_record_from_data_file( | |
processor, | |
None, | |
self.tokenizer, | |
train_data_output_path=output_path, | |
eval_data_output_path=output_path, | |
test_data_output_path=output_path) | |
files = tf.io.gfile.glob(output_path) | |
self.assertNotEmpty(files) | |
train_dataset = tf.data.TFRecordDataset(output_path) | |
seq_length = 128 | |
label_type = tf.int64 | |
name_to_features = { | |
"input_ids": tf.io.FixedLenFeature([seq_length], tf.int64), | |
"input_mask": tf.io.FixedLenFeature([seq_length], tf.int64), | |
"segment_ids": tf.io.FixedLenFeature([seq_length], tf.int64), | |
"label_ids": tf.io.FixedLenFeature([], label_type), | |
} | |
train_dataset = train_dataset.map( | |
lambda record: decode_record(record, name_to_features)) | |
# If data is retrieved without error, then all requirements | |
# including data type/shapes are met. | |
_ = next(iter(train_dataset)) | |
if __name__ == "__main__": | |
tf.test.main() | |