Spaces:
Running
Running
Pradeep Kumar
commited on
Delete sentence_prediction_dataloader_test.py
Browse files
sentence_prediction_dataloader_test.py
DELETED
@@ -1,290 +0,0 @@
|
|
1 |
-
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
|
2 |
-
#
|
3 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
-
# you may not use this file except in compliance with the License.
|
5 |
-
# You may obtain a copy of the License at
|
6 |
-
#
|
7 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
-
#
|
9 |
-
# Unless required by applicable law or agreed to in writing, software
|
10 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
-
# See the License for the specific language governing permissions and
|
13 |
-
# limitations under the License.
|
14 |
-
|
15 |
-
"""Tests for official.nlp.data.sentence_prediction_dataloader."""
|
16 |
-
import os
|
17 |
-
|
18 |
-
from absl.testing import parameterized
|
19 |
-
import numpy as np
|
20 |
-
import tensorflow as tf, tf_keras
|
21 |
-
|
22 |
-
from sentencepiece import SentencePieceTrainer
|
23 |
-
from official.nlp.data import sentence_prediction_dataloader as loader
|
24 |
-
|
25 |
-
|
26 |
-
def _create_fake_preprocessed_dataset(output_path, seq_length, label_type):
|
27 |
-
"""Creates a fake dataset."""
|
28 |
-
writer = tf.io.TFRecordWriter(output_path)
|
29 |
-
|
30 |
-
def create_int_feature(values):
|
31 |
-
f = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values)))
|
32 |
-
return f
|
33 |
-
|
34 |
-
def create_float_feature(values):
|
35 |
-
f = tf.train.Feature(float_list=tf.train.FloatList(value=list(values)))
|
36 |
-
return f
|
37 |
-
|
38 |
-
for _ in range(100):
|
39 |
-
features = {}
|
40 |
-
input_ids = np.random.randint(100, size=(seq_length))
|
41 |
-
features['input_ids'] = create_int_feature(input_ids)
|
42 |
-
features['input_mask'] = create_int_feature(np.ones_like(input_ids))
|
43 |
-
features['segment_ids'] = create_int_feature(np.ones_like(input_ids))
|
44 |
-
|
45 |
-
if label_type == 'int':
|
46 |
-
features['label_ids'] = create_int_feature([1])
|
47 |
-
elif label_type == 'float':
|
48 |
-
features['label_ids'] = create_float_feature([0.5])
|
49 |
-
else:
|
50 |
-
raise ValueError('Unsupported label_type: %s' % label_type)
|
51 |
-
|
52 |
-
tf_example = tf.train.Example(features=tf.train.Features(feature=features))
|
53 |
-
writer.write(tf_example.SerializeToString())
|
54 |
-
writer.close()
|
55 |
-
|
56 |
-
|
57 |
-
def _create_fake_raw_dataset(output_path, text_fields, label_type):
|
58 |
-
"""Creates a fake tf record file."""
|
59 |
-
writer = tf.io.TFRecordWriter(output_path)
|
60 |
-
|
61 |
-
def create_str_feature(value):
|
62 |
-
f = tf.train.Feature(bytes_list=tf.train.BytesList(value=value))
|
63 |
-
return f
|
64 |
-
|
65 |
-
def create_int_feature(values):
|
66 |
-
f = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values)))
|
67 |
-
return f
|
68 |
-
|
69 |
-
def create_float_feature(values):
|
70 |
-
f = tf.train.Feature(float_list=tf.train.FloatList(value=list(values)))
|
71 |
-
return f
|
72 |
-
|
73 |
-
for _ in range(100):
|
74 |
-
features = {}
|
75 |
-
for text_field in text_fields:
|
76 |
-
features[text_field] = create_str_feature([b'hello world'])
|
77 |
-
|
78 |
-
if label_type == 'int':
|
79 |
-
features['label'] = create_int_feature([0])
|
80 |
-
elif label_type == 'float':
|
81 |
-
features['label'] = create_float_feature([0.5])
|
82 |
-
else:
|
83 |
-
raise ValueError('Unexpected label_type: %s' % label_type)
|
84 |
-
tf_example = tf.train.Example(features=tf.train.Features(feature=features))
|
85 |
-
writer.write(tf_example.SerializeToString())
|
86 |
-
writer.close()
|
87 |
-
|
88 |
-
|
89 |
-
def _create_fake_sentencepiece_model(output_dir):
|
90 |
-
vocab = ['a', 'b', 'c', 'd', 'e', 'abc', 'def', 'ABC', 'DEF']
|
91 |
-
model_prefix = os.path.join(output_dir, 'spm_model')
|
92 |
-
input_text_file_path = os.path.join(output_dir, 'train_input.txt')
|
93 |
-
with tf.io.gfile.GFile(input_text_file_path, 'w') as f:
|
94 |
-
f.write(' '.join(vocab + ['\n']))
|
95 |
-
# Add 7 more tokens: <pad>, <unk>, [CLS], [SEP], [MASK], <s>, </s>.
|
96 |
-
full_vocab_size = len(vocab) + 7
|
97 |
-
flags = dict(
|
98 |
-
model_prefix=model_prefix,
|
99 |
-
model_type='word',
|
100 |
-
input=input_text_file_path,
|
101 |
-
pad_id=0,
|
102 |
-
unk_id=1,
|
103 |
-
control_symbols='[CLS],[SEP],[MASK]',
|
104 |
-
vocab_size=full_vocab_size,
|
105 |
-
bos_id=full_vocab_size - 2,
|
106 |
-
eos_id=full_vocab_size - 1)
|
107 |
-
SentencePieceTrainer.Train(' '.join(
|
108 |
-
['--{}={}'.format(k, v) for k, v in flags.items()]))
|
109 |
-
return model_prefix + '.model'
|
110 |
-
|
111 |
-
|
112 |
-
def _create_fake_vocab_file(vocab_file_path):
|
113 |
-
tokens = ['[PAD]']
|
114 |
-
for i in range(1, 100):
|
115 |
-
tokens.append('[unused%d]' % i)
|
116 |
-
tokens.extend(['[UNK]', '[CLS]', '[SEP]', '[MASK]', 'hello', 'world'])
|
117 |
-
with tf.io.gfile.GFile(vocab_file_path, 'w') as outfile:
|
118 |
-
outfile.write('\n'.join(tokens))
|
119 |
-
|
120 |
-
|
121 |
-
class SentencePredictionDataTest(tf.test.TestCase, parameterized.TestCase):
|
122 |
-
|
123 |
-
@parameterized.parameters(('int', tf.int32), ('float', tf.float32))
|
124 |
-
def test_load_dataset(self, label_type, expected_label_type):
|
125 |
-
input_path = os.path.join(self.get_temp_dir(), 'train.tf_record')
|
126 |
-
batch_size = 10
|
127 |
-
seq_length = 128
|
128 |
-
_create_fake_preprocessed_dataset(input_path, seq_length, label_type)
|
129 |
-
data_config = loader.SentencePredictionDataConfig(
|
130 |
-
input_path=input_path,
|
131 |
-
seq_length=seq_length,
|
132 |
-
global_batch_size=batch_size,
|
133 |
-
label_type=label_type)
|
134 |
-
dataset = loader.SentencePredictionDataLoader(data_config).load()
|
135 |
-
features = next(iter(dataset))
|
136 |
-
self.assertCountEqual(
|
137 |
-
['input_word_ids', 'input_type_ids', 'input_mask', 'label_ids'],
|
138 |
-
features.keys())
|
139 |
-
self.assertEqual(features['input_word_ids'].shape, (batch_size, seq_length))
|
140 |
-
self.assertEqual(features['input_mask'].shape, (batch_size, seq_length))
|
141 |
-
self.assertEqual(features['input_type_ids'].shape, (batch_size, seq_length))
|
142 |
-
self.assertEqual(features['label_ids'].shape, (batch_size,))
|
143 |
-
self.assertEqual(features['label_ids'].dtype, expected_label_type)
|
144 |
-
|
145 |
-
def test_load_dataset_with_label_mapping(self):
|
146 |
-
input_path = os.path.join(self.get_temp_dir(), 'train.tf_record')
|
147 |
-
batch_size = 10
|
148 |
-
seq_length = 128
|
149 |
-
_create_fake_preprocessed_dataset(input_path, seq_length, 'int')
|
150 |
-
data_config = loader.SentencePredictionDataConfig(
|
151 |
-
input_path=input_path,
|
152 |
-
seq_length=seq_length,
|
153 |
-
global_batch_size=batch_size,
|
154 |
-
label_type='int',
|
155 |
-
label_name=('label_ids', 'next_sentence_labels'))
|
156 |
-
dataset = loader.SentencePredictionDataLoader(data_config).load()
|
157 |
-
features = next(iter(dataset))
|
158 |
-
self.assertCountEqual([
|
159 |
-
'input_word_ids', 'input_mask', 'input_type_ids',
|
160 |
-
'next_sentence_labels', 'label_ids'
|
161 |
-
], features.keys())
|
162 |
-
self.assertEqual(features['input_word_ids'].shape, (batch_size, seq_length))
|
163 |
-
self.assertEqual(features['input_mask'].shape, (batch_size, seq_length))
|
164 |
-
self.assertEqual(features['input_type_ids'].shape, (batch_size, seq_length))
|
165 |
-
self.assertEqual(features['label_ids'].shape, (batch_size,))
|
166 |
-
self.assertEqual(features['label_ids'].dtype, tf.int32)
|
167 |
-
self.assertEqual(features['next_sentence_labels'].shape, (batch_size,))
|
168 |
-
self.assertEqual(features['next_sentence_labels'].dtype, tf.int32)
|
169 |
-
|
170 |
-
|
171 |
-
class SentencePredictionTfdsDataLoaderTest(tf.test.TestCase,
|
172 |
-
parameterized.TestCase):
|
173 |
-
|
174 |
-
@parameterized.parameters(True, False)
|
175 |
-
def test_python_wordpiece_preprocessing(self, use_tfds):
|
176 |
-
batch_size = 10
|
177 |
-
seq_length = 256 # Non-default value.
|
178 |
-
lower_case = True
|
179 |
-
|
180 |
-
tf_record_path = os.path.join(self.get_temp_dir(), 'train.tf_record')
|
181 |
-
text_fields = ['sentence1', 'sentence2']
|
182 |
-
if not use_tfds:
|
183 |
-
_create_fake_raw_dataset(tf_record_path, text_fields, label_type='int')
|
184 |
-
|
185 |
-
vocab_file_path = os.path.join(self.get_temp_dir(), 'vocab.txt')
|
186 |
-
_create_fake_vocab_file(vocab_file_path)
|
187 |
-
|
188 |
-
data_config = loader.SentencePredictionTextDataConfig(
|
189 |
-
input_path='' if use_tfds else tf_record_path,
|
190 |
-
tfds_name='glue/mrpc' if use_tfds else '',
|
191 |
-
tfds_split='train' if use_tfds else '',
|
192 |
-
text_fields=text_fields,
|
193 |
-
global_batch_size=batch_size,
|
194 |
-
seq_length=seq_length,
|
195 |
-
is_training=True,
|
196 |
-
lower_case=lower_case,
|
197 |
-
vocab_file=vocab_file_path)
|
198 |
-
dataset = loader.SentencePredictionTextDataLoader(data_config).load()
|
199 |
-
features = next(iter(dataset))
|
200 |
-
label_field = data_config.label_field
|
201 |
-
expected_keys = [
|
202 |
-
'input_word_ids', 'input_type_ids', 'input_mask', label_field
|
203 |
-
]
|
204 |
-
if use_tfds:
|
205 |
-
expected_keys += ['idx']
|
206 |
-
self.assertCountEqual(expected_keys, features.keys())
|
207 |
-
self.assertEqual(features['input_word_ids'].shape, (batch_size, seq_length))
|
208 |
-
self.assertEqual(features['input_mask'].shape, (batch_size, seq_length))
|
209 |
-
self.assertEqual(features['input_type_ids'].shape, (batch_size, seq_length))
|
210 |
-
self.assertEqual(features[label_field].shape, (batch_size,))
|
211 |
-
|
212 |
-
@parameterized.parameters(True, False)
|
213 |
-
def test_python_sentencepiece_preprocessing(self, use_tfds):
|
214 |
-
batch_size = 10
|
215 |
-
seq_length = 256 # Non-default value.
|
216 |
-
lower_case = True
|
217 |
-
|
218 |
-
tf_record_path = os.path.join(self.get_temp_dir(), 'train.tf_record')
|
219 |
-
text_fields = ['sentence1', 'sentence2']
|
220 |
-
if not use_tfds:
|
221 |
-
_create_fake_raw_dataset(tf_record_path, text_fields, label_type='int')
|
222 |
-
|
223 |
-
sp_model_file_path = _create_fake_sentencepiece_model(self.get_temp_dir())
|
224 |
-
data_config = loader.SentencePredictionTextDataConfig(
|
225 |
-
input_path='' if use_tfds else tf_record_path,
|
226 |
-
tfds_name='glue/mrpc' if use_tfds else '',
|
227 |
-
tfds_split='train' if use_tfds else '',
|
228 |
-
text_fields=text_fields,
|
229 |
-
global_batch_size=batch_size,
|
230 |
-
seq_length=seq_length,
|
231 |
-
is_training=True,
|
232 |
-
lower_case=lower_case,
|
233 |
-
tokenization='SentencePiece',
|
234 |
-
vocab_file=sp_model_file_path,
|
235 |
-
)
|
236 |
-
dataset = loader.SentencePredictionTextDataLoader(data_config).load()
|
237 |
-
features = next(iter(dataset))
|
238 |
-
label_field = data_config.label_field
|
239 |
-
expected_keys = [
|
240 |
-
'input_word_ids', 'input_type_ids', 'input_mask', label_field
|
241 |
-
]
|
242 |
-
if use_tfds:
|
243 |
-
expected_keys += ['idx']
|
244 |
-
self.assertCountEqual(expected_keys, features.keys())
|
245 |
-
self.assertEqual(features['input_word_ids'].shape, (batch_size, seq_length))
|
246 |
-
self.assertEqual(features['input_mask'].shape, (batch_size, seq_length))
|
247 |
-
self.assertEqual(features['input_type_ids'].shape, (batch_size, seq_length))
|
248 |
-
self.assertEqual(features[label_field].shape, (batch_size,))
|
249 |
-
|
250 |
-
@parameterized.parameters(True, False)
|
251 |
-
def test_saved_model_preprocessing(self, use_tfds):
|
252 |
-
batch_size = 10
|
253 |
-
seq_length = 256 # Non-default value.
|
254 |
-
|
255 |
-
tf_record_path = os.path.join(self.get_temp_dir(), 'train.tf_record')
|
256 |
-
text_fields = ['sentence1', 'sentence2']
|
257 |
-
if not use_tfds:
|
258 |
-
_create_fake_raw_dataset(tf_record_path, text_fields, label_type='float')
|
259 |
-
|
260 |
-
vocab_file_path = os.path.join(self.get_temp_dir(), 'vocab.txt')
|
261 |
-
_create_fake_vocab_file(vocab_file_path)
|
262 |
-
data_config = loader.SentencePredictionTextDataConfig(
|
263 |
-
input_path='' if use_tfds else tf_record_path,
|
264 |
-
tfds_name='glue/mrpc' if use_tfds else '',
|
265 |
-
tfds_split='train' if use_tfds else '',
|
266 |
-
text_fields=text_fields,
|
267 |
-
global_batch_size=batch_size,
|
268 |
-
seq_length=seq_length,
|
269 |
-
is_training=True,
|
270 |
-
preprocessing_hub_module_url=(
|
271 |
-
'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3'),
|
272 |
-
label_type='int' if use_tfds else 'float',
|
273 |
-
)
|
274 |
-
dataset = loader.SentencePredictionTextDataLoader(data_config).load()
|
275 |
-
features = next(iter(dataset))
|
276 |
-
label_field = data_config.label_field
|
277 |
-
expected_keys = [
|
278 |
-
'input_word_ids', 'input_type_ids', 'input_mask', label_field
|
279 |
-
]
|
280 |
-
if use_tfds:
|
281 |
-
expected_keys += ['idx']
|
282 |
-
self.assertCountEqual(expected_keys, features.keys())
|
283 |
-
self.assertEqual(features['input_word_ids'].shape, (batch_size, seq_length))
|
284 |
-
self.assertEqual(features['input_mask'].shape, (batch_size, seq_length))
|
285 |
-
self.assertEqual(features['input_type_ids'].shape, (batch_size, seq_length))
|
286 |
-
self.assertEqual(features[label_field].shape, (batch_size,))
|
287 |
-
|
288 |
-
|
289 |
-
if __name__ == '__main__':
|
290 |
-
tf.test.main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|