Pradeep Kumar commited on
Commit
82a3db7
·
verified ·
1 Parent(s): 31f84b4

Delete sentence_prediction_dataloader_test.py

Browse files
sentence_prediction_dataloader_test.py DELETED
@@ -1,290 +0,0 @@
1
- # Copyright 2024 The TensorFlow Authors. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- """Tests for official.nlp.data.sentence_prediction_dataloader."""
16
- import os
17
-
18
- from absl.testing import parameterized
19
- import numpy as np
20
- import tensorflow as tf, tf_keras
21
-
22
- from sentencepiece import SentencePieceTrainer
23
- from official.nlp.data import sentence_prediction_dataloader as loader
24
-
25
-
26
- def _create_fake_preprocessed_dataset(output_path, seq_length, label_type):
27
- """Creates a fake dataset."""
28
- writer = tf.io.TFRecordWriter(output_path)
29
-
30
- def create_int_feature(values):
31
- f = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values)))
32
- return f
33
-
34
- def create_float_feature(values):
35
- f = tf.train.Feature(float_list=tf.train.FloatList(value=list(values)))
36
- return f
37
-
38
- for _ in range(100):
39
- features = {}
40
- input_ids = np.random.randint(100, size=(seq_length))
41
- features['input_ids'] = create_int_feature(input_ids)
42
- features['input_mask'] = create_int_feature(np.ones_like(input_ids))
43
- features['segment_ids'] = create_int_feature(np.ones_like(input_ids))
44
-
45
- if label_type == 'int':
46
- features['label_ids'] = create_int_feature([1])
47
- elif label_type == 'float':
48
- features['label_ids'] = create_float_feature([0.5])
49
- else:
50
- raise ValueError('Unsupported label_type: %s' % label_type)
51
-
52
- tf_example = tf.train.Example(features=tf.train.Features(feature=features))
53
- writer.write(tf_example.SerializeToString())
54
- writer.close()
55
-
56
-
57
- def _create_fake_raw_dataset(output_path, text_fields, label_type):
58
- """Creates a fake tf record file."""
59
- writer = tf.io.TFRecordWriter(output_path)
60
-
61
- def create_str_feature(value):
62
- f = tf.train.Feature(bytes_list=tf.train.BytesList(value=value))
63
- return f
64
-
65
- def create_int_feature(values):
66
- f = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values)))
67
- return f
68
-
69
- def create_float_feature(values):
70
- f = tf.train.Feature(float_list=tf.train.FloatList(value=list(values)))
71
- return f
72
-
73
- for _ in range(100):
74
- features = {}
75
- for text_field in text_fields:
76
- features[text_field] = create_str_feature([b'hello world'])
77
-
78
- if label_type == 'int':
79
- features['label'] = create_int_feature([0])
80
- elif label_type == 'float':
81
- features['label'] = create_float_feature([0.5])
82
- else:
83
- raise ValueError('Unexpected label_type: %s' % label_type)
84
- tf_example = tf.train.Example(features=tf.train.Features(feature=features))
85
- writer.write(tf_example.SerializeToString())
86
- writer.close()
87
-
88
-
89
- def _create_fake_sentencepiece_model(output_dir):
90
- vocab = ['a', 'b', 'c', 'd', 'e', 'abc', 'def', 'ABC', 'DEF']
91
- model_prefix = os.path.join(output_dir, 'spm_model')
92
- input_text_file_path = os.path.join(output_dir, 'train_input.txt')
93
- with tf.io.gfile.GFile(input_text_file_path, 'w') as f:
94
- f.write(' '.join(vocab + ['\n']))
95
- # Add 7 more tokens: <pad>, <unk>, [CLS], [SEP], [MASK], <s>, </s>.
96
- full_vocab_size = len(vocab) + 7
97
- flags = dict(
98
- model_prefix=model_prefix,
99
- model_type='word',
100
- input=input_text_file_path,
101
- pad_id=0,
102
- unk_id=1,
103
- control_symbols='[CLS],[SEP],[MASK]',
104
- vocab_size=full_vocab_size,
105
- bos_id=full_vocab_size - 2,
106
- eos_id=full_vocab_size - 1)
107
- SentencePieceTrainer.Train(' '.join(
108
- ['--{}={}'.format(k, v) for k, v in flags.items()]))
109
- return model_prefix + '.model'
110
-
111
-
112
- def _create_fake_vocab_file(vocab_file_path):
113
- tokens = ['[PAD]']
114
- for i in range(1, 100):
115
- tokens.append('[unused%d]' % i)
116
- tokens.extend(['[UNK]', '[CLS]', '[SEP]', '[MASK]', 'hello', 'world'])
117
- with tf.io.gfile.GFile(vocab_file_path, 'w') as outfile:
118
- outfile.write('\n'.join(tokens))
119
-
120
-
121
- class SentencePredictionDataTest(tf.test.TestCase, parameterized.TestCase):
122
-
123
- @parameterized.parameters(('int', tf.int32), ('float', tf.float32))
124
- def test_load_dataset(self, label_type, expected_label_type):
125
- input_path = os.path.join(self.get_temp_dir(), 'train.tf_record')
126
- batch_size = 10
127
- seq_length = 128
128
- _create_fake_preprocessed_dataset(input_path, seq_length, label_type)
129
- data_config = loader.SentencePredictionDataConfig(
130
- input_path=input_path,
131
- seq_length=seq_length,
132
- global_batch_size=batch_size,
133
- label_type=label_type)
134
- dataset = loader.SentencePredictionDataLoader(data_config).load()
135
- features = next(iter(dataset))
136
- self.assertCountEqual(
137
- ['input_word_ids', 'input_type_ids', 'input_mask', 'label_ids'],
138
- features.keys())
139
- self.assertEqual(features['input_word_ids'].shape, (batch_size, seq_length))
140
- self.assertEqual(features['input_mask'].shape, (batch_size, seq_length))
141
- self.assertEqual(features['input_type_ids'].shape, (batch_size, seq_length))
142
- self.assertEqual(features['label_ids'].shape, (batch_size,))
143
- self.assertEqual(features['label_ids'].dtype, expected_label_type)
144
-
145
- def test_load_dataset_with_label_mapping(self):
146
- input_path = os.path.join(self.get_temp_dir(), 'train.tf_record')
147
- batch_size = 10
148
- seq_length = 128
149
- _create_fake_preprocessed_dataset(input_path, seq_length, 'int')
150
- data_config = loader.SentencePredictionDataConfig(
151
- input_path=input_path,
152
- seq_length=seq_length,
153
- global_batch_size=batch_size,
154
- label_type='int',
155
- label_name=('label_ids', 'next_sentence_labels'))
156
- dataset = loader.SentencePredictionDataLoader(data_config).load()
157
- features = next(iter(dataset))
158
- self.assertCountEqual([
159
- 'input_word_ids', 'input_mask', 'input_type_ids',
160
- 'next_sentence_labels', 'label_ids'
161
- ], features.keys())
162
- self.assertEqual(features['input_word_ids'].shape, (batch_size, seq_length))
163
- self.assertEqual(features['input_mask'].shape, (batch_size, seq_length))
164
- self.assertEqual(features['input_type_ids'].shape, (batch_size, seq_length))
165
- self.assertEqual(features['label_ids'].shape, (batch_size,))
166
- self.assertEqual(features['label_ids'].dtype, tf.int32)
167
- self.assertEqual(features['next_sentence_labels'].shape, (batch_size,))
168
- self.assertEqual(features['next_sentence_labels'].dtype, tf.int32)
169
-
170
-
171
- class SentencePredictionTfdsDataLoaderTest(tf.test.TestCase,
172
- parameterized.TestCase):
173
-
174
- @parameterized.parameters(True, False)
175
- def test_python_wordpiece_preprocessing(self, use_tfds):
176
- batch_size = 10
177
- seq_length = 256 # Non-default value.
178
- lower_case = True
179
-
180
- tf_record_path = os.path.join(self.get_temp_dir(), 'train.tf_record')
181
- text_fields = ['sentence1', 'sentence2']
182
- if not use_tfds:
183
- _create_fake_raw_dataset(tf_record_path, text_fields, label_type='int')
184
-
185
- vocab_file_path = os.path.join(self.get_temp_dir(), 'vocab.txt')
186
- _create_fake_vocab_file(vocab_file_path)
187
-
188
- data_config = loader.SentencePredictionTextDataConfig(
189
- input_path='' if use_tfds else tf_record_path,
190
- tfds_name='glue/mrpc' if use_tfds else '',
191
- tfds_split='train' if use_tfds else '',
192
- text_fields=text_fields,
193
- global_batch_size=batch_size,
194
- seq_length=seq_length,
195
- is_training=True,
196
- lower_case=lower_case,
197
- vocab_file=vocab_file_path)
198
- dataset = loader.SentencePredictionTextDataLoader(data_config).load()
199
- features = next(iter(dataset))
200
- label_field = data_config.label_field
201
- expected_keys = [
202
- 'input_word_ids', 'input_type_ids', 'input_mask', label_field
203
- ]
204
- if use_tfds:
205
- expected_keys += ['idx']
206
- self.assertCountEqual(expected_keys, features.keys())
207
- self.assertEqual(features['input_word_ids'].shape, (batch_size, seq_length))
208
- self.assertEqual(features['input_mask'].shape, (batch_size, seq_length))
209
- self.assertEqual(features['input_type_ids'].shape, (batch_size, seq_length))
210
- self.assertEqual(features[label_field].shape, (batch_size,))
211
-
212
- @parameterized.parameters(True, False)
213
- def test_python_sentencepiece_preprocessing(self, use_tfds):
214
- batch_size = 10
215
- seq_length = 256 # Non-default value.
216
- lower_case = True
217
-
218
- tf_record_path = os.path.join(self.get_temp_dir(), 'train.tf_record')
219
- text_fields = ['sentence1', 'sentence2']
220
- if not use_tfds:
221
- _create_fake_raw_dataset(tf_record_path, text_fields, label_type='int')
222
-
223
- sp_model_file_path = _create_fake_sentencepiece_model(self.get_temp_dir())
224
- data_config = loader.SentencePredictionTextDataConfig(
225
- input_path='' if use_tfds else tf_record_path,
226
- tfds_name='glue/mrpc' if use_tfds else '',
227
- tfds_split='train' if use_tfds else '',
228
- text_fields=text_fields,
229
- global_batch_size=batch_size,
230
- seq_length=seq_length,
231
- is_training=True,
232
- lower_case=lower_case,
233
- tokenization='SentencePiece',
234
- vocab_file=sp_model_file_path,
235
- )
236
- dataset = loader.SentencePredictionTextDataLoader(data_config).load()
237
- features = next(iter(dataset))
238
- label_field = data_config.label_field
239
- expected_keys = [
240
- 'input_word_ids', 'input_type_ids', 'input_mask', label_field
241
- ]
242
- if use_tfds:
243
- expected_keys += ['idx']
244
- self.assertCountEqual(expected_keys, features.keys())
245
- self.assertEqual(features['input_word_ids'].shape, (batch_size, seq_length))
246
- self.assertEqual(features['input_mask'].shape, (batch_size, seq_length))
247
- self.assertEqual(features['input_type_ids'].shape, (batch_size, seq_length))
248
- self.assertEqual(features[label_field].shape, (batch_size,))
249
-
250
- @parameterized.parameters(True, False)
251
- def test_saved_model_preprocessing(self, use_tfds):
252
- batch_size = 10
253
- seq_length = 256 # Non-default value.
254
-
255
- tf_record_path = os.path.join(self.get_temp_dir(), 'train.tf_record')
256
- text_fields = ['sentence1', 'sentence2']
257
- if not use_tfds:
258
- _create_fake_raw_dataset(tf_record_path, text_fields, label_type='float')
259
-
260
- vocab_file_path = os.path.join(self.get_temp_dir(), 'vocab.txt')
261
- _create_fake_vocab_file(vocab_file_path)
262
- data_config = loader.SentencePredictionTextDataConfig(
263
- input_path='' if use_tfds else tf_record_path,
264
- tfds_name='glue/mrpc' if use_tfds else '',
265
- tfds_split='train' if use_tfds else '',
266
- text_fields=text_fields,
267
- global_batch_size=batch_size,
268
- seq_length=seq_length,
269
- is_training=True,
270
- preprocessing_hub_module_url=(
271
- 'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3'),
272
- label_type='int' if use_tfds else 'float',
273
- )
274
- dataset = loader.SentencePredictionTextDataLoader(data_config).load()
275
- features = next(iter(dataset))
276
- label_field = data_config.label_field
277
- expected_keys = [
278
- 'input_word_ids', 'input_type_ids', 'input_mask', label_field
279
- ]
280
- if use_tfds:
281
- expected_keys += ['idx']
282
- self.assertCountEqual(expected_keys, features.keys())
283
- self.assertEqual(features['input_word_ids'].shape, (batch_size, seq_length))
284
- self.assertEqual(features['input_mask'].shape, (batch_size, seq_length))
285
- self.assertEqual(features['input_type_ids'].shape, (batch_size, seq_length))
286
- self.assertEqual(features[label_field].shape, (batch_size,))
287
-
288
-
289
- if __name__ == '__main__':
290
- tf.test.main()