Pradeep Kumar commited on
Commit
7496baf
·
verified ·
1 Parent(s): a173dbb

Delete classifier_data_lib_test.py

Browse files
Files changed (1) hide show
  1. classifier_data_lib_test.py +0 -95
classifier_data_lib_test.py DELETED
@@ -1,95 +0,0 @@
1
- # Copyright 2024 The TensorFlow Authors. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- """Tests for third_party.tensorflow_models.official.nlp.data.classifier_data_lib."""
16
-
17
- import os
18
- import tempfile
19
-
20
- from absl.testing import parameterized
21
- import tensorflow as tf, tf_keras
22
- import tensorflow_datasets as tfds
23
-
24
- from official.nlp.data import classifier_data_lib
25
- from official.nlp.tools import tokenization
26
-
27
-
28
- def decode_record(record, name_to_features):
29
- """Decodes a record to a TensorFlow example."""
30
- return tf.io.parse_single_example(record, name_to_features)
31
-
32
-
33
- class BertClassifierLibTest(tf.test.TestCase, parameterized.TestCase):
34
-
35
- def setUp(self):
36
- super(BertClassifierLibTest, self).setUp()
37
- self.model_dir = self.get_temp_dir()
38
- self.processors = {
39
- "CB": classifier_data_lib.CBProcessor,
40
- "SUPERGLUE-RTE": classifier_data_lib.SuperGLUERTEProcessor,
41
- "BOOLQ": classifier_data_lib.BoolQProcessor,
42
- "WIC": classifier_data_lib.WiCProcessor,
43
- }
44
-
45
- vocab_tokens = [
46
- "[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn",
47
- "##ing", ","
48
- ]
49
- with tempfile.NamedTemporaryFile(delete=False) as vocab_writer:
50
- vocab_writer.write("".join([x + "\n" for x in vocab_tokens
51
- ]).encode("utf-8"))
52
- vocab_file = vocab_writer.name
53
- self.tokenizer = tokenization.FullTokenizer(vocab_file)
54
-
55
- @parameterized.parameters(
56
- {"task_type": "CB"},
57
- {"task_type": "BOOLQ"},
58
- {"task_type": "SUPERGLUE-RTE"},
59
- {"task_type": "WIC"},
60
- )
61
- def test_generate_dataset_from_tfds_processor(self, task_type):
62
- with tfds.testing.mock_data(num_examples=5):
63
- output_path = os.path.join(self.model_dir, task_type)
64
-
65
- processor = self.processors[task_type]()
66
-
67
- classifier_data_lib.generate_tf_record_from_data_file(
68
- processor,
69
- None,
70
- self.tokenizer,
71
- train_data_output_path=output_path,
72
- eval_data_output_path=output_path,
73
- test_data_output_path=output_path)
74
- files = tf.io.gfile.glob(output_path)
75
- self.assertNotEmpty(files)
76
-
77
- train_dataset = tf.data.TFRecordDataset(output_path)
78
- seq_length = 128
79
- label_type = tf.int64
80
- name_to_features = {
81
- "input_ids": tf.io.FixedLenFeature([seq_length], tf.int64),
82
- "input_mask": tf.io.FixedLenFeature([seq_length], tf.int64),
83
- "segment_ids": tf.io.FixedLenFeature([seq_length], tf.int64),
84
- "label_ids": tf.io.FixedLenFeature([], label_type),
85
- }
86
- train_dataset = train_dataset.map(
87
- lambda record: decode_record(record, name_to_features))
88
-
89
- # If data is retrieved without error, then all requirements
90
- # including data type/shapes are met.
91
- _ = next(iter(train_dataset))
92
-
93
-
94
- if __name__ == "__main__":
95
- tf.test.main()