Pradeep Kumar commited on
Commit
f18e71f
β€’
1 Parent(s): c130734

Upload 33 files

Browse files
README.md CHANGED
@@ -1,13 +1,4 @@
1
- ---
2
- title: ISCO Code Predictor Api
3
- emoji: πŸ“‰
4
- colorFrom: green
5
- colorTo: indigo
6
- sdk: gradio
7
- sdk_version: 4.41.0
8
- app_file: app.py
9
- pinned: false
10
- license: apache-2.0
11
- ---
12
-
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
+ This directory contains binaries and utils required for input preprocessing,
2
+ tokenization, etc that can be used with model building blocks available in
3
+ NLP modeling library [nlp/modelling](https://github.com/tensorflow/models/tree/master/official/nlp/modeling)
4
+ to train custom models and validate new research ideas.
 
 
 
 
 
 
 
 
 
__init__.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The TensorFlow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
classifier_data_lib.py CHANGED
The diff for this file is too large to render. See raw diff
 
classifier_data_lib_test.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The TensorFlow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Tests for third_party.tensorflow_models.official.nlp.data.classifier_data_lib."""
16
+
17
+ import os
18
+ import tempfile
19
+
20
+ from absl.testing import parameterized
21
+ import tensorflow as tf, tf_keras
22
+ import tensorflow_datasets as tfds
23
+
24
+ from official.nlp.data import classifier_data_lib
25
+ from official.nlp.tools import tokenization
26
+
27
+
28
+ def decode_record(record, name_to_features):
29
+ """Decodes a record to a TensorFlow example."""
30
+ return tf.io.parse_single_example(record, name_to_features)
31
+
32
+
33
+ class BertClassifierLibTest(tf.test.TestCase, parameterized.TestCase):
34
+
35
+ def setUp(self):
36
+ super(BertClassifierLibTest, self).setUp()
37
+ self.model_dir = self.get_temp_dir()
38
+ self.processors = {
39
+ "CB": classifier_data_lib.CBProcessor,
40
+ "SUPERGLUE-RTE": classifier_data_lib.SuperGLUERTEProcessor,
41
+ "BOOLQ": classifier_data_lib.BoolQProcessor,
42
+ "WIC": classifier_data_lib.WiCProcessor,
43
+ }
44
+
45
+ vocab_tokens = [
46
+ "[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn",
47
+ "##ing", ","
48
+ ]
49
+ with tempfile.NamedTemporaryFile(delete=False) as vocab_writer:
50
+ vocab_writer.write("".join([x + "\n" for x in vocab_tokens
51
+ ]).encode("utf-8"))
52
+ vocab_file = vocab_writer.name
53
+ self.tokenizer = tokenization.FullTokenizer(vocab_file)
54
+
55
+ @parameterized.parameters(
56
+ {"task_type": "CB"},
57
+ {"task_type": "BOOLQ"},
58
+ {"task_type": "SUPERGLUE-RTE"},
59
+ {"task_type": "WIC"},
60
+ )
61
+ def test_generate_dataset_from_tfds_processor(self, task_type):
62
+ with tfds.testing.mock_data(num_examples=5):
63
+ output_path = os.path.join(self.model_dir, task_type)
64
+
65
+ processor = self.processors[task_type]()
66
+
67
+ classifier_data_lib.generate_tf_record_from_data_file(
68
+ processor,
69
+ None,
70
+ self.tokenizer,
71
+ train_data_output_path=output_path,
72
+ eval_data_output_path=output_path,
73
+ test_data_output_path=output_path)
74
+ files = tf.io.gfile.glob(output_path)
75
+ self.assertNotEmpty(files)
76
+
77
+ train_dataset = tf.data.TFRecordDataset(output_path)
78
+ seq_length = 128
79
+ label_type = tf.int64
80
+ name_to_features = {
81
+ "input_ids": tf.io.FixedLenFeature([seq_length], tf.int64),
82
+ "input_mask": tf.io.FixedLenFeature([seq_length], tf.int64),
83
+ "segment_ids": tf.io.FixedLenFeature([seq_length], tf.int64),
84
+ "label_ids": tf.io.FixedLenFeature([], label_type),
85
+ }
86
+ train_dataset = train_dataset.map(
87
+ lambda record: decode_record(record, name_to_features))
88
+
89
+ # If data is retrieved without error, then all requirements
90
+ # including data type/shapes are met.
91
+ _ = next(iter(train_dataset))
92
+
93
+
94
+ if __name__ == "__main__":
95
+ tf.test.main()
create_finetuning_data.py ADDED
@@ -0,0 +1,441 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The TensorFlow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """BERT finetuning task dataset generator."""
16
+
17
+ import functools
18
+ import json
19
+ import os
20
+
21
+ # Import libraries
22
+ from absl import app
23
+ from absl import flags
24
+ import tensorflow as tf, tf_keras
25
+ from official.nlp.data import classifier_data_lib
26
+ from official.nlp.data import sentence_retrieval_lib
27
+ # word-piece tokenizer based squad_lib
28
+ from official.nlp.data import squad_lib as squad_lib_wp
29
+ # sentence-piece tokenizer based squad_lib
30
+ from official.nlp.data import squad_lib_sp
31
+ from official.nlp.data import tagging_data_lib
32
+ from official.nlp.tools import tokenization
33
+
34
+ FLAGS = flags.FLAGS
35
+
36
+ flags.DEFINE_enum(
37
+ "fine_tuning_task_type", "classification",
38
+ ["classification", "regression", "squad", "retrieval", "tagging"],
39
+ "The name of the BERT fine tuning task for which data "
40
+ "will be generated.")
41
+
42
+ # BERT classification specific flags.
43
+ flags.DEFINE_string(
44
+ "input_data_dir", None,
45
+ "The input data dir. Should contain the .tsv files (or other data files) "
46
+ "for the task.")
47
+
48
+ flags.DEFINE_enum(
49
+ "classification_task_name", "MNLI", [
50
+ "AX", "COLA", "IMDB", "MNLI", "MRPC", "PAWS-X", "QNLI", "QQP", "RTE",
51
+ "SST-2", "STS-B", "WNLI", "XNLI", "XTREME-XNLI", "XTREME-PAWS-X",
52
+ "AX-g", "SUPERGLUE-RTE", "CB", "BoolQ", "WIC"
53
+ ], "The name of the task to train BERT classifier. The "
54
+ "difference between XTREME-XNLI and XNLI is: 1. the format "
55
+ "of input tsv files; 2. the dev set for XTREME is english "
56
+ "only and for XNLI is all languages combined. Same for "
57
+ "PAWS-X.")
58
+
59
+ # MNLI task-specific flag.
60
+ flags.DEFINE_enum("mnli_type", "matched", ["matched", "mismatched"],
61
+ "The type of MNLI dataset.")
62
+
63
+ # XNLI task-specific flag.
64
+ flags.DEFINE_string(
65
+ "xnli_language", "en",
66
+ "Language of training data for XNLI task. If the value is 'all', the data "
67
+ "of all languages will be used for training.")
68
+
69
+ # PAWS-X task-specific flag.
70
+ flags.DEFINE_string(
71
+ "pawsx_language", "en",
72
+ "Language of training data for PAWS-X task. If the value is 'all', the data "
73
+ "of all languages will be used for training.")
74
+
75
+ # XTREME classification specific flags. Only used in XtremePawsx and XtremeXnli.
76
+ flags.DEFINE_string(
77
+ "translated_input_data_dir", None,
78
+ "The translated input data dir. Should contain the .tsv files (or other "
79
+ "data files) for the task.")
80
+
81
+ # Retrieval task-specific flags.
82
+ flags.DEFINE_enum("retrieval_task_name", "bucc", ["bucc", "tatoeba"],
83
+ "The name of sentence retrieval task for scoring")
84
+
85
+ # Tagging task-specific flags.
86
+ flags.DEFINE_enum("tagging_task_name", "panx", ["panx", "udpos"],
87
+ "The name of BERT tagging (token classification) task.")
88
+
89
+ flags.DEFINE_bool("tagging_only_use_en_train", True,
90
+ "Whether only use english training data in tagging.")
91
+
92
+ # BERT Squad task-specific flags.
93
+ flags.DEFINE_string(
94
+ "squad_data_file", None,
95
+ "The input data file in for generating training data for BERT squad task.")
96
+
97
+ flags.DEFINE_string(
98
+ "translated_squad_data_folder", None,
99
+ "The translated data folder for generating training data for BERT squad "
100
+ "task.")
101
+
102
+ flags.DEFINE_integer(
103
+ "doc_stride", 128,
104
+ "When splitting up a long document into chunks, how much stride to "
105
+ "take between chunks.")
106
+
107
+ flags.DEFINE_integer(
108
+ "max_query_length", 64,
109
+ "The maximum number of tokens for the question. Questions longer than "
110
+ "this will be truncated to this length.")
111
+
112
+ flags.DEFINE_bool(
113
+ "version_2_with_negative", False,
114
+ "If true, the SQuAD examples contain some that do not have an answer.")
115
+
116
+ flags.DEFINE_bool(
117
+ "xlnet_format", False,
118
+ "If true, then data will be preprocessed in a paragraph, query, class order"
119
+ " instead of the BERT-style class, paragraph, query order.")
120
+
121
+ # XTREME specific flags.
122
+ flags.DEFINE_bool("only_use_en_dev", True, "Whether only use english dev data.")
123
+
124
+ # Shared flags across BERT fine-tuning tasks.
125
+ flags.DEFINE_string("vocab_file", None,
126
+ "The vocabulary file that the BERT model was trained on.")
127
+
128
+ flags.DEFINE_string(
129
+ "train_data_output_path", None,
130
+ "The path in which generated training input data will be written as tf"
131
+ " records.")
132
+
133
+ flags.DEFINE_string(
134
+ "eval_data_output_path", None,
135
+ "The path in which generated evaluation input data will be written as tf"
136
+ " records.")
137
+
138
+ flags.DEFINE_string(
139
+ "test_data_output_path", None,
140
+ "The path in which generated test input data will be written as tf"
141
+ " records. If None, do not generate test data. Must be a pattern template"
142
+ " as test_{}.tfrecords if processor has language specific test data.")
143
+
144
+ flags.DEFINE_string("meta_data_file_path", None,
145
+ "The path in which input meta data will be written.")
146
+
147
+ flags.DEFINE_bool(
148
+ "do_lower_case", True,
149
+ "Whether to lower case the input text. Should be True for uncased "
150
+ "models and False for cased models.")
151
+
152
+ flags.DEFINE_integer(
153
+ "max_seq_length", 128,
154
+ "The maximum total input sequence length after WordPiece tokenization. "
155
+ "Sequences longer than this will be truncated, and sequences shorter "
156
+ "than this will be padded.")
157
+
158
+ flags.DEFINE_string("sp_model_file", "",
159
+ "The path to the model used by sentence piece tokenizer.")
160
+
161
+ flags.DEFINE_enum(
162
+ "tokenization", "WordPiece", ["WordPiece", "SentencePiece"],
163
+ "Specifies the tokenizer implementation, i.e., whether to use WordPiece "
164
+ "or SentencePiece tokenizer. Canonical BERT uses WordPiece tokenizer, "
165
+ "while ALBERT uses SentencePiece tokenizer.")
166
+
167
+ flags.DEFINE_string(
168
+ "tfds_params", "", "Comma-separated list of TFDS parameter assignments for "
169
+ "generic classfication data import (for more details "
170
+ "see the TfdsProcessor class documentation).")
171
+
172
+
173
+ def generate_classifier_dataset():
174
+ """Generates classifier dataset and returns input meta data."""
175
+ if FLAGS.classification_task_name in [
176
+ "COLA",
177
+ "WNLI",
178
+ "SST-2",
179
+ "MRPC",
180
+ "QQP",
181
+ "STS-B",
182
+ "MNLI",
183
+ "QNLI",
184
+ "RTE",
185
+ "AX",
186
+ "SUPERGLUE-RTE",
187
+ "CB",
188
+ "BoolQ",
189
+ "WIC",
190
+ ]:
191
+ assert not FLAGS.input_data_dir or FLAGS.tfds_params
192
+ else:
193
+ assert (FLAGS.input_data_dir and FLAGS.classification_task_name or
194
+ FLAGS.tfds_params)
195
+
196
+ if FLAGS.tokenization == "WordPiece":
197
+ tokenizer = tokenization.FullTokenizer(
198
+ vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)
199
+ processor_text_fn = tokenization.convert_to_unicode
200
+ else:
201
+ assert FLAGS.tokenization == "SentencePiece"
202
+ tokenizer = tokenization.FullSentencePieceTokenizer(FLAGS.sp_model_file)
203
+ processor_text_fn = functools.partial(
204
+ tokenization.preprocess_text, lower=FLAGS.do_lower_case)
205
+
206
+ if FLAGS.tfds_params:
207
+ processor = classifier_data_lib.TfdsProcessor(
208
+ tfds_params=FLAGS.tfds_params, process_text_fn=processor_text_fn)
209
+ return classifier_data_lib.generate_tf_record_from_data_file(
210
+ processor,
211
+ None,
212
+ tokenizer,
213
+ train_data_output_path=FLAGS.train_data_output_path,
214
+ eval_data_output_path=FLAGS.eval_data_output_path,
215
+ test_data_output_path=FLAGS.test_data_output_path,
216
+ max_seq_length=FLAGS.max_seq_length)
217
+ else:
218
+ processors = {
219
+ "ax":
220
+ classifier_data_lib.AxProcessor,
221
+ "cola":
222
+ classifier_data_lib.ColaProcessor,
223
+ "imdb":
224
+ classifier_data_lib.ImdbProcessor,
225
+ "mnli":
226
+ functools.partial(
227
+ classifier_data_lib.MnliProcessor, mnli_type=FLAGS.mnli_type),
228
+ "mrpc":
229
+ classifier_data_lib.MrpcProcessor,
230
+ "qnli":
231
+ classifier_data_lib.QnliProcessor,
232
+ "qqp":
233
+ classifier_data_lib.QqpProcessor,
234
+ "rte":
235
+ classifier_data_lib.RteProcessor,
236
+ "sst-2":
237
+ classifier_data_lib.SstProcessor,
238
+ "sts-b":
239
+ classifier_data_lib.StsBProcessor,
240
+ "xnli":
241
+ functools.partial(
242
+ classifier_data_lib.XnliProcessor,
243
+ language=FLAGS.xnli_language),
244
+ "paws-x":
245
+ functools.partial(
246
+ classifier_data_lib.PawsxProcessor,
247
+ language=FLAGS.pawsx_language),
248
+ "wnli":
249
+ classifier_data_lib.WnliProcessor,
250
+ "xtreme-xnli":
251
+ functools.partial(
252
+ classifier_data_lib.XtremeXnliProcessor,
253
+ translated_data_dir=FLAGS.translated_input_data_dir,
254
+ only_use_en_dev=FLAGS.only_use_en_dev),
255
+ "xtreme-paws-x":
256
+ functools.partial(
257
+ classifier_data_lib.XtremePawsxProcessor,
258
+ translated_data_dir=FLAGS.translated_input_data_dir,
259
+ only_use_en_dev=FLAGS.only_use_en_dev),
260
+ "ax-g":
261
+ classifier_data_lib.AXgProcessor,
262
+ "superglue-rte":
263
+ classifier_data_lib.SuperGLUERTEProcessor,
264
+ "cb":
265
+ classifier_data_lib.CBProcessor,
266
+ "boolq":
267
+ classifier_data_lib.BoolQProcessor,
268
+ "wic":
269
+ classifier_data_lib.WnliProcessor,
270
+ }
271
+ task_name = FLAGS.classification_task_name.lower()
272
+ if task_name not in processors:
273
+ raise ValueError("Task not found: %s" % (task_name,))
274
+
275
+ processor = processors[task_name](process_text_fn=processor_text_fn)
276
+ return classifier_data_lib.generate_tf_record_from_data_file(
277
+ processor,
278
+ FLAGS.input_data_dir,
279
+ tokenizer,
280
+ train_data_output_path=FLAGS.train_data_output_path,
281
+ eval_data_output_path=FLAGS.eval_data_output_path,
282
+ test_data_output_path=FLAGS.test_data_output_path,
283
+ max_seq_length=FLAGS.max_seq_length)
284
+
285
+
286
+ def generate_regression_dataset():
287
+ """Generates regression dataset and returns input meta data."""
288
+ if FLAGS.tokenization == "WordPiece":
289
+ tokenizer = tokenization.FullTokenizer(
290
+ vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)
291
+ processor_text_fn = tokenization.convert_to_unicode
292
+ else:
293
+ assert FLAGS.tokenization == "SentencePiece"
294
+ tokenizer = tokenization.FullSentencePieceTokenizer(FLAGS.sp_model_file)
295
+ processor_text_fn = functools.partial(
296
+ tokenization.preprocess_text, lower=FLAGS.do_lower_case)
297
+
298
+ if FLAGS.tfds_params:
299
+ processor = classifier_data_lib.TfdsProcessor(
300
+ tfds_params=FLAGS.tfds_params, process_text_fn=processor_text_fn)
301
+ return classifier_data_lib.generate_tf_record_from_data_file(
302
+ processor,
303
+ None,
304
+ tokenizer,
305
+ train_data_output_path=FLAGS.train_data_output_path,
306
+ eval_data_output_path=FLAGS.eval_data_output_path,
307
+ test_data_output_path=FLAGS.test_data_output_path,
308
+ max_seq_length=FLAGS.max_seq_length)
309
+ else:
310
+ raise ValueError("No data processor found for the given regression task.")
311
+
312
+
313
+ def generate_squad_dataset():
314
+ """Generates squad training dataset and returns input meta data."""
315
+ assert FLAGS.squad_data_file
316
+ if FLAGS.tokenization == "WordPiece":
317
+ return squad_lib_wp.generate_tf_record_from_json_file(
318
+ input_file_path=FLAGS.squad_data_file,
319
+ vocab_file_path=FLAGS.vocab_file,
320
+ output_path=FLAGS.train_data_output_path,
321
+ translated_input_folder=FLAGS.translated_squad_data_folder,
322
+ max_seq_length=FLAGS.max_seq_length,
323
+ do_lower_case=FLAGS.do_lower_case,
324
+ max_query_length=FLAGS.max_query_length,
325
+ doc_stride=FLAGS.doc_stride,
326
+ version_2_with_negative=FLAGS.version_2_with_negative,
327
+ xlnet_format=FLAGS.xlnet_format)
328
+ else:
329
+ assert FLAGS.tokenization == "SentencePiece"
330
+ return squad_lib_sp.generate_tf_record_from_json_file(
331
+ input_file_path=FLAGS.squad_data_file,
332
+ sp_model_file=FLAGS.sp_model_file,
333
+ output_path=FLAGS.train_data_output_path,
334
+ translated_input_folder=FLAGS.translated_squad_data_folder,
335
+ max_seq_length=FLAGS.max_seq_length,
336
+ do_lower_case=FLAGS.do_lower_case,
337
+ max_query_length=FLAGS.max_query_length,
338
+ doc_stride=FLAGS.doc_stride,
339
+ xlnet_format=FLAGS.xlnet_format,
340
+ version_2_with_negative=FLAGS.version_2_with_negative)
341
+
342
+
343
+ def generate_retrieval_dataset():
344
+ """Generate retrieval test and dev dataset and returns input meta data."""
345
+ assert (FLAGS.input_data_dir and FLAGS.retrieval_task_name)
346
+ if FLAGS.tokenization == "WordPiece":
347
+ tokenizer = tokenization.FullTokenizer(
348
+ vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)
349
+ processor_text_fn = tokenization.convert_to_unicode
350
+ else:
351
+ assert FLAGS.tokenization == "SentencePiece"
352
+ tokenizer = tokenization.FullSentencePieceTokenizer(FLAGS.sp_model_file)
353
+ processor_text_fn = functools.partial(
354
+ tokenization.preprocess_text, lower=FLAGS.do_lower_case)
355
+
356
+ processors = {
357
+ "bucc": sentence_retrieval_lib.BuccProcessor,
358
+ "tatoeba": sentence_retrieval_lib.TatoebaProcessor,
359
+ }
360
+
361
+ task_name = FLAGS.retrieval_task_name.lower()
362
+ if task_name not in processors:
363
+ raise ValueError("Task not found: %s" % task_name)
364
+
365
+ processor = processors[task_name](process_text_fn=processor_text_fn)
366
+
367
+ return sentence_retrieval_lib.generate_sentence_retrevial_tf_record(
368
+ processor, FLAGS.input_data_dir, tokenizer, FLAGS.eval_data_output_path,
369
+ FLAGS.test_data_output_path, FLAGS.max_seq_length)
370
+
371
+
372
+ def generate_tagging_dataset():
373
+ """Generates tagging dataset."""
374
+ processors = {
375
+ "panx":
376
+ functools.partial(
377
+ tagging_data_lib.PanxProcessor,
378
+ only_use_en_train=FLAGS.tagging_only_use_en_train,
379
+ only_use_en_dev=FLAGS.only_use_en_dev),
380
+ "udpos":
381
+ functools.partial(
382
+ tagging_data_lib.UdposProcessor,
383
+ only_use_en_train=FLAGS.tagging_only_use_en_train,
384
+ only_use_en_dev=FLAGS.only_use_en_dev),
385
+ }
386
+ task_name = FLAGS.tagging_task_name.lower()
387
+ if task_name not in processors:
388
+ raise ValueError("Task not found: %s" % task_name)
389
+
390
+ if FLAGS.tokenization == "WordPiece":
391
+ tokenizer = tokenization.FullTokenizer(
392
+ vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)
393
+ processor_text_fn = tokenization.convert_to_unicode
394
+ elif FLAGS.tokenization == "SentencePiece":
395
+ tokenizer = tokenization.FullSentencePieceTokenizer(FLAGS.sp_model_file)
396
+ processor_text_fn = functools.partial(
397
+ tokenization.preprocess_text, lower=FLAGS.do_lower_case)
398
+ else:
399
+ raise ValueError("Unsupported tokenization: %s" % FLAGS.tokenization)
400
+
401
+ processor = processors[task_name]()
402
+ return tagging_data_lib.generate_tf_record_from_data_file(
403
+ processor, FLAGS.input_data_dir, tokenizer, FLAGS.max_seq_length,
404
+ FLAGS.train_data_output_path, FLAGS.eval_data_output_path,
405
+ FLAGS.test_data_output_path, processor_text_fn)
406
+
407
+
408
+ def main(_):
409
+ if FLAGS.tokenization == "WordPiece":
410
+ if not FLAGS.vocab_file:
411
+ raise ValueError(
412
+ "FLAG vocab_file for word-piece tokenizer is not specified.")
413
+ else:
414
+ assert FLAGS.tokenization == "SentencePiece"
415
+ if not FLAGS.sp_model_file:
416
+ raise ValueError(
417
+ "FLAG sp_model_file for sentence-piece tokenizer is not specified.")
418
+
419
+ if FLAGS.fine_tuning_task_type != "retrieval":
420
+ flags.mark_flag_as_required("train_data_output_path")
421
+
422
+ if FLAGS.fine_tuning_task_type == "classification":
423
+ input_meta_data = generate_classifier_dataset()
424
+ elif FLAGS.fine_tuning_task_type == "regression":
425
+ input_meta_data = generate_regression_dataset()
426
+ elif FLAGS.fine_tuning_task_type == "retrieval":
427
+ input_meta_data = generate_retrieval_dataset()
428
+ elif FLAGS.fine_tuning_task_type == "squad":
429
+ input_meta_data = generate_squad_dataset()
430
+ else:
431
+ assert FLAGS.fine_tuning_task_type == "tagging"
432
+ input_meta_data = generate_tagging_dataset()
433
+
434
+ tf.io.gfile.makedirs(os.path.dirname(FLAGS.meta_data_file_path))
435
+ with tf.io.gfile.GFile(FLAGS.meta_data_file_path, "w") as writer:
436
+ writer.write(json.dumps(input_meta_data, indent=4) + "\n")
437
+
438
+
439
+ if __name__ == "__main__":
440
+ flags.mark_flag_as_required("meta_data_file_path")
441
+ app.run(main)
create_pretraining_data.py ADDED
@@ -0,0 +1,718 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The TensorFlow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Create masked LM/next sentence masked_lm TF examples for BERT."""
16
+
17
+ import collections
18
+ import itertools
19
+ import random
20
+
21
+ # Import libraries
22
+
23
+ from absl import app
24
+ from absl import flags
25
+ from absl import logging
26
+ import tensorflow as tf, tf_keras
27
+
28
+ from official.nlp.tools import tokenization
29
+
30
+ FLAGS = flags.FLAGS
31
+
32
+ flags.DEFINE_string("input_file", None,
33
+ "Input raw text file (or comma-separated list of files).")
34
+
35
+ flags.DEFINE_string(
36
+ "output_file", None,
37
+ "Output TF example file (or comma-separated list of files).")
38
+
39
+ flags.DEFINE_enum(
40
+ "tokenization",
41
+ "WordPiece",
42
+ ["WordPiece", "SentencePiece"],
43
+ "Specifies the tokenizer implementation, i.e., whether to use WordPiece "
44
+ "or SentencePiece tokenizer. Canonical BERT uses WordPiece tokenizer, "
45
+ "while ALBERT uses SentencePiece tokenizer.",
46
+ )
47
+
48
+ flags.DEFINE_string(
49
+ "vocab_file",
50
+ None,
51
+ "For WordPiece tokenization, the vocabulary file of the tokenizer.",
52
+ )
53
+
54
+ flags.DEFINE_string(
55
+ "sp_model_file",
56
+ "",
57
+ "For SentencePiece tokenization, the path to the model of the tokenizer.",
58
+ )
59
+
60
+ flags.DEFINE_bool(
61
+ "do_lower_case", True,
62
+ "Whether to lower case the input text. Should be True for uncased "
63
+ "models and False for cased models.")
64
+
65
+ flags.DEFINE_bool(
66
+ "do_whole_word_mask",
67
+ False,
68
+ "Whether to use whole word masking rather than per-token masking.",
69
+ )
70
+
71
+ flags.DEFINE_integer(
72
+ "max_ngram_size", None,
73
+ "Mask contiguous whole words (n-grams) of up to `max_ngram_size` using a "
74
+ "weighting scheme to favor shorter n-grams. "
75
+ "Note: `--do_whole_word_mask=True` must also be set when n-gram masking.")
76
+
77
+ flags.DEFINE_bool(
78
+ "gzip_compress", False,
79
+ "Whether to use `GZIP` compress option to get compressed TFRecord files.")
80
+
81
+ flags.DEFINE_bool(
82
+ "use_v2_feature_names", False,
83
+ "Whether to use the feature names consistent with the models.")
84
+
85
+ flags.DEFINE_integer("max_seq_length", 128, "Maximum sequence length.")
86
+
87
+ flags.DEFINE_integer("max_predictions_per_seq", 20,
88
+ "Maximum number of masked LM predictions per sequence.")
89
+
90
+ flags.DEFINE_integer("random_seed", 12345, "Random seed for data generation.")
91
+
92
+ flags.DEFINE_integer(
93
+ "dupe_factor", 10,
94
+ "Number of times to duplicate the input data (with different masks).")
95
+
96
+ flags.DEFINE_float("masked_lm_prob", 0.15, "Masked LM probability.")
97
+
98
+ flags.DEFINE_float(
99
+ "short_seq_prob", 0.1,
100
+ "Probability of creating sequences which are shorter than the "
101
+ "maximum length.")
102
+
103
+
104
+ class TrainingInstance(object):
105
+ """A single training instance (sentence pair)."""
106
+
107
+ def __init__(self, tokens, segment_ids, masked_lm_positions, masked_lm_labels,
108
+ is_random_next):
109
+ self.tokens = tokens
110
+ self.segment_ids = segment_ids
111
+ self.is_random_next = is_random_next
112
+ self.masked_lm_positions = masked_lm_positions
113
+ self.masked_lm_labels = masked_lm_labels
114
+
115
+ def __str__(self):
116
+ s = ""
117
+ s += "tokens: %s\n" % (" ".join(
118
+ [tokenization.printable_text(x) for x in self.tokens]))
119
+ s += "segment_ids: %s\n" % (" ".join([str(x) for x in self.segment_ids]))
120
+ s += "is_random_next: %s\n" % self.is_random_next
121
+ s += "masked_lm_positions: %s\n" % (" ".join(
122
+ [str(x) for x in self.masked_lm_positions]))
123
+ s += "masked_lm_labels: %s\n" % (" ".join(
124
+ [tokenization.printable_text(x) for x in self.masked_lm_labels]))
125
+ s += "\n"
126
+ return s
127
+
128
+ def __repr__(self):
129
+ return self.__str__()
130
+
131
+
132
+ def write_instance_to_example_files(instances, tokenizer, max_seq_length,
133
+ max_predictions_per_seq, output_files,
134
+ gzip_compress, use_v2_feature_names):
135
+ """Creates TF example files from `TrainingInstance`s."""
136
+ writers = []
137
+ for output_file in output_files:
138
+ writers.append(
139
+ tf.io.TFRecordWriter(
140
+ output_file, options="GZIP" if gzip_compress else ""))
141
+
142
+ writer_index = 0
143
+
144
+ total_written = 0
145
+ for (inst_index, instance) in enumerate(instances):
146
+ input_ids = tokenizer.convert_tokens_to_ids(instance.tokens)
147
+ input_mask = [1] * len(input_ids)
148
+ segment_ids = list(instance.segment_ids)
149
+ assert len(input_ids) <= max_seq_length
150
+
151
+ while len(input_ids) < max_seq_length:
152
+ input_ids.append(0)
153
+ input_mask.append(0)
154
+ segment_ids.append(0)
155
+
156
+ assert len(input_ids) == max_seq_length
157
+ assert len(input_mask) == max_seq_length
158
+ assert len(segment_ids) == max_seq_length
159
+
160
+ masked_lm_positions = list(instance.masked_lm_positions)
161
+ masked_lm_ids = tokenizer.convert_tokens_to_ids(instance.masked_lm_labels)
162
+ masked_lm_weights = [1.0] * len(masked_lm_ids)
163
+
164
+ while len(masked_lm_positions) < max_predictions_per_seq:
165
+ masked_lm_positions.append(0)
166
+ masked_lm_ids.append(0)
167
+ masked_lm_weights.append(0.0)
168
+
169
+ next_sentence_label = 1 if instance.is_random_next else 0
170
+
171
+ features = collections.OrderedDict()
172
+ if use_v2_feature_names:
173
+ features["input_word_ids"] = create_int_feature(input_ids)
174
+ features["input_type_ids"] = create_int_feature(segment_ids)
175
+ else:
176
+ features["input_ids"] = create_int_feature(input_ids)
177
+ features["segment_ids"] = create_int_feature(segment_ids)
178
+
179
+ features["input_mask"] = create_int_feature(input_mask)
180
+ features["masked_lm_positions"] = create_int_feature(masked_lm_positions)
181
+ features["masked_lm_ids"] = create_int_feature(masked_lm_ids)
182
+ features["masked_lm_weights"] = create_float_feature(masked_lm_weights)
183
+ features["next_sentence_labels"] = create_int_feature([next_sentence_label])
184
+
185
+ tf_example = tf.train.Example(features=tf.train.Features(feature=features))
186
+
187
+ writers[writer_index].write(tf_example.SerializeToString())
188
+ writer_index = (writer_index + 1) % len(writers)
189
+
190
+ total_written += 1
191
+
192
+ if inst_index < 20:
193
+ logging.info("*** Example ***")
194
+ logging.info("tokens: %s", " ".join(
195
+ [tokenization.printable_text(x) for x in instance.tokens]))
196
+
197
+ for feature_name in features.keys():
198
+ feature = features[feature_name]
199
+ values = []
200
+ if feature.int64_list.value:
201
+ values = feature.int64_list.value
202
+ elif feature.float_list.value:
203
+ values = feature.float_list.value
204
+ logging.info("%s: %s", feature_name, " ".join([str(x) for x in values]))
205
+
206
+ for writer in writers:
207
+ writer.close()
208
+
209
+ logging.info("Wrote %d total instances", total_written)
210
+
211
+
212
+ def create_int_feature(values):
213
+ feature = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values)))
214
+ return feature
215
+
216
+
217
+ def create_float_feature(values):
218
+ feature = tf.train.Feature(float_list=tf.train.FloatList(value=list(values)))
219
+ return feature
220
+
221
+
222
+ def create_training_instances(
223
+ input_files,
224
+ tokenizer,
225
+ processor_text_fn,
226
+ max_seq_length,
227
+ dupe_factor,
228
+ short_seq_prob,
229
+ masked_lm_prob,
230
+ max_predictions_per_seq,
231
+ rng,
232
+ do_whole_word_mask=False,
233
+ max_ngram_size=None,
234
+ ):
235
+ """Create `TrainingInstance`s from raw text."""
236
+ all_documents = [[]]
237
+
238
+ # Input file format:
239
+ # (1) One sentence per line. These should ideally be actual sentences, not
240
+ # entire paragraphs or arbitrary spans of text. (Because we use the
241
+ # sentence boundaries for the "next sentence prediction" task).
242
+ # (2) Blank lines between documents. Document boundaries are needed so
243
+ # that the "next sentence prediction" task doesn't span between documents.
244
+ for input_file in input_files:
245
+ with tf.io.gfile.GFile(input_file, "rb") as reader:
246
+ for line in reader:
247
+ line = processor_text_fn(line)
248
+
249
+ # Empty lines are used as document delimiters
250
+ if not line:
251
+ all_documents.append([])
252
+ tokens = tokenizer.tokenize(line)
253
+ if tokens:
254
+ all_documents[-1].append(tokens)
255
+
256
+ # Remove empty documents
257
+ all_documents = [x for x in all_documents if x]
258
+ rng.shuffle(all_documents)
259
+
260
+ vocab_words = list(tokenizer.vocab.keys())
261
+ instances = []
262
+ for _ in range(dupe_factor):
263
+ for document_index in range(len(all_documents)):
264
+ instances.extend(
265
+ create_instances_from_document(
266
+ all_documents, document_index, max_seq_length, short_seq_prob,
267
+ masked_lm_prob, max_predictions_per_seq, vocab_words, rng,
268
+ do_whole_word_mask, max_ngram_size))
269
+
270
+ rng.shuffle(instances)
271
+ return instances
272
+
273
+
274
+ def create_instances_from_document(
275
+ all_documents, document_index, max_seq_length, short_seq_prob,
276
+ masked_lm_prob, max_predictions_per_seq, vocab_words, rng,
277
+ do_whole_word_mask=False,
278
+ max_ngram_size=None):
279
+ """Creates `TrainingInstance`s for a single document."""
280
+ document = all_documents[document_index]
281
+
282
+ # Account for [CLS], [SEP], [SEP]
283
+ max_num_tokens = max_seq_length - 3
284
+
285
+ # We *usually* want to fill up the entire sequence since we are padding
286
+ # to `max_seq_length` anyways, so short sequences are generally wasted
287
+ # computation. However, we *sometimes*
288
+ # (i.e., short_seq_prob == 0.1 == 10% of the time) want to use shorter
289
+ # sequences to minimize the mismatch between pre-training and fine-tuning.
290
+ # The `target_seq_length` is just a rough target however, whereas
291
+ # `max_seq_length` is a hard limit.
292
+ target_seq_length = max_num_tokens
293
+ if rng.random() < short_seq_prob:
294
+ target_seq_length = rng.randint(2, max_num_tokens)
295
+
296
+ # We DON'T just concatenate all of the tokens from a document into a long
297
+ # sequence and choose an arbitrary split point because this would make the
298
+ # next sentence prediction task too easy. Instead, we split the input into
299
+ # segments "A" and "B" based on the actual "sentences" provided by the user
300
+ # input.
301
+ instances = []
302
+ current_chunk = []
303
+ current_length = 0
304
+ i = 0
305
+ while i < len(document):
306
+ segment = document[i]
307
+ current_chunk.append(segment)
308
+ current_length += len(segment)
309
+ if i == len(document) - 1 or current_length >= target_seq_length:
310
+ if current_chunk:
311
+ # `a_end` is how many segments from `current_chunk` go into the `A`
312
+ # (first) sentence.
313
+ a_end = 1
314
+ if len(current_chunk) >= 2:
315
+ a_end = rng.randint(1, len(current_chunk) - 1)
316
+
317
+ tokens_a = []
318
+ for j in range(a_end):
319
+ tokens_a.extend(current_chunk[j])
320
+
321
+ tokens_b = []
322
+ # Random next
323
+ is_random_next = False
324
+ if len(current_chunk) == 1 or rng.random() < 0.5:
325
+ is_random_next = True
326
+ target_b_length = target_seq_length - len(tokens_a)
327
+
328
+ # This should rarely go for more than one iteration for large
329
+ # corpora. However, just to be careful, we try to make sure that
330
+ # the random document is not the same as the document
331
+ # we're processing.
332
+ for _ in range(10):
333
+ random_document_index = rng.randint(0, len(all_documents) - 1)
334
+ if random_document_index != document_index:
335
+ break
336
+
337
+ random_document = all_documents[random_document_index]
338
+ random_start = rng.randint(0, len(random_document) - 1)
339
+ for j in range(random_start, len(random_document)):
340
+ tokens_b.extend(random_document[j])
341
+ if len(tokens_b) >= target_b_length:
342
+ break
343
+ # We didn't actually use these segments so we "put them back" so
344
+ # they don't go to waste.
345
+ num_unused_segments = len(current_chunk) - a_end
346
+ i -= num_unused_segments
347
+ # Actual next
348
+ else:
349
+ is_random_next = False
350
+ for j in range(a_end, len(current_chunk)):
351
+ tokens_b.extend(current_chunk[j])
352
+ truncate_seq_pair(tokens_a, tokens_b, max_num_tokens, rng)
353
+
354
+ assert len(tokens_a) >= 1
355
+ assert len(tokens_b) >= 1
356
+
357
+ tokens = []
358
+ segment_ids = []
359
+ tokens.append("[CLS]")
360
+ segment_ids.append(0)
361
+ for token in tokens_a:
362
+ tokens.append(token)
363
+ segment_ids.append(0)
364
+
365
+ tokens.append("[SEP]")
366
+ segment_ids.append(0)
367
+
368
+ for token in tokens_b:
369
+ tokens.append(token)
370
+ segment_ids.append(1)
371
+ tokens.append("[SEP]")
372
+ segment_ids.append(1)
373
+
374
+ (tokens, masked_lm_positions,
375
+ masked_lm_labels) = create_masked_lm_predictions(
376
+ tokens, masked_lm_prob, max_predictions_per_seq, vocab_words, rng,
377
+ do_whole_word_mask, max_ngram_size)
378
+ instance = TrainingInstance(
379
+ tokens=tokens,
380
+ segment_ids=segment_ids,
381
+ is_random_next=is_random_next,
382
+ masked_lm_positions=masked_lm_positions,
383
+ masked_lm_labels=masked_lm_labels)
384
+ instances.append(instance)
385
+ current_chunk = []
386
+ current_length = 0
387
+ i += 1
388
+
389
+ return instances
390
+
391
+
392
+ MaskedLmInstance = collections.namedtuple("MaskedLmInstance",
393
+ ["index", "label"])
394
+
395
+ # A _Gram is a [half-open) interval of token indices which form a word.
396
+ # E.g.,
397
+ # words: ["The", "doghouse"]
398
+ # tokens: ["The", "dog", "##house"]
399
+ # grams: [(0,1), (1,3)]
400
+ _Gram = collections.namedtuple("_Gram", ["begin", "end"])
401
+
402
+
403
+ def _window(iterable, size):
404
+ """Helper to create a sliding window iterator with a given size.
405
+
406
+ E.g.,
407
+ input = [1, 2, 3, 4]
408
+ _window(input, 1) => [1], [2], [3], [4]
409
+ _window(input, 2) => [1, 2], [2, 3], [3, 4]
410
+ _window(input, 3) => [1, 2, 3], [2, 3, 4]
411
+ _window(input, 4) => [1, 2, 3, 4]
412
+ _window(input, 5) => None
413
+
414
+ Args:
415
+ iterable: elements to iterate over.
416
+ size: size of the window.
417
+
418
+ Yields:
419
+ Elements of `iterable` batched into a sliding window of length `size`.
420
+ """
421
+ i = iter(iterable)
422
+ window = []
423
+ try:
424
+ for e in range(0, size):
425
+ window.append(next(i))
426
+ yield window
427
+ except StopIteration:
428
+ # handle the case where iterable's length is less than the window size.
429
+ return
430
+ for e in i:
431
+ window = window[1:] + [e]
432
+ yield window
433
+
434
+
435
+ def _contiguous(sorted_grams):
436
+ """Test whether a sequence of grams is contiguous.
437
+
438
+ Args:
439
+ sorted_grams: _Grams which are sorted in increasing order.
440
+ Returns:
441
+ True if `sorted_grams` are touching each other.
442
+
443
+ E.g.,
444
+ _contiguous([(1, 4), (4, 5), (5, 10)]) == True
445
+ _contiguous([(1, 2), (4, 5)]) == False
446
+ """
447
+ for a, b in _window(sorted_grams, 2):
448
+ if a.end != b.begin:
449
+ return False
450
+ return True
451
+
452
+
453
+ def _masking_ngrams(grams, max_ngram_size, max_masked_tokens, rng):
454
+ """Create a list of masking {1, ..., n}-grams from a list of one-grams.
455
+
456
+ This is an extension of 'whole word masking' to mask multiple, contiguous
457
+ words such as (e.g., "the red boat").
458
+
459
+ Each input gram represents the token indices of a single word,
460
+ words: ["the", "red", "boat"]
461
+ tokens: ["the", "red", "boa", "##t"]
462
+ grams: [(0,1), (1,2), (2,4)]
463
+
464
+ For a `max_ngram_size` of three, possible outputs masks include:
465
+ 1-grams: (0,1), (1,2), (2,4)
466
+ 2-grams: (0,2), (1,4)
467
+ 3-grams; (0,4)
468
+
469
+ Output masks will not overlap and contain less than `max_masked_tokens` total
470
+ tokens. E.g., for the example above with `max_masked_tokens` as three,
471
+ valid outputs are,
472
+ [(0,1), (1,2)] # "the", "red" covering two tokens
473
+ [(1,2), (2,4)] # "red", "boa", "##t" covering three tokens
474
+
475
+ The length of the selected n-gram follows a zipf weighting to
476
+ favor shorter n-gram sizes (weight(1)=1, weight(2)=1/2, weight(3)=1/3, ...).
477
+
478
+ Args:
479
+ grams: List of one-grams.
480
+ max_ngram_size: Maximum number of contiguous one-grams combined to create
481
+ an n-gram.
482
+ max_masked_tokens: Maximum total number of tokens to be masked.
483
+ rng: `random.Random` generator.
484
+
485
+ Returns:
486
+ A list of n-grams to be used as masks.
487
+ """
488
+ if not grams:
489
+ return None
490
+
491
+ grams = sorted(grams)
492
+ num_tokens = grams[-1].end
493
+
494
+ # Ensure our grams are valid (i.e., they don't overlap).
495
+ for a, b in _window(grams, 2):
496
+ if a.end > b.begin:
497
+ raise ValueError("overlapping grams: {}".format(grams))
498
+
499
+ # Build map from n-gram length to list of n-grams.
500
+ ngrams = {i: [] for i in range(1, max_ngram_size+1)}
501
+ for gram_size in range(1, max_ngram_size+1):
502
+ for g in _window(grams, gram_size):
503
+ if _contiguous(g):
504
+ # Add an n-gram which spans these one-grams.
505
+ ngrams[gram_size].append(_Gram(g[0].begin, g[-1].end))
506
+
507
+ # Shuffle each list of n-grams.
508
+ for v in ngrams.values():
509
+ rng.shuffle(v)
510
+
511
+ # Create the weighting for n-gram length selection.
512
+ # Stored cumulatively for `random.choices` below.
513
+ cummulative_weights = list(
514
+ itertools.accumulate([1./n for n in range(1, max_ngram_size+1)]))
515
+
516
+ output_ngrams = []
517
+ # Keep a bitmask of which tokens have been masked.
518
+ masked_tokens = [False] * num_tokens
519
+ # Loop until we have enough masked tokens or there are no more candidate
520
+ # n-grams of any length.
521
+ # Each code path should ensure one or more elements from `ngrams` are removed
522
+ # to guarantee this loop terminates.
523
+ while (sum(masked_tokens) < max_masked_tokens and
524
+ sum(len(s) for s in ngrams.values())):
525
+ # Pick an n-gram size based on our weights.
526
+ sz = random.choices(range(1, max_ngram_size+1),
527
+ cum_weights=cummulative_weights)[0]
528
+
529
+ # Ensure this size doesn't result in too many masked tokens.
530
+ # E.g., a two-gram contains _at least_ two tokens.
531
+ if sum(masked_tokens) + sz > max_masked_tokens:
532
+ # All n-grams of this length are too long and can be removed from
533
+ # consideration.
534
+ ngrams[sz].clear()
535
+ continue
536
+
537
+ # All of the n-grams of this size have been used.
538
+ if not ngrams[sz]:
539
+ continue
540
+
541
+ # Choose a random n-gram of the given size.
542
+ gram = ngrams[sz].pop()
543
+ num_gram_tokens = gram.end-gram.begin
544
+
545
+ # Check if this would add too many tokens.
546
+ if num_gram_tokens + sum(masked_tokens) > max_masked_tokens:
547
+ continue
548
+
549
+ # Check if any of the tokens in this gram have already been masked.
550
+ if sum(masked_tokens[gram.begin:gram.end]):
551
+ continue
552
+
553
+ # Found a usable n-gram! Mark its tokens as masked and add it to return.
554
+ masked_tokens[gram.begin:gram.end] = [True] * (gram.end-gram.begin)
555
+ output_ngrams.append(gram)
556
+ return output_ngrams
557
+
558
+
559
+ def _tokens_to_grams(tokens):
560
+ """Reconstitue grams (words) from `tokens`.
561
+
562
+ E.g.,
563
+ tokens: ['[CLS]', 'That', 'lit', '##tle', 'blue', 'tru', '##ck', '[SEP]']
564
+ grams: [ [1,2), [2, 4), [4,5) , [5, 6)]
565
+
566
+ Args:
567
+ tokens: list of tokens (word pieces or sentence pieces).
568
+
569
+ Returns:
570
+ List of _Grams representing spans of whole words
571
+ (without "[CLS]" and "[SEP]").
572
+ """
573
+ grams = []
574
+ gram_start_pos = None
575
+ for i, token in enumerate(tokens):
576
+ if gram_start_pos is not None and token.startswith("##"):
577
+ continue
578
+ if gram_start_pos is not None:
579
+ grams.append(_Gram(gram_start_pos, i))
580
+ if token not in ["[CLS]", "[SEP]"]:
581
+ gram_start_pos = i
582
+ else:
583
+ gram_start_pos = None
584
+ if gram_start_pos is not None:
585
+ grams.append(_Gram(gram_start_pos, len(tokens)))
586
+ return grams
587
+
588
+
589
+ def create_masked_lm_predictions(tokens, masked_lm_prob,
590
+ max_predictions_per_seq, vocab_words, rng,
591
+ do_whole_word_mask,
592
+ max_ngram_size=None):
593
+ """Creates the predictions for the masked LM objective."""
594
+ if do_whole_word_mask:
595
+ grams = _tokens_to_grams(tokens)
596
+ else:
597
+ # Here we consider each token to be a word to allow for sub-word masking.
598
+ if max_ngram_size:
599
+ raise ValueError("cannot use ngram masking without whole word masking")
600
+ grams = [_Gram(i, i+1) for i in range(0, len(tokens))
601
+ if tokens[i] not in ["[CLS]", "[SEP]"]]
602
+
603
+ num_to_predict = min(max_predictions_per_seq,
604
+ max(1, int(round(len(tokens) * masked_lm_prob))))
605
+ # Generate masks. If `max_ngram_size` in [0, None] it means we're doing
606
+ # whole word masking or token level masking. Both of these can be treated
607
+ # as the `max_ngram_size=1` case.
608
+ masked_grams = _masking_ngrams(grams, max_ngram_size or 1,
609
+ num_to_predict, rng)
610
+ masked_lms = []
611
+ output_tokens = list(tokens)
612
+ for gram in masked_grams:
613
+ # 80% of the time, replace all n-gram tokens with [MASK]
614
+ if rng.random() < 0.8:
615
+ replacement_action = lambda idx: "[MASK]"
616
+ else:
617
+ # 10% of the time, keep all the original n-gram tokens.
618
+ if rng.random() < 0.5:
619
+ replacement_action = lambda idx: tokens[idx]
620
+ # 10% of the time, replace each n-gram token with a random word.
621
+ else:
622
+ replacement_action = lambda idx: rng.choice(vocab_words)
623
+
624
+ for idx in range(gram.begin, gram.end):
625
+ output_tokens[idx] = replacement_action(idx)
626
+ masked_lms.append(MaskedLmInstance(index=idx, label=tokens[idx]))
627
+
628
+ assert len(masked_lms) <= num_to_predict
629
+ masked_lms = sorted(masked_lms, key=lambda x: x.index)
630
+
631
+ masked_lm_positions = []
632
+ masked_lm_labels = []
633
+ for p in masked_lms:
634
+ masked_lm_positions.append(p.index)
635
+ masked_lm_labels.append(p.label)
636
+
637
+ return (output_tokens, masked_lm_positions, masked_lm_labels)
638
+
639
+
640
+ def truncate_seq_pair(tokens_a, tokens_b, max_num_tokens, rng):
641
+ """Truncates a pair of sequences to a maximum sequence length."""
642
+ while True:
643
+ total_length = len(tokens_a) + len(tokens_b)
644
+ if total_length <= max_num_tokens:
645
+ break
646
+
647
+ trunc_tokens = tokens_a if len(tokens_a) > len(tokens_b) else tokens_b
648
+ assert len(trunc_tokens) >= 1
649
+
650
+ # We want to sometimes truncate from the front and sometimes from the
651
+ # back to add more randomness and avoid biases.
652
+ if rng.random() < 0.5:
653
+ del trunc_tokens[0]
654
+ else:
655
+ trunc_tokens.pop()
656
+
657
+
658
+ def get_processor_text_fn(is_sentence_piece, do_lower_case):
659
+ def processor_text_fn(text):
660
+ text = tokenization.convert_to_unicode(text)
661
+ if is_sentence_piece:
662
+ # Additional preprocessing specific to the SentencePiece tokenizer.
663
+ text = tokenization.preprocess_text(text, lower=do_lower_case)
664
+
665
+ return text.strip()
666
+
667
+ return processor_text_fn
668
+
669
+
670
+ def main(_):
671
+ if FLAGS.tokenization == "WordPiece":
672
+ tokenizer = tokenization.FullTokenizer(
673
+ vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case
674
+ )
675
+ processor_text_fn = get_processor_text_fn(False, FLAGS.do_lower_case)
676
+ else:
677
+ assert FLAGS.tokenization == "SentencePiece"
678
+ tokenizer = tokenization.FullSentencePieceTokenizer(FLAGS.sp_model_file)
679
+ processor_text_fn = get_processor_text_fn(True, FLAGS.do_lower_case)
680
+
681
+ input_files = []
682
+ for input_pattern in FLAGS.input_file.split(","):
683
+ input_files.extend(tf.io.gfile.glob(input_pattern))
684
+
685
+ logging.info("*** Reading from input files ***")
686
+ for input_file in input_files:
687
+ logging.info(" %s", input_file)
688
+
689
+ rng = random.Random(FLAGS.random_seed)
690
+ instances = create_training_instances(
691
+ input_files,
692
+ tokenizer,
693
+ processor_text_fn,
694
+ FLAGS.max_seq_length,
695
+ FLAGS.dupe_factor,
696
+ FLAGS.short_seq_prob,
697
+ FLAGS.masked_lm_prob,
698
+ FLAGS.max_predictions_per_seq,
699
+ rng,
700
+ FLAGS.do_whole_word_mask,
701
+ FLAGS.max_ngram_size,
702
+ )
703
+
704
+ output_files = FLAGS.output_file.split(",")
705
+ logging.info("*** Writing to output files ***")
706
+ for output_file in output_files:
707
+ logging.info(" %s", output_file)
708
+
709
+ write_instance_to_example_files(instances, tokenizer, FLAGS.max_seq_length,
710
+ FLAGS.max_predictions_per_seq, output_files,
711
+ FLAGS.gzip_compress,
712
+ FLAGS.use_v2_feature_names)
713
+
714
+
715
+ if __name__ == "__main__":
716
+ flags.mark_flag_as_required("input_file")
717
+ flags.mark_flag_as_required("output_file")
718
+ app.run(main)
create_pretraining_data_test.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The TensorFlow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Tests for official.nlp.data.create_pretraining_data."""
16
+ import random
17
+
18
+ import tensorflow as tf, tf_keras
19
+
20
+ from official.nlp.data import create_pretraining_data as cpd
21
+
22
+ _VOCAB_WORDS = ["vocab_1", "vocab_2"]
23
+
24
+
25
+ class CreatePretrainingDataTest(tf.test.TestCase):
26
+
27
+ def assertTokens(self, input_tokens, output_tokens, masked_positions,
28
+ masked_labels):
29
+ # Ensure the masked positions are unique.
30
+ self.assertCountEqual(masked_positions, set(masked_positions))
31
+
32
+ # Ensure we can reconstruct the input from the output.
33
+ reconstructed_tokens = output_tokens
34
+ for pos, label in zip(masked_positions, masked_labels):
35
+ reconstructed_tokens[pos] = label
36
+ self.assertEqual(input_tokens, reconstructed_tokens)
37
+
38
+ # Ensure each label is valid.
39
+ for pos, label in zip(masked_positions, masked_labels):
40
+ output_token = output_tokens[pos]
41
+ if (output_token == "[MASK]" or output_token in _VOCAB_WORDS or
42
+ output_token == input_tokens[pos]):
43
+ continue
44
+ self.fail("invalid mask value: {}".format(output_token))
45
+
46
+ def test_tokens_to_grams(self):
47
+ tests = [
48
+ (["That", "cone"], [(0, 1), (1, 2)]),
49
+ (["That", "cone", "##s"], [(0, 1), (1, 3)]),
50
+ (["Swit", "##zer", "##land"], [(0, 3)]),
51
+ (["[CLS]", "Up", "##dog"], [(1, 3)]),
52
+ (["[CLS]", "Up", "##dog", "[SEP]", "Down"], [(1, 3), (4, 5)]),
53
+ ]
54
+ for inp, expected in tests:
55
+ output = cpd._tokens_to_grams(inp)
56
+ self.assertEqual(expected, output)
57
+
58
+ def test_window(self):
59
+ input_list = [1, 2, 3, 4]
60
+ window_outputs = [
61
+ (1, [[1], [2], [3], [4]]),
62
+ (2, [[1, 2], [2, 3], [3, 4]]),
63
+ (3, [[1, 2, 3], [2, 3, 4]]),
64
+ (4, [[1, 2, 3, 4]]),
65
+ (5, []),
66
+ ]
67
+ for window, expected in window_outputs:
68
+ output = cpd._window(input_list, window)
69
+ self.assertEqual(expected, list(output))
70
+
71
+ def test_create_masked_lm_predictions(self):
72
+ tokens = ["[CLS]", "a", "##a", "b", "##b", "c", "##c", "[SEP]"]
73
+ rng = random.Random(123)
74
+ for _ in range(0, 5):
75
+ output_tokens, masked_positions, masked_labels = (
76
+ cpd.create_masked_lm_predictions(
77
+ tokens=tokens,
78
+ masked_lm_prob=1.0,
79
+ max_predictions_per_seq=3,
80
+ vocab_words=_VOCAB_WORDS,
81
+ rng=rng,
82
+ do_whole_word_mask=False,
83
+ max_ngram_size=None))
84
+ self.assertLen(masked_positions, 3)
85
+ self.assertLen(masked_labels, 3)
86
+ self.assertTokens(tokens, output_tokens, masked_positions, masked_labels)
87
+
88
+ def test_create_masked_lm_predictions_whole_word(self):
89
+ tokens = ["[CLS]", "a", "##a", "b", "##b", "c", "##c", "[SEP]"]
90
+ rng = random.Random(345)
91
+ for _ in range(0, 5):
92
+ output_tokens, masked_positions, masked_labels = (
93
+ cpd.create_masked_lm_predictions(
94
+ tokens=tokens,
95
+ masked_lm_prob=1.0,
96
+ max_predictions_per_seq=3,
97
+ vocab_words=_VOCAB_WORDS,
98
+ rng=rng,
99
+ do_whole_word_mask=True,
100
+ max_ngram_size=None))
101
+ # since we can't get exactly three tokens without breaking a word we
102
+ # only take two.
103
+ self.assertLen(masked_positions, 2)
104
+ self.assertLen(masked_labels, 2)
105
+ self.assertTokens(tokens, output_tokens, masked_positions, masked_labels)
106
+ # ensure that we took an entire word.
107
+ self.assertIn(masked_labels, [["a", "##a"], ["b", "##b"], ["c", "##c"]])
108
+
109
+ def test_create_masked_lm_predictions_ngram(self):
110
+ tokens = ["[CLS]"] + ["tok{}".format(i) for i in range(0, 512)] + ["[SEP]"]
111
+ rng = random.Random(345)
112
+ for _ in range(0, 5):
113
+ output_tokens, masked_positions, masked_labels = (
114
+ cpd.create_masked_lm_predictions(
115
+ tokens=tokens,
116
+ masked_lm_prob=1.0,
117
+ max_predictions_per_seq=76,
118
+ vocab_words=_VOCAB_WORDS,
119
+ rng=rng,
120
+ do_whole_word_mask=True,
121
+ max_ngram_size=3))
122
+ self.assertLen(masked_positions, 76)
123
+ self.assertLen(masked_labels, 76)
124
+ self.assertTokens(tokens, output_tokens, masked_positions, masked_labels)
125
+
126
+
127
+ if __name__ == "__main__":
128
+ tf.test.main()
create_xlnet_pretraining_data.py ADDED
@@ -0,0 +1,721 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The TensorFlow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Create LM TF examples for XLNet."""
16
+
17
+ import dataclasses
18
+ import json
19
+ import math
20
+ import os
21
+
22
+ import random
23
+ from typing import Iterable, Mapping, List, Optional, Tuple
24
+ import unicodedata
25
+
26
+ # Import libraries
27
+
28
+ from absl import app
29
+ from absl import flags
30
+ from absl import logging
31
+
32
+ import numpy as np
33
+ import tensorflow as tf, tf_keras
34
+
35
+ from official.nlp.tools import tokenization
36
+
37
+ special_symbols = {
38
+ "<unk>": 0,
39
+ "<s>": 1,
40
+ "</s>": 2,
41
+ "<cls>": 3,
42
+ "<sep>": 4,
43
+ "<pad>": 5,
44
+ "<mask>": 6,
45
+ "<eod>": 7,
46
+ "<eop>": 8,
47
+ }
48
+
49
+ FLAGS = flags.FLAGS
50
+
51
+ flags.DEFINE_integer("seq_length", 512,
52
+ help="Sequence length.")
53
+ flags.DEFINE_integer("reuse_length", 256,
54
+ help="Number of token that can be reused as memory. "
55
+ "Could be half of `seq_len`.")
56
+ flags.DEFINE_string("input_file", None,
57
+ "Input raw text file (or comma-separated list of files).")
58
+ flags.DEFINE_string(
59
+ "save_dir", None,
60
+ "Directory for saving processed data.")
61
+ flags.DEFINE_string("sp_model_file", "",
62
+ "The path to the model used by sentence piece tokenizer.")
63
+ flags.DEFINE_bool("use_eod_token", True,
64
+ "Whether or not to include EOD tokens.")
65
+ flags.DEFINE_bool("bi_data", True, "Whether or not to use bi-directional data.")
66
+ flags.DEFINE_bool(
67
+ "do_lower_case", True,
68
+ "Whether to lower case the input text. Should be True for uncased "
69
+ "models and False for cased models.")
70
+ flags.DEFINE_integer("per_host_batch_size", 32, "Batch size per host.")
71
+ flags.DEFINE_integer("num_cores_per_host", 16,
72
+ "The number of (TPU) cores per host.")
73
+ flags.DEFINE_string("prefix", "", "Filename prefix.")
74
+ flags.DEFINE_string("suffix", "", "Filename suffix.")
75
+
76
+ flags.DEFINE_integer("task_id", None,
77
+ "The id of the current task.")
78
+ flags.DEFINE_integer("num_tasks", None,
79
+ "The total number of tasks.")
80
+ flags.DEFINE_integer("num_passes", 1, "The number of times to run the script.")
81
+
82
+
83
+ @dataclasses.dataclass
84
+ class TrainingInstance:
85
+ """Representation of a single XLNet Pretraining instance."""
86
+ data: Iterable[int]
87
+ segment_ids: Iterable[int]
88
+ boundary_indices: Iterable[int]
89
+ label: int
90
+
91
+ def to_feature(self) -> Mapping[str, tf.train.Feature]:
92
+ feat = lambda x: tf.train.Feature(int64_list=tf.train.Int64List(value=x))
93
+ return dict(
94
+ input_word_ids=feat(self.data),
95
+ input_type_ids=feat(self.segment_ids),
96
+ boundary_indices=feat(self.boundary_indices),
97
+ label=feat([self.label]))
98
+
99
+ def to_example(self) -> tf.train.Example:
100
+ return tf.train.Example(
101
+ features=tf.train.Features(feature=self.to_feature()))
102
+
103
+ def __str__(self):
104
+ def seq_to_str(seq):
105
+ return " ".join([str(x) for x in seq])
106
+
107
+ s = ""
108
+ s += "tokens: %s\n" % seq_to_str(self.data)
109
+ s += "segment_ids: %s\n" % seq_to_str(self.segment_ids)
110
+ s += "boundary_indices: %s\n" % seq_to_str(self.boundary_indices)
111
+ s += "label: %s\n" % self.label
112
+ s += "\n"
113
+ return s
114
+
115
+ def __repr__(self):
116
+ return self.__str__()
117
+
118
+
119
+ def _preprocess_line(line: str, do_lower_case: bool = False) -> str:
120
+ """Preprocesses an individual raw text line.
121
+
122
+ This function will:
123
+ - Remove extraneous spaces.
124
+ - Replace `` with ", and '' with ".
125
+ - Replaces accents.
126
+ - Applies lower casing.
127
+
128
+ Args:
129
+ line: The input line to preprocess.
130
+ do_lower_case: Whether or not to lower case the text.
131
+
132
+ Returns:
133
+ The preprocessed line.
134
+
135
+ """
136
+ line = " ".join(line.split())
137
+ line = line.replace("``", "\"").replace("''", "\"")
138
+
139
+ # Replace accents.
140
+ line = unicodedata.normalize("NFKD", line)
141
+ line = "".join([c for c in line if not unicodedata.combining(c)])
142
+
143
+ if do_lower_case:
144
+ line = line.lower()
145
+ return line
146
+
147
+
148
+ def preprocess_and_tokenize_input_files(
149
+ input_files: Iterable[str],
150
+ tokenizer: tokenization.FullSentencePieceTokenizer,
151
+ use_eod: bool = True,
152
+ do_lower_case: bool = False,
153
+ log_example_freq: int = 100000) -> List[Tuple[np.array, np.array]]:
154
+ """Preprocesses and encodes raw text from input files.
155
+
156
+ This function preprocesses raw text and encodes them into tokens using a
157
+ `SentencePieceModel` tokenization method. This also provides the sentence
158
+ indicator for each token.
159
+
160
+ Args:
161
+ input_files: The list of input file names.
162
+ tokenizer: The SentencePiece tokenizer that has the attribute `sp_model`.
163
+ use_eod: Whether or not to use an EOD indicator. If `False`, then EOD is
164
+ not included.
165
+ do_lower_case: Whether or not to apply lower casing during raw text
166
+ preprocessing.
167
+ log_example_freq: The optional field for how many lines to process before
168
+ emitting an info log.
169
+
170
+ Returns:
171
+ The preprocessed list. Each entry in the list is a tuple consisting of
172
+ the token IDs and the sentence IDs.
173
+
174
+ """
175
+ all_data = []
176
+ eod_symbol = special_symbols["<eod>"]
177
+
178
+ total_number_of_lines = 0
179
+
180
+ # Input file format:
181
+ # (1) One sentence per line. These should ideally be actual sentences, not
182
+ # entire paragraphs or arbitrary spans of text. (Because we use the
183
+ # sentence boundaries for the "next sentence prediction" task).
184
+ # (2) Blank lines between documents. Document boundaries are needed so
185
+ # that the "next sentence prediction" task doesn't span between documents.
186
+ for input_file in input_files:
187
+ line_count = 0
188
+ logging.info("Preprocessing %s", input_file)
189
+
190
+ all_tokens = []
191
+ all_sentence_ids = []
192
+
193
+ sentence_id = True
194
+
195
+ with tf.io.gfile.GFile(input_file, "rb") as reader:
196
+ while True:
197
+ line = tokenization.convert_to_unicode(reader.readline())
198
+ if not line:
199
+ break
200
+
201
+ line_count += 1
202
+ if line_count % log_example_freq == 0:
203
+ logging.info("Loading line %d", line_count)
204
+
205
+ line = line.strip()
206
+
207
+ if not line:
208
+ if use_eod:
209
+ token_ids = [eod_symbol]
210
+ sentence_id = not sentence_id
211
+ else:
212
+ continue
213
+ else:
214
+ preprocessed_line = _preprocess_line(
215
+ line=line, do_lower_case=do_lower_case)
216
+ token_ids = tokenization.encode_ids(
217
+ sp_model=tokenizer.sp_model, text=preprocessed_line)
218
+
219
+ all_tokens.extend(token_ids)
220
+ all_sentence_ids.extend([sentence_id] * len(token_ids))
221
+ sentence_id = not sentence_id
222
+ logging.info("Finished processing %s. Number of lines: %d",
223
+ input_file, line_count)
224
+ if line_count == 0:
225
+ continue
226
+ total_number_of_lines += line_count
227
+ all_tokens = np.array(all_tokens, dtype=np.int64)
228
+ all_sentence_ids = np.array(all_sentence_ids, dtype=bool)
229
+ all_data.append((all_tokens, all_sentence_ids))
230
+
231
+ logging.info("Completed text preprocessing. Total number of lines: %d",
232
+ total_number_of_lines)
233
+ return all_data
234
+
235
+
236
+ def _reshape_to_batch_dimensions(
237
+ tokens: np.array,
238
+ sentence_ids: np.array,
239
+ per_host_batch_size: int) -> Tuple[np.array, np.array]:
240
+ """Truncates and reshapes input data with a batch major dimension.
241
+
242
+ Args:
243
+ tokens: The input token ids. This should have the same shape as
244
+ `sentence_ids`.
245
+ sentence_ids: The input sentence ids. This should have the same shape as
246
+ `token_ids`.
247
+ per_host_batch_size: The target per-host batch size.
248
+
249
+ Returns:
250
+ The tuple of reshaped tokens and sentence_ids.
251
+ """
252
+ num_steps = len(tokens) // per_host_batch_size
253
+ truncated_data_length = num_steps * per_host_batch_size
254
+
255
+ logging.info("per_host_batch_size: %d", per_host_batch_size)
256
+ logging.info("num_steps: %d", num_steps)
257
+ def truncate_and_reshape(a):
258
+ return a[:truncated_data_length].reshape((per_host_batch_size, num_steps))
259
+
260
+ return (truncate_and_reshape(tokens), truncate_and_reshape(sentence_ids))
261
+
262
+
263
+ def _create_a_and_b_segments(
264
+ tokens: np.array,
265
+ sentence_ids: np.array,
266
+ begin_index: int,
267
+ total_length: int,
268
+ no_cut_probability: float = 0.5):
269
+ """Splits segments A and B from a single instance of tokens and sentence ids.
270
+
271
+ Args:
272
+ tokens: The 1D input token ids. This represents an individual entry within a
273
+ batch.
274
+ sentence_ids: The 1D input sentence ids. This represents an individual entry
275
+ within a batch. This should be the same length as `tokens`.
276
+ begin_index: The reference beginning index to split data.
277
+ total_length: The target combined length of segments A and B.
278
+ no_cut_probability: The probability of not cutting a segment despite
279
+ a cut possibly existing.
280
+
281
+ Returns:
282
+ A tuple consisting of A data, B data, and label.
283
+
284
+ """
285
+ data_length = tokens.shape[0]
286
+ if begin_index + total_length >= data_length:
287
+ logging.info("[_create_segments]: begin_index %d + total_length %d >= "
288
+ "data_length %d", begin_index, total_length, data_length)
289
+ return None
290
+
291
+ end_index = begin_index + 1
292
+ cut_indices = []
293
+
294
+ # Identify all indices where sentence IDs change from one to the next.
295
+ while end_index < data_length:
296
+ if sentence_ids[end_index] != sentence_ids[end_index - 1]:
297
+ if end_index - begin_index >= total_length:
298
+ break
299
+ cut_indices.append(end_index)
300
+ end_index += 1
301
+
302
+ a_begin = begin_index
303
+
304
+ if not cut_indices or random.random() < no_cut_probability:
305
+ # Segments A and B are contained within the same sentence.
306
+ label = 0
307
+ if not cut_indices:
308
+ a_end = end_index
309
+ else:
310
+ a_end = random.choice(cut_indices)
311
+ b_length = max(1, total_length - (a_end - a_begin))
312
+ b_begin = random.randint(0, data_length - 1 - b_length)
313
+ b_end = b_begin + b_length
314
+
315
+ while b_begin > 0 and sentence_ids[b_begin - 1] == sentence_ids[b_begin]:
316
+ b_begin -= 1
317
+ while (b_end < data_length - 1 and
318
+ sentence_ids[b_end - 1] == sentence_ids[b_end]):
319
+ b_end += 1
320
+ else:
321
+ # Segments A and B are different sentences.
322
+ label = 1
323
+ a_end = random.choice(cut_indices)
324
+ b_begin = a_end
325
+ b_end = end_index
326
+
327
+ while a_end - a_begin + b_end - b_begin > total_length:
328
+ if a_end - a_begin > b_end - b_begin:
329
+ # Delete only the right side for the LM objective.
330
+ a_end -= 1
331
+ else:
332
+ b_end -= 1
333
+ if a_end >= data_length or b_end >= data_length:
334
+ logging.info("[_create_segments]: a_end %d or b_end %d >= data_length %d",
335
+ a_end, b_end, data_length)
336
+ return None
337
+
338
+ a_data = tokens[a_begin: a_end]
339
+ b_data = tokens[b_begin: b_end]
340
+ return a_data, b_data, label
341
+
342
+
343
+ def _is_functional_piece(piece: str) -> bool:
344
+ return piece != "<unk>" and piece.startswith("<") and piece.endswith(">")
345
+
346
+
347
+ def _is_start_piece(piece: str) -> bool:
348
+ special_pieces = set(list('!"#$%&\"()*+,-./:;?@[\\]^_`{|}~'))
349
+ if (piece.startswith("▁") or piece in special_pieces):
350
+ return True
351
+ else:
352
+ return False
353
+
354
+
355
+ def _get_boundary_indices(
356
+ data: np.array,
357
+ tokenizer: tokenization.FullSentencePieceTokenizer) -> np.array:
358
+ """Gets the boundary indices of whole words."""
359
+ seq_length = len(data)
360
+ boundary_indices = []
361
+ for index, piece in enumerate(tokenizer.convert_ids_to_tokens(data.tolist())):
362
+ if _is_start_piece(piece) and not _is_functional_piece(piece):
363
+ boundary_indices.append(index)
364
+ boundary_indices.append(seq_length)
365
+ return boundary_indices
366
+
367
+
368
+ def _convert_tokens_to_instances(
369
+ tokens: np.array,
370
+ sentence_ids: np.array,
371
+ per_host_batch_size: int,
372
+ seq_length: int,
373
+ reuse_length: int,
374
+ bi_data: bool,
375
+ tokenizer: tokenization.FullSentencePieceTokenizer,
376
+ num_cores_per_host: int = 0,
377
+ logging_frequency: int = 500) -> List[TrainingInstance]:
378
+ """Converts tokens and sentence IDs into individual training instances.
379
+
380
+ The format of data in the XLNet pretraining task is very similar to the
381
+ BERT pretraining task. Two segments A and B are randomly sampled, and the
382
+ contatenation of A and B into a single sequence is used to perform
383
+ language modeling.
384
+
385
+ To create an XLNet Pretraining instance from a single long sequence, S:
386
+ - Create a segment of length `reuse_length`. This first segment represents
387
+ past tokens. During modeling, this segment is used to cache obtained
388
+ content representations for the segment recurrence mechanism.
389
+ - Similar to BERT, create a segment of length `seq_length` - `reuse_length`
390
+ composed of A and B segments.
391
+ For XLNet, the order is "A", "SEP", "B", "SEP", "CLS".
392
+
393
+ Args:
394
+ tokens: All tokens concatenated into a single list.
395
+ sentence_ids: All sentence IDs concatenated into a single list.
396
+ per_host_batch_size: The target batch size per host.
397
+ seq_length: The max sequence length.
398
+ reuse_length: The number of tokens to use from the previous segment.
399
+ bi_data: Whether or not to use bidirectional data.
400
+ tokenizer: The SentencePiece tokenizer that has the attribute `sp_model`.
401
+ num_cores_per_host: The number of cores per host. This is required if
402
+ `bi_data` = `True`.
403
+ logging_frequency: The frequency at which to log status updates.
404
+
405
+ Returns:
406
+ A list of `TrainingInstance` objects.
407
+ """
408
+ instances = []
409
+
410
+ per_core_batch_size = (per_host_batch_size // num_cores_per_host
411
+ if bi_data else None)
412
+
413
+ if bi_data:
414
+ logging.info("Bi-directional data enabled.")
415
+ assert per_host_batch_size % (2 * num_cores_per_host) == 0
416
+ forward_tokens, forward_sentence_ids = _reshape_to_batch_dimensions(
417
+ tokens=tokens,
418
+ sentence_ids=sentence_ids,
419
+ per_host_batch_size=per_host_batch_size // 2)
420
+ forward_data_shape = (num_cores_per_host, 1, per_core_batch_size // 2, -1)
421
+
422
+ forward_tokens = forward_tokens.reshape(forward_data_shape)
423
+ forward_sentence_ids = forward_sentence_ids.reshape(forward_data_shape)
424
+
425
+ backwards_tokens = forward_tokens[:, :, :, ::-1]
426
+ backwards_sentence_ids = forward_sentence_ids[:, :, :, ::-1]
427
+
428
+ tokens = np.concatenate([forward_tokens, backwards_tokens], 1).reshape(
429
+ per_host_batch_size, -1)
430
+ sentence_ids = np.concatenate(
431
+ [forward_sentence_ids, backwards_sentence_ids]).reshape(
432
+ per_host_batch_size, -1)
433
+ else:
434
+ logging.info("Bi-directional data disabled.")
435
+ tokens, sentence_ids = _reshape_to_batch_dimensions(
436
+ tokens=tokens,
437
+ sentence_ids=sentence_ids,
438
+ per_host_batch_size=per_host_batch_size)
439
+
440
+ logging.info("Tokens shape: %s", tokens.shape)
441
+
442
+ data_length = tokens.shape[1]
443
+ sep = np.array([special_symbols["<sep>"]], dtype=np.int64)
444
+ cls = np.array([special_symbols["<cls>"]], dtype=np.int64)
445
+ # 2 sep, 1 cls
446
+ num_special_tokens = 3
447
+
448
+ data_index = 0
449
+ batch_number = 0
450
+ step_size = reuse_length if reuse_length else seq_length
451
+ num_batches = math.ceil(data_length / step_size)
452
+
453
+ while data_index + seq_length <= data_length:
454
+ if batch_number % logging_frequency == 0:
455
+ logging.info("Processing batch %d of %d", batch_number, num_batches)
456
+
457
+ for batch_index in range(per_host_batch_size):
458
+ previous_segment_tokens = tokens[
459
+ batch_index, data_index: data_index + reuse_length]
460
+
461
+ results = _create_a_and_b_segments(
462
+ tokens=tokens[batch_index],
463
+ sentence_ids=sentence_ids[batch_index],
464
+ begin_index=data_index + reuse_length,
465
+ total_length=seq_length - reuse_length - num_special_tokens)
466
+
467
+ if results is None:
468
+ logging.info("Stopping at data index: %d", data_index)
469
+ break
470
+ a_data, b_data, label = results
471
+
472
+ data = np.concatenate(
473
+ [previous_segment_tokens, a_data, sep, b_data, sep, cls])
474
+ a_length = a_data.shape[0]
475
+ b_length = b_data.shape[0]
476
+ segment_ids = ([0] * (reuse_length + a_length) + [0]
477
+ + [1] * b_length + [1] + [2])
478
+ boundary_indices = _get_boundary_indices(tokenizer=tokenizer,
479
+ data=data)
480
+ assert len(data) == seq_length
481
+ assert len(segment_ids) == seq_length
482
+ assert len(boundary_indices) > 0 # pylint: disable=g-explicit-length-test
483
+
484
+ instances.append(TrainingInstance(
485
+ data=data,
486
+ segment_ids=segment_ids,
487
+ boundary_indices=boundary_indices,
488
+ label=label))
489
+ batch_number += 1
490
+ data_index += step_size
491
+ return instances
492
+
493
+
494
+ def write_instances_to_tfrecord(
495
+ instances: Iterable[TrainingInstance],
496
+ save_path: str):
497
+ """Writes instances to TFRecord."""
498
+ record_writer = tf.io.TFRecordWriter(save_path)
499
+ logging.info("Start writing to %s.", save_path)
500
+
501
+ for i, instance in enumerate(instances):
502
+ if i < 5:
503
+ logging.info("Instance %d: %s", i, str(instance))
504
+ record_writer.write(instance.to_example().SerializeToString())
505
+
506
+ record_writer.close()
507
+ logging.info("Done writing %s.", save_path)
508
+
509
+
510
+ def shuffle_and_combine_preprocessed_data(
511
+ all_data: List[Tuple[np.array, np.array]]) -> Tuple[np.array, np.array]:
512
+ """Shuffles and combines preprocessed token/sentence IDs from documents."""
513
+ document_permutation = np.random.permutation(len(all_data))
514
+
515
+ previous_sentence_id = None
516
+
517
+ all_tokens, all_sentence_ids = [], []
518
+ for document_index in document_permutation:
519
+ tokens, sentence_ids = all_data[document_index]
520
+ # pylint: disable=g-explicit-length-test
521
+ if len(tokens) == 0:
522
+ continue
523
+ if (previous_sentence_id is not None and
524
+ sentence_ids[0] == previous_sentence_id):
525
+ sentence_ids = np.logical_not(sentence_ids)
526
+
527
+ all_tokens.append(tokens)
528
+ all_sentence_ids.append(sentence_ids)
529
+
530
+ previous_sentence_id = sentence_ids[-1]
531
+
532
+ return np.concatenate(all_tokens), np.concatenate(all_sentence_ids)
533
+
534
+
535
+ def get_tfrecord_name(
536
+ per_host_batch_size: int,
537
+ num_cores_per_host: int,
538
+ seq_length: int,
539
+ bi_data: bool,
540
+ reuse_length: int,
541
+ do_lower_case: bool,
542
+ use_eod_token: bool,
543
+ prefix: str = "",
544
+ suffix: str = "",
545
+ pass_id: int = 0,
546
+ num_passes: int = 1,
547
+ task_id: int = None,
548
+ num_tasks: int = None) -> str:
549
+ """Formats the resulting TFRecord name based on provided inputs."""
550
+ components = []
551
+ if prefix:
552
+ components.append(prefix)
553
+ components.append("seqlen-{}".format(seq_length))
554
+ if reuse_length == 0:
555
+ components.append("memless")
556
+ else:
557
+ components.append("reuse-{}".format(reuse_length))
558
+ components.append("bs-{}".format(per_host_batch_size))
559
+ components.append("cores-{}".format(num_cores_per_host))
560
+
561
+ if do_lower_case:
562
+ components.append("uncased")
563
+ else:
564
+ components.append("cased")
565
+ if use_eod_token:
566
+ components.append("eod")
567
+ if bi_data:
568
+ components.append("bi")
569
+ else:
570
+ components.append("uni")
571
+
572
+ if suffix:
573
+ components.append(suffix)
574
+
575
+ s = "_".join(components) + ".tfrecord"
576
+ if num_passes == 1 and task_id is None:
577
+ return s
578
+
579
+ if task_id is None:
580
+ num_tasks = 1
581
+ task_id = 0
582
+
583
+ current_shard = task_id * num_passes + pass_id
584
+ total_shards = num_tasks * num_passes
585
+ return s + "-{}-of-{}".format(current_shard, total_shards)
586
+
587
+
588
+ def create_tfrecords(
589
+ tokenizer: tokenization.FullSentencePieceTokenizer,
590
+ input_file_or_files: str,
591
+ use_eod_token: bool,
592
+ do_lower_case: bool,
593
+ per_host_batch_size: int,
594
+ seq_length: int,
595
+ reuse_length: int,
596
+ bi_data: bool,
597
+ num_cores_per_host: int,
598
+ save_dir: str,
599
+ prefix: str = "",
600
+ suffix: str = "",
601
+ num_tasks: Optional[int] = None,
602
+ task_id: Optional[int] = None,
603
+ num_passes: int = 1):
604
+ """Runs the end-to-end preprocessing pipeline."""
605
+
606
+ logging.info("Input configuration:")
607
+ logging.info("input file(s): %s", input_file_or_files)
608
+ logging.info("use_eod_token: %s", use_eod_token)
609
+ logging.info("do_lower_case: %s", do_lower_case)
610
+ logging.info("per_host_batch_size: %d", per_host_batch_size)
611
+ logging.info("seq_length: %d", seq_length)
612
+ logging.info("reuse_length: %d", reuse_length)
613
+ logging.info("bi_data: %s", bi_data)
614
+ logging.info("num_cores_per_host: %d", num_cores_per_host)
615
+ logging.info("save_dir: %s", save_dir)
616
+ if task_id is not None and num_tasks is not None:
617
+ logging.info("task_id: %d", task_id)
618
+ logging.info("num_tasks: %d", num_tasks)
619
+
620
+ input_files = []
621
+ for input_pattern in input_file_or_files.split(","):
622
+ input_files.extend(tf.io.gfile.glob(input_pattern))
623
+
624
+ logging.info("*** Reading from input files ***")
625
+ for input_file in input_files:
626
+ logging.info(" %s", input_file)
627
+
628
+ logging.info("Shuffling the files with a fixed random seed.")
629
+ np.random.shuffle(input_files)
630
+ if num_tasks is not None:
631
+ assert task_id is not None
632
+ logging.info("Total number of input files: %d", len(input_files))
633
+ logging.info("Splitting into %d shards of %d files each.",
634
+ num_tasks, len(input_files) // num_tasks)
635
+ input_files = input_files[task_id::num_tasks]
636
+
637
+ all_data = preprocess_and_tokenize_input_files(
638
+ input_files=input_files,
639
+ tokenizer=tokenizer,
640
+ use_eod=use_eod_token,
641
+ do_lower_case=do_lower_case)
642
+ for pass_id in range(num_passes):
643
+ logging.info("Beginning pass %d of %d", pass_id, num_passes)
644
+ tokens, sentence_ids = shuffle_and_combine_preprocessed_data(all_data)
645
+
646
+ assert len(tokens) == len(sentence_ids)
647
+
648
+ filename = get_tfrecord_name(
649
+ per_host_batch_size=per_host_batch_size,
650
+ num_cores_per_host=num_cores_per_host,
651
+ seq_length=seq_length,
652
+ bi_data=bi_data,
653
+ use_eod_token=use_eod_token,
654
+ reuse_length=reuse_length,
655
+ do_lower_case=do_lower_case,
656
+ prefix=prefix,
657
+ suffix=suffix,
658
+ pass_id=pass_id,
659
+ num_passes=num_passes,
660
+ num_tasks=num_tasks,
661
+ task_id=task_id)
662
+ save_path = os.path.join(save_dir, filename)
663
+ if os.path.exists(save_path):
664
+ # If the path already exists, then we were probably preempted but
665
+ # previously wrote this file.
666
+ logging.info("%s already exists, skipping this batch.", save_path)
667
+ else:
668
+ instances = _convert_tokens_to_instances(
669
+ tokenizer=tokenizer,
670
+ tokens=tokens,
671
+ sentence_ids=sentence_ids,
672
+ per_host_batch_size=per_host_batch_size,
673
+ seq_length=seq_length,
674
+ reuse_length=reuse_length,
675
+ bi_data=bi_data,
676
+ num_cores_per_host=num_cores_per_host)
677
+ write_instances_to_tfrecord(instances=instances, save_path=save_path)
678
+
679
+ if task_id is None or task_id == 0:
680
+ corpus_info = {
681
+ "vocab_size": 32000,
682
+ "per_host_batch_size": per_host_batch_size,
683
+ "num_cores_per_host": num_cores_per_host,
684
+ "seq_length": seq_length,
685
+ "reuse_length": reuse_length,
686
+ "do_lower_case": do_lower_case,
687
+ "bi_data": bi_data,
688
+ "use_eod_token": use_eod_token,
689
+ }
690
+ corpus_fname = os.path.basename(filename) + ".json"
691
+ corpus_destination = os.path.join(save_dir, corpus_fname)
692
+ logging.info("Saving corpus info to %s", corpus_destination)
693
+
694
+ with tf.io.gfile.GFile(corpus_destination, "w") as fp:
695
+ json.dump(corpus_info, fp)
696
+
697
+
698
+ def main(_):
699
+ tokenizer = tokenization.FullSentencePieceTokenizer(FLAGS.sp_model_file)
700
+ create_tfrecords(
701
+ tokenizer=tokenizer,
702
+ input_file_or_files=FLAGS.input_file,
703
+ use_eod_token=FLAGS.use_eod_token,
704
+ do_lower_case=FLAGS.do_lower_case,
705
+ per_host_batch_size=FLAGS.per_host_batch_size,
706
+ seq_length=FLAGS.seq_length,
707
+ reuse_length=FLAGS.reuse_length,
708
+ bi_data=FLAGS.bi_data,
709
+ num_cores_per_host=FLAGS.num_cores_per_host,
710
+ save_dir=FLAGS.save_dir,
711
+ prefix=FLAGS.prefix,
712
+ suffix=FLAGS.suffix,
713
+ num_tasks=FLAGS.num_tasks,
714
+ task_id=FLAGS.task_id,
715
+ num_passes=FLAGS.num_passes)
716
+
717
+
718
+ if __name__ == "__main__":
719
+ np.random.seed(0)
720
+ logging.set_verbosity(logging.INFO)
721
+ app.run(main)
create_xlnet_pretraining_data_test.py ADDED
@@ -0,0 +1,355 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The TensorFlow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Tests for official.nlp.data.create_xlnet_pretraining_data."""
16
+ import os
17
+ import tempfile
18
+ from typing import List
19
+
20
+ from absl import logging
21
+ from absl.testing import parameterized
22
+
23
+ import numpy as np
24
+ import tensorflow as tf, tf_keras
25
+
26
+ from official.nlp.data import create_xlnet_pretraining_data as cpd
27
+
28
+ _VOCAB_WORDS = ["vocab_1", "vocab_2"]
29
+
30
+
31
+ # pylint: disable=invalid-name
32
+ def _create_files(
33
+ temp_dir: str, file_contents: List[List[str]]) -> List[str]:
34
+ """Writes arbitrary documents into files."""
35
+ root_dir = tempfile.mkdtemp(dir=temp_dir)
36
+ files = []
37
+
38
+ for i, file_content in enumerate(file_contents):
39
+ destination = os.path.join(root_dir, "%d.txt" % i)
40
+ with open(destination, "wb") as f:
41
+ for line in file_content:
42
+ f.write(line.encode("utf-8"))
43
+ files.append(destination)
44
+ return files
45
+
46
+
47
+ def _get_mock_tokenizer():
48
+ """Creates a mock tokenizer."""
49
+
50
+ class MockSpieceModel:
51
+ """Mock Spiece model for testing."""
52
+
53
+ def __init__(self):
54
+ self._special_piece_to_id = {
55
+ "<unk>": 0,
56
+ }
57
+ for piece in set(list('!"#$%&\"()*+,-./:;?@[\\]^_`{|}~')):
58
+ self._special_piece_to_id[piece] = 1
59
+
60
+ def EncodeAsPieces(self, inputs: str) -> List[str]:
61
+ return inputs
62
+
63
+ def SampleEncodeAsPieces(self,
64
+ inputs: str,
65
+ nbest_size: int,
66
+ theta: float) -> List[str]:
67
+ del nbest_size, theta
68
+ return inputs
69
+
70
+ def PieceToId(self, piece: str) -> int:
71
+ return ord(piece[0])
72
+
73
+ def IdToPiece(self, id_: int) -> str:
74
+ return chr(id_) * 3
75
+
76
+ class Tokenizer:
77
+ """Mock Tokenizer for testing."""
78
+
79
+ def __init__(self):
80
+ self.sp_model = MockSpieceModel()
81
+
82
+ def convert_ids_to_tokens(self, ids: List[int]) -> List[str]:
83
+ return [self.sp_model.IdToPiece(id_) for id_ in ids]
84
+
85
+ return Tokenizer()
86
+
87
+
88
+ class PreprocessDataTest(tf.test.TestCase):
89
+
90
+ def test_remove_extraneous_space(self):
91
+ line = " abc "
92
+ output = cpd._preprocess_line(line)
93
+ self.assertEqual(output, "abc")
94
+
95
+ def test_symbol_replacements(self):
96
+ self.assertEqual(cpd._preprocess_line("``abc``"), "\"abc\"")
97
+ self.assertEqual(cpd._preprocess_line("''abc''"), "\"abc\"")
98
+
99
+ def test_accent_replacements(self):
100
+ self.assertEqual(cpd._preprocess_line("Γ₯bc"), "abc")
101
+
102
+ def test_lower_case(self):
103
+ self.assertEqual(cpd._preprocess_line("ABC", do_lower_case=True), "abc")
104
+
105
+ def test_end_to_end(self):
106
+ self.assertEqual(
107
+ cpd._preprocess_line("HelLo ``wΓ³rLd``", do_lower_case=True),
108
+ "hello \"world\"")
109
+
110
+
111
+ class PreprocessAndTokenizeFilesTest(tf.test.TestCase):
112
+
113
+ def test_basic_end_to_end(self):
114
+ documents = [
115
+ [
116
+ "This is sentence 1.\n",
117
+ "This is sentence 2.\n",
118
+ "Sentence 3 is what this is.\n",
119
+ ],
120
+ [
121
+ "This is the second document.\n",
122
+ "This is the second line of the second document.\n"
123
+ ],
124
+ ]
125
+ input_files = _create_files(temp_dir=self.get_temp_dir(),
126
+ file_contents=documents)
127
+ all_data = cpd.preprocess_and_tokenize_input_files(
128
+ input_files=input_files,
129
+ tokenizer=_get_mock_tokenizer(),
130
+ log_example_freq=1)
131
+
132
+ self.assertEqual(len(all_data), len(documents))
133
+ for token_ids, sentence_ids in all_data:
134
+ self.assertEqual(len(token_ids), len(sentence_ids))
135
+
136
+ def test_basic_correctness(self):
137
+ documents = [["a\n", "b\n", "c\n"]]
138
+ input_files = _create_files(temp_dir=self.get_temp_dir(),
139
+ file_contents=documents)
140
+ all_data = cpd.preprocess_and_tokenize_input_files(
141
+ input_files=input_files,
142
+ tokenizer=_get_mock_tokenizer(),
143
+ log_example_freq=1)
144
+
145
+ token_ids, sentence_ids = all_data[0]
146
+
147
+ self.assertAllClose(token_ids, [97, 98, 99])
148
+ self.assertAllClose(sentence_ids, [True, False, True])
149
+
150
+ def test_correctness_with_spaces_and_accents(self):
151
+ documents = [[
152
+ " Γ₯ \n",
153
+ "b \n",
154
+ " c \n",
155
+ ]]
156
+ input_files = _create_files(temp_dir=self.get_temp_dir(),
157
+ file_contents=documents)
158
+ all_data = cpd.preprocess_and_tokenize_input_files(
159
+ input_files=input_files,
160
+ tokenizer=_get_mock_tokenizer(),
161
+ log_example_freq=1)
162
+
163
+ token_ids, sentence_ids = all_data[0]
164
+
165
+ self.assertAllClose(token_ids, [97, 98, 99])
166
+ self.assertAllClose(sentence_ids, [True, False, True])
167
+
168
+
169
+ class BatchReshapeTests(tf.test.TestCase):
170
+
171
+ def test_basic_functionality(self):
172
+ per_host_batch_size = 3
173
+ mock_shape = (20,)
174
+
175
+ # Should truncate and reshape.
176
+ expected_result_shape = (3, 6)
177
+
178
+ tokens = np.zeros(mock_shape)
179
+ sentence_ids = np.zeros(mock_shape)
180
+
181
+ reshaped_data = cpd._reshape_to_batch_dimensions(
182
+ tokens=tokens,
183
+ sentence_ids=sentence_ids,
184
+ per_host_batch_size=per_host_batch_size)
185
+ for values in reshaped_data:
186
+ self.assertEqual(len(values.flatten()) % per_host_batch_size, 0)
187
+ self.assertAllClose(values.shape, expected_result_shape)
188
+
189
+
190
+ class CreateSegmentsTest(tf.test.TestCase):
191
+
192
+ def test_basic_functionality(self):
193
+ data_length = 10
194
+ tokens = np.arange(data_length)
195
+ sentence_ids = np.concatenate([np.zeros(data_length // 2),
196
+ np.ones(data_length // 2)])
197
+ begin_index = 0
198
+ total_length = 8
199
+ a_data, b_data, label = cpd._create_a_and_b_segments(
200
+ tokens=tokens,
201
+ sentence_ids=sentence_ids,
202
+ begin_index=begin_index,
203
+ total_length=total_length,
204
+ no_cut_probability=0.)
205
+ self.assertAllClose(a_data, [0, 1, 2, 3])
206
+ self.assertAllClose(b_data, [5, 6, 7, 8])
207
+ self.assertEqual(label, 1)
208
+
209
+ def test_no_cut(self):
210
+ data_length = 10
211
+ tokens = np.arange(data_length)
212
+ sentence_ids = np.zeros(data_length)
213
+
214
+ begin_index = 0
215
+ total_length = 8
216
+ a_data, b_data, label = cpd._create_a_and_b_segments(
217
+ tokens=tokens,
218
+ sentence_ids=sentence_ids,
219
+ begin_index=begin_index,
220
+ total_length=total_length,
221
+ no_cut_probability=0.)
222
+ self.assertGreater(len(a_data), 0)
223
+ self.assertGreater(len(b_data), 0)
224
+ self.assertEqual(label, 0)
225
+
226
+ def test_no_cut_with_probability(self):
227
+ data_length = 10
228
+ tokens = np.arange(data_length)
229
+ sentence_ids = np.concatenate([np.zeros(data_length // 2),
230
+ np.ones(data_length // 2)])
231
+ begin_index = 0
232
+ total_length = 8
233
+ a_data, b_data, label = cpd._create_a_and_b_segments(
234
+ tokens=tokens,
235
+ sentence_ids=sentence_ids,
236
+ begin_index=begin_index,
237
+ total_length=total_length,
238
+ no_cut_probability=1.)
239
+ self.assertGreater(len(a_data), 0)
240
+ self.assertGreater(len(b_data), 0)
241
+ self.assertEqual(label, 0)
242
+
243
+
244
+ class CreateInstancesTest(tf.test.TestCase):
245
+ """Tests conversions of Token/Sentence IDs to training instances."""
246
+
247
+ def test_basic(self):
248
+ data_length = 12
249
+ tokens = np.arange(data_length)
250
+ sentence_ids = np.zeros(data_length)
251
+ seq_length = 8
252
+ instances = cpd._convert_tokens_to_instances(
253
+ tokens=tokens,
254
+ sentence_ids=sentence_ids,
255
+ per_host_batch_size=2,
256
+ seq_length=seq_length,
257
+ reuse_length=4,
258
+ tokenizer=_get_mock_tokenizer(),
259
+ bi_data=False,
260
+ num_cores_per_host=1,
261
+ logging_frequency=1)
262
+ for instance in instances:
263
+ self.assertEqual(len(instance.data), seq_length)
264
+ self.assertEqual(len(instance.segment_ids), seq_length)
265
+ self.assertIsInstance(instance.label, int)
266
+ self.assertIsInstance(instance.boundary_indices, list)
267
+
268
+
269
+ class TFRecordPathTests(tf.test.TestCase):
270
+
271
+ def test_basic(self):
272
+ base_kwargs = dict(
273
+ per_host_batch_size=1,
274
+ num_cores_per_host=1,
275
+ seq_length=2,
276
+ reuse_length=1)
277
+
278
+ config1 = dict(
279
+ prefix="test",
280
+ suffix="",
281
+ bi_data=True,
282
+ use_eod_token=False,
283
+ do_lower_case=True)
284
+ config1.update(base_kwargs)
285
+ expectation1 = "test_seqlen-2_reuse-1_bs-1_cores-1_uncased_bi.tfrecord"
286
+ self.assertEqual(cpd.get_tfrecord_name(**config1), expectation1)
287
+
288
+ config2 = dict(
289
+ prefix="",
290
+ suffix="test",
291
+ bi_data=False,
292
+ use_eod_token=False,
293
+ do_lower_case=False)
294
+ config2.update(base_kwargs)
295
+ expectation2 = "seqlen-2_reuse-1_bs-1_cores-1_cased_uni_test.tfrecord"
296
+ self.assertEqual(cpd.get_tfrecord_name(**config2), expectation2)
297
+
298
+ config3 = dict(
299
+ prefix="",
300
+ suffix="",
301
+ use_eod_token=True,
302
+ bi_data=False,
303
+ do_lower_case=True)
304
+ config3.update(base_kwargs)
305
+ expectation3 = "seqlen-2_reuse-1_bs-1_cores-1_uncased_eod_uni.tfrecord"
306
+ self.assertEqual(cpd.get_tfrecord_name(**config3), expectation3)
307
+
308
+
309
+ class TestCreateTFRecords(parameterized.TestCase, tf.test.TestCase):
310
+
311
+ @parameterized.named_parameters(
312
+ ("bi_data_only", True, False, False),
313
+ ("eod_token_only", False, True, True),
314
+ ("lower_case_only", False, False, True),
315
+ ("all_enabled", True, True, True),
316
+ )
317
+ def test_end_to_end(self,
318
+ bi_data: bool,
319
+ use_eod_token: bool,
320
+ do_lower_case: bool):
321
+ tokenizer = _get_mock_tokenizer()
322
+
323
+ num_documents = 5
324
+ sentences_per_document = 10
325
+ document_length = 50
326
+
327
+ documents = [
328
+ ["a " * document_length for _ in range(sentences_per_document)]
329
+ for _ in range(num_documents)]
330
+
331
+ save_dir = tempfile.mkdtemp(dir=self.get_temp_dir())
332
+ files = _create_files(temp_dir=self.get_temp_dir(), file_contents=documents)
333
+
334
+ cpd.create_tfrecords(
335
+ tokenizer=tokenizer,
336
+ input_file_or_files=",".join(files),
337
+ use_eod_token=use_eod_token,
338
+ do_lower_case=do_lower_case,
339
+ per_host_batch_size=8,
340
+ seq_length=8,
341
+ reuse_length=4,
342
+ bi_data=bi_data,
343
+ num_cores_per_host=2,
344
+ save_dir=save_dir)
345
+
346
+ self.assertTrue(any(filter(lambda x: x.endswith(".json"),
347
+ os.listdir(save_dir))))
348
+ self.assertTrue(any(filter(lambda x: x.endswith(".tfrecord"),
349
+ os.listdir(save_dir))))
350
+
351
+
352
+ if __name__ == "__main__":
353
+ np.random.seed(0)
354
+ logging.set_verbosity(logging.INFO)
355
+ tf.test.main()
data_loader.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The TensorFlow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """An abstraction that NLP models define input pipelines."""
16
+
17
+ import abc
18
+ from typing import Optional
19
+
20
+ import tensorflow as tf, tf_keras
21
+
22
+
23
+ class DataLoader(metaclass=abc.ABCMeta):
24
+ """An abstract class defining the APIs for tf.data input pipeline."""
25
+
26
+ @abc.abstractmethod
27
+ def load(
28
+ self,
29
+ input_context: Optional[tf.distribute.InputContext] = None
30
+ ) -> tf.data.Dataset:
31
+ """Implements DataLoader load method.
32
+
33
+ Builds the entire input pipeline inside the load method. Users can define
34
+ states inside the DataLoader class and returns a tf.data dataset
35
+ object.
36
+
37
+ Args:
38
+ input_context: This is a context class that is passed to the user's input
39
+ function and contains information about the compute replicas and input
40
+ pipelines. This object is used for multi-host inputs and passed by the
41
+ distribution strategy.
42
+
43
+ Returns:
44
+ A per-host tf.data dataset. Note that, we usually create the distributed
45
+ dataset through the load method, so we should not directly return a
46
+ distributed dataset here.
47
+ """
48
+ pass
data_loader_factory.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The TensorFlow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """A global factory to access NLP registered data loaders."""
16
+
17
+ from official.core import registry
18
+
19
+ _REGISTERED_DATA_LOADER_CLS = {}
20
+
21
+
22
+ def register_data_loader_cls(data_config_cls):
23
+ """Decorates a factory of DataLoader for lookup by a subclass of DataConfig.
24
+
25
+ This decorator supports registration of data loaders as follows:
26
+
27
+ ```
28
+ @dataclasses.dataclass
29
+ class MyDataConfig(DataConfig):
30
+ # Add fields here.
31
+ pass
32
+
33
+ @register_data_loader_cls(MyDataConfig)
34
+ class MyDataLoader:
35
+ # Inherits def __init__(self, data_config).
36
+ pass
37
+
38
+ my_data_config = MyDataConfig()
39
+
40
+ # Returns MyDataLoader(my_data_config).
41
+ my_loader = get_data_loader(my_data_config)
42
+ ```
43
+
44
+ Args:
45
+ data_config_cls: a subclass of DataConfig (*not* an instance
46
+ of DataConfig).
47
+
48
+ Returns:
49
+ A callable for use as class decorator that registers the decorated class
50
+ for creation from an instance of data_config_cls.
51
+ """
52
+ return registry.register(_REGISTERED_DATA_LOADER_CLS, data_config_cls)
53
+
54
+
55
+ def get_data_loader(data_config):
56
+ """Creates a data_loader from data_config."""
57
+ return registry.lookup(_REGISTERED_DATA_LOADER_CLS, data_config.__class__)(
58
+ data_config)
data_loader_factory_test.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The TensorFlow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Tests for official.nlp.data.data_loader_factory."""
16
+
17
+ import dataclasses
18
+ import tensorflow as tf, tf_keras
19
+
20
+ from official.core import config_definitions as cfg
21
+ from official.nlp.data import data_loader_factory
22
+
23
+
24
+ @dataclasses.dataclass
25
+ class MyDataConfig(cfg.DataConfig):
26
+ is_training: bool = True
27
+
28
+
29
+ @data_loader_factory.register_data_loader_cls(MyDataConfig)
30
+ class MyDataLoader:
31
+
32
+ def __init__(self, params):
33
+ self.params = params
34
+
35
+
36
+ class DataLoaderFactoryTest(tf.test.TestCase):
37
+
38
+ def test_register_and_load(self):
39
+ train_config = MyDataConfig()
40
+ train_loader = data_loader_factory.get_data_loader(train_config)
41
+ self.assertTrue(train_loader.params.is_training)
42
+
43
+
44
+ if __name__ == "__main__":
45
+ tf.test.main()
dual_encoder_dataloader.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The TensorFlow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Loads dataset for the dual encoder (retrieval) task."""
16
+ import dataclasses
17
+ import functools
18
+ import itertools
19
+ from typing import Iterable, Mapping, Optional, Tuple
20
+
21
+ import tensorflow as tf, tf_keras
22
+ import tensorflow_hub as hub
23
+
24
+ from official.common import dataset_fn
25
+ from official.core import config_definitions as cfg
26
+ from official.core import input_reader
27
+ from official.nlp.data import data_loader
28
+ from official.nlp.data import data_loader_factory
29
+ from official.nlp.modeling import layers
30
+
31
+
32
+ @dataclasses.dataclass
33
+ class DualEncoderDataConfig(cfg.DataConfig):
34
+ """Data config for dual encoder task (tasks/dual_encoder)."""
35
+ # Either set `input_path`...
36
+ input_path: str = ''
37
+ # ...or `tfds_name` and `tfds_split` to specify input.
38
+ tfds_name: str = ''
39
+ tfds_split: str = ''
40
+ global_batch_size: int = 32
41
+ # Either build preprocessing with Python code by specifying these values...
42
+ vocab_file: str = ''
43
+ lower_case: bool = True
44
+ # ...or load preprocessing from a SavedModel at this location.
45
+ preprocessing_hub_module_url: str = ''
46
+
47
+ left_text_fields: Tuple[str] = ('left_input',)
48
+ right_text_fields: Tuple[str] = ('right_input',)
49
+ is_training: bool = True
50
+ seq_length: int = 128
51
+ file_type: str = 'tfrecord'
52
+
53
+
54
+ @data_loader_factory.register_data_loader_cls(DualEncoderDataConfig)
55
+ class DualEncoderDataLoader(data_loader.DataLoader):
56
+ """A class to load dataset for dual encoder task (tasks/dual_encoder)."""
57
+
58
+ def __init__(self, params):
59
+ if bool(params.tfds_name) == bool(params.input_path):
60
+ raise ValueError('Must specify either `tfds_name` and `tfds_split` '
61
+ 'or `input_path`.')
62
+ if bool(params.vocab_file) == bool(params.preprocessing_hub_module_url):
63
+ raise ValueError('Must specify exactly one of vocab_file (with matching '
64
+ 'lower_case flag) or preprocessing_hub_module_url.')
65
+ self._params = params
66
+ self._seq_length = params.seq_length
67
+ self._left_text_fields = params.left_text_fields
68
+ self._right_text_fields = params.right_text_fields
69
+
70
+ if params.preprocessing_hub_module_url:
71
+ preprocessing_hub_module = hub.load(params.preprocessing_hub_module_url)
72
+ self._tokenizer = preprocessing_hub_module.tokenize
73
+ self._pack_inputs = functools.partial(
74
+ preprocessing_hub_module.bert_pack_inputs,
75
+ seq_length=params.seq_length)
76
+ else:
77
+ self._tokenizer = layers.BertTokenizer(
78
+ vocab_file=params.vocab_file, lower_case=params.lower_case)
79
+ self._pack_inputs = layers.BertPackInputs(
80
+ seq_length=params.seq_length,
81
+ special_tokens_dict=self._tokenizer.get_special_tokens_dict())
82
+
83
+ def _decode(self, record: tf.Tensor):
84
+ """Decodes a serialized tf.Example."""
85
+ name_to_features = {
86
+ x: tf.io.FixedLenFeature([], tf.string)
87
+ for x in itertools.chain(
88
+ *[self._left_text_fields, self._right_text_fields])
89
+ }
90
+ example = tf.io.parse_single_example(record, name_to_features)
91
+
92
+ # tf.Example only supports tf.int64, but the TPU only supports tf.int32.
93
+ # So cast all int64 to int32.
94
+ for name in example:
95
+ t = example[name]
96
+ if t.dtype == tf.int64:
97
+ t = tf.cast(t, tf.int32)
98
+ example[name] = t
99
+
100
+ return example
101
+
102
+ def _bert_tokenize(
103
+ self, record: Mapping[str, tf.Tensor],
104
+ text_fields: Iterable[str]) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor]:
105
+ """Tokenize the input in text_fields using BERT tokenizer.
106
+
107
+ Args:
108
+ record: A tfexample record contains the features.
109
+ text_fields: A list of fields to be tokenzied.
110
+
111
+ Returns:
112
+ The tokenized features in a tuple of (input_word_ids, input_mask,
113
+ input_type_ids).
114
+ """
115
+ segments_text = [record[x] for x in text_fields]
116
+ segments_tokens = [self._tokenizer(s) for s in segments_text]
117
+ segments = [tf.cast(x.merge_dims(1, 2), tf.int32) for x in segments_tokens]
118
+ return self._pack_inputs(segments)
119
+
120
+ def _bert_preprocess(
121
+ self, record: Mapping[str, tf.Tensor]) -> Mapping[str, tf.Tensor]:
122
+ """Perform the bert word piece tokenization for left and right inputs."""
123
+
124
+ def _switch_prefix(string, old, new):
125
+ if string.startswith(old): return new + string[len(old):]
126
+ raise ValueError('Expected {} to start with {}'.format(string, old))
127
+
128
+ def _switch_key_prefix(d, old, new):
129
+ return {_switch_prefix(key, old, new): value for key, value in d.items()} # pytype: disable=attribute-error # trace-all-classes
130
+
131
+ model_inputs = _switch_key_prefix(
132
+ self._bert_tokenize(record, self._left_text_fields),
133
+ 'input_', 'left_')
134
+ model_inputs.update(_switch_key_prefix(
135
+ self._bert_tokenize(record, self._right_text_fields),
136
+ 'input_', 'right_'))
137
+ return model_inputs
138
+
139
+ def load(self, input_context: Optional[tf.distribute.InputContext] = None):
140
+ """Returns a tf.dataset.Dataset."""
141
+ reader = input_reader.InputReader(
142
+ params=self._params,
143
+ # Skip `decoder_fn` for tfds input.
144
+ decoder_fn=self._decode if self._params.input_path else None,
145
+ dataset_fn=dataset_fn.pick_dataset_fn(self._params.file_type),
146
+ postprocess_fn=self._bert_preprocess)
147
+ return reader.read(input_context)
dual_encoder_dataloader_test.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The TensorFlow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Tests for official.nlp.data.dual_encoder_dataloader."""
16
+ import os
17
+
18
+ from absl.testing import parameterized
19
+ import tensorflow as tf, tf_keras
20
+
21
+ from official.nlp.data import dual_encoder_dataloader
22
+
23
+
24
+ _LEFT_FEATURE_NAME = 'left_input'
25
+ _RIGHT_FEATURE_NAME = 'right_input'
26
+
27
+
28
+ def _create_fake_dataset(output_path):
29
+ """Creates a fake dataset contains examples for training a dual encoder model.
30
+
31
+ The created dataset contains examples with two byteslist features keyed by
32
+ _LEFT_FEATURE_NAME and _RIGHT_FEATURE_NAME.
33
+
34
+ Args:
35
+ output_path: The output path of the fake dataset.
36
+ """
37
+ def create_str_feature(values):
38
+ return tf.train.Feature(bytes_list=tf.train.BytesList(value=values))
39
+
40
+ with tf.io.TFRecordWriter(output_path) as writer:
41
+ for _ in range(100):
42
+ features = {}
43
+ features[_LEFT_FEATURE_NAME] = create_str_feature([b'hello world.'])
44
+ features[_RIGHT_FEATURE_NAME] = create_str_feature([b'world hello.'])
45
+
46
+ tf_example = tf.train.Example(
47
+ features=tf.train.Features(feature=features))
48
+ writer.write(tf_example.SerializeToString())
49
+
50
+
51
+ def _make_vocab_file(vocab, output_path):
52
+ with tf.io.gfile.GFile(output_path, 'w') as f:
53
+ f.write('\n'.join(vocab + ['']))
54
+
55
+
56
+ class DualEncoderDataTest(tf.test.TestCase, parameterized.TestCase):
57
+
58
+ def test_load_dataset(self):
59
+ seq_length = 16
60
+ batch_size = 10
61
+ train_data_path = os.path.join(self.get_temp_dir(), 'train.tf_record')
62
+ vocab_path = os.path.join(self.get_temp_dir(), 'vocab.txt')
63
+
64
+ _create_fake_dataset(train_data_path)
65
+ _make_vocab_file(
66
+ ['[PAD]', '[UNK]', '[CLS]', '[SEP]', 'he', '#llo', 'world'], vocab_path)
67
+
68
+ data_config = dual_encoder_dataloader.DualEncoderDataConfig(
69
+ input_path=train_data_path,
70
+ seq_length=seq_length,
71
+ vocab_file=vocab_path,
72
+ lower_case=True,
73
+ left_text_fields=(_LEFT_FEATURE_NAME,),
74
+ right_text_fields=(_RIGHT_FEATURE_NAME,),
75
+ global_batch_size=batch_size)
76
+ dataset = dual_encoder_dataloader.DualEncoderDataLoader(
77
+ data_config).load()
78
+ features = next(iter(dataset))
79
+ self.assertCountEqual(
80
+ ['left_word_ids', 'left_mask', 'left_type_ids', 'right_word_ids',
81
+ 'right_mask', 'right_type_ids'],
82
+ features.keys())
83
+ self.assertEqual(features['left_word_ids'].shape, (batch_size, seq_length))
84
+ self.assertEqual(features['left_mask'].shape, (batch_size, seq_length))
85
+ self.assertEqual(features['left_type_ids'].shape, (batch_size, seq_length))
86
+ self.assertEqual(features['right_word_ids'].shape, (batch_size, seq_length))
87
+ self.assertEqual(features['right_mask'].shape, (batch_size, seq_length))
88
+ self.assertEqual(features['right_type_ids'].shape, (batch_size, seq_length))
89
+
90
+ @parameterized.parameters(False, True)
91
+ def test_load_tfds(self, use_preprocessing_hub):
92
+ seq_length = 16
93
+ batch_size = 10
94
+ if use_preprocessing_hub:
95
+ vocab_path = ''
96
+ preprocessing_hub = (
97
+ 'https://tfhub.dev/tensorflow/bert_multi_cased_preprocess/3')
98
+ else:
99
+ vocab_path = os.path.join(self.get_temp_dir(), 'vocab.txt')
100
+ _make_vocab_file(
101
+ ['[PAD]', '[UNK]', '[CLS]', '[SEP]', 'he', '#llo', 'world'],
102
+ vocab_path)
103
+ preprocessing_hub = ''
104
+
105
+ data_config = dual_encoder_dataloader.DualEncoderDataConfig(
106
+ tfds_name='para_crawl/enmt',
107
+ tfds_split='train',
108
+ seq_length=seq_length,
109
+ vocab_file=vocab_path,
110
+ lower_case=True,
111
+ left_text_fields=('en',),
112
+ right_text_fields=('mt',),
113
+ preprocessing_hub_module_url=preprocessing_hub,
114
+ global_batch_size=batch_size)
115
+ dataset = dual_encoder_dataloader.DualEncoderDataLoader(
116
+ data_config).load()
117
+ features = next(iter(dataset))
118
+ self.assertCountEqual(
119
+ ['left_word_ids', 'left_mask', 'left_type_ids', 'right_word_ids',
120
+ 'right_mask', 'right_type_ids'],
121
+ features.keys())
122
+ self.assertEqual(features['left_word_ids'].shape, (batch_size, seq_length))
123
+ self.assertEqual(features['left_mask'].shape, (batch_size, seq_length))
124
+ self.assertEqual(features['left_type_ids'].shape, (batch_size, seq_length))
125
+ self.assertEqual(features['right_word_ids'].shape, (batch_size, seq_length))
126
+ self.assertEqual(features['right_mask'].shape, (batch_size, seq_length))
127
+ self.assertEqual(features['right_type_ids'].shape, (batch_size, seq_length))
128
+
129
+
130
+ if __name__ == '__main__':
131
+ tf.test.main()
pretrain_dataloader.py ADDED
@@ -0,0 +1,589 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The TensorFlow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Loads dataset for the BERT pretraining task."""
16
+ import dataclasses
17
+ from typing import Mapping, Optional
18
+
19
+ from absl import logging
20
+
21
+ import numpy as np
22
+ import tensorflow as tf, tf_keras
23
+ from official.common import dataset_fn
24
+ from official.core import config_definitions as cfg
25
+ from official.core import input_reader
26
+ from official.nlp.data import data_loader
27
+ from official.nlp.data import data_loader_factory
28
+
29
+
30
+ @dataclasses.dataclass
31
+ class BertPretrainDataConfig(cfg.DataConfig):
32
+ """Data config for BERT pretraining task (tasks/masked_lm)."""
33
+ input_path: str = ''
34
+ global_batch_size: int = 512
35
+ is_training: bool = True
36
+ seq_length: int = 512
37
+ max_predictions_per_seq: int = 76
38
+ use_next_sentence_label: bool = True
39
+ use_position_id: bool = False
40
+ # Historically, BERT implementations take `input_ids` and `segment_ids` as
41
+ # feature names. Inside the TF Model Garden implementation, the Keras model
42
+ # inputs are set as `input_word_ids` and `input_type_ids`. When
43
+ # v2_feature_names is True, the data loader assumes the tf.Examples use
44
+ # `input_word_ids` and `input_type_ids` as keys.
45
+ use_v2_feature_names: bool = False
46
+ file_type: str = 'tfrecord'
47
+
48
+
49
+ @data_loader_factory.register_data_loader_cls(BertPretrainDataConfig)
50
+ class BertPretrainDataLoader(data_loader.DataLoader):
51
+ """A class to load dataset for bert pretraining task."""
52
+
53
+ def __init__(self, params):
54
+ """Inits `BertPretrainDataLoader` class.
55
+
56
+ Args:
57
+ params: A `BertPretrainDataConfig` object.
58
+ """
59
+ self._params = params
60
+ self._seq_length = params.seq_length
61
+ self._max_predictions_per_seq = params.max_predictions_per_seq
62
+ self._use_next_sentence_label = params.use_next_sentence_label
63
+ self._use_position_id = params.use_position_id
64
+
65
+ def _name_to_features(self):
66
+ name_to_features = {
67
+ 'input_mask':
68
+ tf.io.FixedLenFeature([self._seq_length], tf.int64),
69
+ 'masked_lm_positions':
70
+ tf.io.FixedLenFeature([self._max_predictions_per_seq], tf.int64),
71
+ 'masked_lm_ids':
72
+ tf.io.FixedLenFeature([self._max_predictions_per_seq], tf.int64),
73
+ 'masked_lm_weights':
74
+ tf.io.FixedLenFeature([self._max_predictions_per_seq], tf.float32),
75
+ }
76
+ if self._params.use_v2_feature_names:
77
+ name_to_features.update({
78
+ 'input_word_ids': tf.io.FixedLenFeature([self._seq_length], tf.int64),
79
+ 'input_type_ids': tf.io.FixedLenFeature([self._seq_length], tf.int64),
80
+ })
81
+ else:
82
+ name_to_features.update({
83
+ 'input_ids': tf.io.FixedLenFeature([self._seq_length], tf.int64),
84
+ 'segment_ids': tf.io.FixedLenFeature([self._seq_length], tf.int64),
85
+ })
86
+ if self._use_next_sentence_label:
87
+ name_to_features['next_sentence_labels'] = tf.io.FixedLenFeature([1],
88
+ tf.int64)
89
+ if self._use_position_id:
90
+ name_to_features['position_ids'] = tf.io.FixedLenFeature(
91
+ [self._seq_length], tf.int64)
92
+ return name_to_features
93
+
94
+ def _decode(self, record: tf.Tensor):
95
+ """Decodes a serialized tf.Example."""
96
+ name_to_features = self._name_to_features()
97
+ example = tf.io.parse_single_example(record, name_to_features)
98
+
99
+ # tf.Example only supports tf.int64, but the TPU only supports tf.int32.
100
+ # So cast all int64 to int32.
101
+ for name in list(example.keys()):
102
+ t = example[name]
103
+ if t.dtype == tf.int64:
104
+ t = tf.cast(t, tf.int32)
105
+ example[name] = t
106
+
107
+ return example
108
+
109
+ def _parse(self, record: Mapping[str, tf.Tensor]):
110
+ """Parses raw tensors into a dict of tensors to be consumed by the model."""
111
+ x = {
112
+ 'input_mask': record['input_mask'],
113
+ 'masked_lm_positions': record['masked_lm_positions'],
114
+ 'masked_lm_ids': record['masked_lm_ids'],
115
+ 'masked_lm_weights': record['masked_lm_weights'],
116
+ }
117
+ if self._params.use_v2_feature_names:
118
+ x['input_word_ids'] = record['input_word_ids']
119
+ x['input_type_ids'] = record['input_type_ids']
120
+ else:
121
+ x['input_word_ids'] = record['input_ids']
122
+ x['input_type_ids'] = record['segment_ids']
123
+ if self._use_next_sentence_label:
124
+ x['next_sentence_labels'] = record['next_sentence_labels']
125
+ if self._use_position_id:
126
+ x['position_ids'] = record['position_ids']
127
+
128
+ return x
129
+
130
+ def load(self, input_context: Optional[tf.distribute.InputContext] = None):
131
+ """Returns a tf.dataset.Dataset."""
132
+ reader = input_reader.InputReader(
133
+ params=self._params,
134
+ dataset_fn=dataset_fn.pick_dataset_fn(self._params.file_type),
135
+ decoder_fn=self._decode,
136
+ parser_fn=self._parse)
137
+ return reader.read(input_context)
138
+
139
+
140
+ @dataclasses.dataclass
141
+ class XLNetPretrainDataConfig(cfg.DataConfig):
142
+ """Data config for XLNet pretraining task.
143
+
144
+ Attributes:
145
+ input_path: See base class.
146
+ global_batch_size: See base class.
147
+ is_training: See base class.
148
+ seq_length: The length of each sequence.
149
+ max_predictions_per_seq: The number of predictions per sequence.
150
+ reuse_length: The number of tokens in a previous segment to reuse. This
151
+ should be the same value used during pretrain data creation.
152
+ sample_strategy: The strategy used to sample factorization permutations.
153
+ Possible values: 'single_token', 'whole_word', 'token_span', 'word_span'.
154
+ min_num_tokens: The minimum number of tokens to sample in a span. This is
155
+ used when `sample_strategy` is 'token_span'.
156
+ max_num_tokens: The maximum number of tokens to sample in a span. This is
157
+ used when `sample_strategy` is 'token_span'.
158
+ min_num_words: The minimum number of words to sample in a span. This is used
159
+ when `sample_strategy` is 'word_span'.
160
+ max_num_words: The maximum number of words to sample in a span. This is used
161
+ when `sample_strategy` is 'word_span'.
162
+ permutation_size: The length of the longest permutation. This can be set to
163
+ `reuse_length`. This should NOT be greater than `reuse_length`, otherwise
164
+ this may introduce data leaks.
165
+ leak_ratio: The percentage of masked tokens that are leaked.
166
+ segment_sep_id: The ID of the SEP token used when preprocessing the dataset.
167
+ segment_cls_id: The ID of the CLS token used when preprocessing the dataset.
168
+ """
169
+ input_path: str = ''
170
+ global_batch_size: int = 512
171
+ is_training: bool = True
172
+ seq_length: int = 512
173
+ max_predictions_per_seq: int = 76
174
+ reuse_length: int = 256
175
+ sample_strategy: str = 'word_span'
176
+ min_num_tokens: int = 1
177
+ max_num_tokens: int = 5
178
+ min_num_words: int = 1
179
+ max_num_words: int = 5
180
+ permutation_size: int = 256
181
+ leak_ratio: float = 0.1
182
+ segment_sep_id: int = 4
183
+ segment_cls_id: int = 3
184
+
185
+
186
+ @data_loader_factory.register_data_loader_cls(XLNetPretrainDataConfig)
187
+ class XLNetPretrainDataLoader(data_loader.DataLoader):
188
+ """A class to load dataset for xlnet pretraining task."""
189
+
190
+ def __init__(self, params: XLNetPretrainDataConfig):
191
+ """Inits `XLNetPretrainDataLoader` class.
192
+
193
+ Args:
194
+ params: A `XLNetPretrainDataConfig` object.
195
+ """
196
+ self._params = params
197
+ self._seq_length = params.seq_length
198
+ self._max_predictions_per_seq = params.max_predictions_per_seq
199
+ self._reuse_length = params.reuse_length
200
+ self._num_replicas_in_sync = None
201
+ self._permutation_size = params.permutation_size
202
+ self._sep_id = params.segment_sep_id
203
+ self._cls_id = params.segment_cls_id
204
+ self._sample_strategy = params.sample_strategy
205
+ self._leak_ratio = params.leak_ratio
206
+
207
+ def _decode(self, record: tf.Tensor):
208
+ """Decodes a serialized tf.Example."""
209
+ name_to_features = {
210
+ 'input_word_ids': tf.io.FixedLenFeature([self._seq_length], tf.int64),
211
+ 'input_type_ids': tf.io.FixedLenFeature([self._seq_length], tf.int64),
212
+ 'boundary_indices': tf.io.VarLenFeature(tf.int64),
213
+ }
214
+ example = tf.io.parse_single_example(record, name_to_features)
215
+
216
+ # tf.Example only supports tf.int64, but the TPU only supports tf.int32.
217
+ # So cast all int64 to int32.
218
+ for name in list(example.keys()):
219
+ t = example[name]
220
+ if t.dtype == tf.int64:
221
+ t = tf.cast(t, tf.int32)
222
+ example[name] = t
223
+
224
+ return example
225
+
226
+ def _parse(self, record: Mapping[str, tf.Tensor]):
227
+ """Parses raw tensors into a dict of tensors to be consumed by the model."""
228
+ x = {}
229
+
230
+ inputs = record['input_word_ids']
231
+ x['input_type_ids'] = record['input_type_ids']
232
+
233
+ if self._sample_strategy in ['whole_word', 'word_span']:
234
+ boundary = tf.sparse.to_dense(record['boundary_indices'])
235
+ else:
236
+ boundary = None
237
+
238
+ input_mask = self._online_sample_mask(inputs=inputs, boundary=boundary)
239
+
240
+ if self._reuse_length > 0:
241
+ if self._permutation_size > self._reuse_length:
242
+ logging.warning(
243
+ '`permutation_size` is greater than `reuse_length` (%d > %d).'
244
+ 'This may introduce data leakage.', self._permutation_size,
245
+ self._reuse_length)
246
+
247
+ # Enable the memory mechanism.
248
+ # Permute the reuse and non-reuse segments separately.
249
+ non_reuse_len = self._seq_length - self._reuse_length
250
+ if not (self._reuse_length % self._permutation_size == 0 and
251
+ non_reuse_len % self._permutation_size == 0):
252
+ raise ValueError('`reuse_length` and `seq_length` should both be '
253
+ 'a multiple of `permutation_size`.')
254
+
255
+ # Creates permutation mask and target mask for the first reuse_len tokens.
256
+ # The tokens in this part are reused from the last sequence.
257
+ perm_mask_0, target_mask_0, tokens_0, masked_0 = self._get_factorization(
258
+ inputs=inputs[:self._reuse_length],
259
+ input_mask=input_mask[:self._reuse_length])
260
+
261
+ # Creates permutation mask and target mask for the rest of tokens in
262
+ # current example, which are concatenation of two new segments.
263
+ perm_mask_1, target_mask_1, tokens_1, masked_1 = self._get_factorization(
264
+ inputs[self._reuse_length:], input_mask[self._reuse_length:])
265
+
266
+ perm_mask_0 = tf.concat([
267
+ perm_mask_0,
268
+ tf.zeros([self._reuse_length, non_reuse_len], dtype=tf.int32)
269
+ ],
270
+ axis=1)
271
+ perm_mask_1 = tf.concat([
272
+ tf.ones([non_reuse_len, self._reuse_length], dtype=tf.int32),
273
+ perm_mask_1
274
+ ],
275
+ axis=1)
276
+ perm_mask = tf.concat([perm_mask_0, perm_mask_1], axis=0)
277
+ target_mask = tf.concat([target_mask_0, target_mask_1], axis=0)
278
+ tokens = tf.concat([tokens_0, tokens_1], axis=0)
279
+ masked_tokens = tf.concat([masked_0, masked_1], axis=0)
280
+ else:
281
+ # Disable the memory mechanism.
282
+ if self._seq_length % self._permutation_size != 0:
283
+ raise ValueError('`seq_length` should be a multiple of '
284
+ '`permutation_size`.')
285
+ # Permute the entire sequence together
286
+ perm_mask, target_mask, tokens, masked_tokens = self._get_factorization(
287
+ inputs=inputs, input_mask=input_mask)
288
+ x['permutation_mask'] = tf.reshape(perm_mask,
289
+ [self._seq_length, self._seq_length])
290
+ x['input_word_ids'] = tokens
291
+ x['masked_tokens'] = masked_tokens
292
+
293
+ target = tokens
294
+ if self._max_predictions_per_seq is not None:
295
+ indices = tf.range(self._seq_length, dtype=tf.int32)
296
+ bool_target_mask = tf.cast(target_mask, tf.bool)
297
+ indices = tf.boolean_mask(indices, bool_target_mask)
298
+
299
+ # account for extra padding due to CLS/SEP.
300
+ actual_num_predict = tf.shape(indices)[0]
301
+ pad_len = self._max_predictions_per_seq - actual_num_predict
302
+
303
+ target_mapping = tf.one_hot(indices, self._seq_length, dtype=tf.int32)
304
+ paddings = tf.zeros([pad_len, self._seq_length],
305
+ dtype=target_mapping.dtype)
306
+ target_mapping = tf.concat([target_mapping, paddings], axis=0)
307
+ x['target_mapping'] = tf.reshape(
308
+ target_mapping, [self._max_predictions_per_seq, self._seq_length])
309
+
310
+ target = tf.boolean_mask(target, bool_target_mask)
311
+ paddings = tf.zeros([pad_len], dtype=target.dtype)
312
+ target = tf.concat([target, paddings], axis=0)
313
+ x['target'] = tf.reshape(target, [self._max_predictions_per_seq])
314
+
315
+ target_mask = tf.concat([
316
+ tf.ones([actual_num_predict], dtype=tf.int32),
317
+ tf.zeros([pad_len], dtype=tf.int32)
318
+ ],
319
+ axis=0)
320
+ x['target_mask'] = tf.reshape(target_mask,
321
+ [self._max_predictions_per_seq])
322
+ else:
323
+ x['target'] = tf.reshape(target, [self._seq_length])
324
+ x['target_mask'] = tf.reshape(target_mask, [self._seq_length])
325
+ return x
326
+
327
+ def _index_pair_to_mask(self, begin_indices: tf.Tensor,
328
+ end_indices: tf.Tensor,
329
+ inputs: tf.Tensor) -> tf.Tensor:
330
+ """Converts beginning and end indices into an actual mask."""
331
+ non_func_mask = tf.logical_and(
332
+ tf.not_equal(inputs, self._sep_id), tf.not_equal(inputs, self._cls_id))
333
+ all_indices = tf.where(
334
+ non_func_mask, tf.range(self._seq_length, dtype=tf.int32),
335
+ tf.constant(-1, shape=[self._seq_length], dtype=tf.int32))
336
+ candidate_matrix = tf.cast(
337
+ tf.logical_and(all_indices[None, :] >= begin_indices[:, None],
338
+ all_indices[None, :] < end_indices[:, None]), tf.float32)
339
+ cumsum_matrix = tf.reshape(
340
+ tf.cumsum(tf.reshape(candidate_matrix, [-1])), [-1, self._seq_length])
341
+ masked_matrix = tf.cast(cumsum_matrix <= self._max_predictions_per_seq,
342
+ tf.float32)
343
+ target_mask = tf.reduce_sum(candidate_matrix * masked_matrix, axis=0)
344
+ return tf.cast(target_mask, tf.bool)
345
+
346
+ def _single_token_mask(self, inputs: tf.Tensor) -> tf.Tensor:
347
+ """Samples individual tokens as prediction targets."""
348
+ all_indices = tf.range(self._seq_length, dtype=tf.int32)
349
+ non_func_mask = tf.logical_and(
350
+ tf.not_equal(inputs, self._sep_id), tf.not_equal(inputs, self._cls_id))
351
+ non_func_indices = tf.boolean_mask(all_indices, non_func_mask)
352
+
353
+ masked_pos = tf.random.shuffle(non_func_indices)
354
+ masked_pos = tf.sort(masked_pos[:self._max_predictions_per_seq])
355
+
356
+ sparse_indices = tf.stack([tf.zeros_like(masked_pos), masked_pos], axis=-1)
357
+ sparse_indices = tf.cast(sparse_indices, tf.int64)
358
+
359
+ sparse_indices = tf.sparse.SparseTensor(
360
+ sparse_indices,
361
+ values=tf.ones_like(masked_pos),
362
+ dense_shape=(1, self._seq_length))
363
+
364
+ target_mask = tf.sparse.to_dense(sp_input=sparse_indices, default_value=0)
365
+
366
+ return tf.squeeze(tf.cast(target_mask, tf.bool))
367
+
368
+ def _whole_word_mask(self, inputs: tf.Tensor,
369
+ boundary: tf.Tensor) -> tf.Tensor:
370
+ """Samples whole words as prediction targets."""
371
+ pair_indices = tf.concat([boundary[:-1, None], boundary[1:, None]], axis=1)
372
+ cand_pair_indices = tf.random.shuffle(
373
+ pair_indices)[:self._max_predictions_per_seq]
374
+ begin_indices = cand_pair_indices[:, 0]
375
+ end_indices = cand_pair_indices[:, 1]
376
+
377
+ return self._index_pair_to_mask(
378
+ begin_indices=begin_indices, end_indices=end_indices, inputs=inputs)
379
+
380
+ def _token_span_mask(self, inputs: tf.Tensor) -> tf.Tensor:
381
+ """Samples token spans as prediction targets."""
382
+ min_num_tokens = self._params.min_num_tokens
383
+ max_num_tokens = self._params.max_num_tokens
384
+
385
+ mask_alpha = self._seq_length / self._max_predictions_per_seq
386
+ round_to_int = lambda x: tf.cast(tf.round(x), tf.int32)
387
+
388
+ # Sample span lengths from a zipf distribution
389
+ span_len_seq = np.arange(min_num_tokens, max_num_tokens + 1)
390
+ probs = np.array([1.0 / (i + 1) for i in span_len_seq])
391
+
392
+ probs /= np.sum(probs)
393
+ logits = tf.constant(np.log(probs), dtype=tf.float32)
394
+ span_lens = tf.random.categorical(
395
+ logits=logits[None],
396
+ num_samples=self._max_predictions_per_seq,
397
+ dtype=tf.int32,
398
+ )[0] + min_num_tokens
399
+
400
+ # Sample the ratio [0.0, 1.0) of left context lengths
401
+ span_lens_float = tf.cast(span_lens, tf.float32)
402
+ left_ratio = tf.random.uniform(
403
+ shape=[self._max_predictions_per_seq], minval=0.0, maxval=1.0)
404
+ left_ctx_len = left_ratio * span_lens_float * (mask_alpha - 1)
405
+ left_ctx_len = round_to_int(left_ctx_len)
406
+
407
+ # Compute the offset from left start to the right end
408
+ right_offset = round_to_int(span_lens_float * mask_alpha) - left_ctx_len
409
+
410
+ # Get the actual begin and end indices
411
+ begin_indices = (
412
+ tf.cumsum(left_ctx_len) + tf.cumsum(right_offset, exclusive=True))
413
+ end_indices = begin_indices + span_lens
414
+
415
+ # Remove out of range indices
416
+ valid_idx_mask = end_indices < self._seq_length
417
+ begin_indices = tf.boolean_mask(begin_indices, valid_idx_mask)
418
+ end_indices = tf.boolean_mask(end_indices, valid_idx_mask)
419
+
420
+ # Shuffle valid indices
421
+ num_valid = tf.cast(tf.shape(begin_indices)[0], tf.int32)
422
+ order = tf.random.shuffle(tf.range(num_valid, dtype=tf.int32))
423
+ begin_indices = tf.gather(begin_indices, order)
424
+ end_indices = tf.gather(end_indices, order)
425
+
426
+ return self._index_pair_to_mask(
427
+ begin_indices=begin_indices, end_indices=end_indices, inputs=inputs)
428
+
429
+ def _word_span_mask(self, inputs: tf.Tensor, boundary: tf.Tensor):
430
+ """Sample whole word spans as prediction targets."""
431
+ min_num_words = self._params.min_num_words
432
+ max_num_words = self._params.max_num_words
433
+
434
+ # Note: 1.2 is the token-to-word ratio
435
+ mask_alpha = self._seq_length / self._max_predictions_per_seq / 1.2
436
+ round_to_int = lambda x: tf.cast(tf.round(x), tf.int32)
437
+
438
+ # Sample span lengths from a zipf distribution
439
+ span_len_seq = np.arange(min_num_words, max_num_words + 1)
440
+ probs = np.array([1.0 / (i + 1) for i in span_len_seq])
441
+ probs /= np.sum(probs)
442
+ logits = tf.constant(np.log(probs), dtype=tf.float32)
443
+
444
+ # Sample `num_predict` words here: note that this is over sampling
445
+ span_lens = tf.random.categorical(
446
+ logits=logits[None],
447
+ num_samples=self._max_predictions_per_seq,
448
+ dtype=tf.int32,
449
+ )[0] + min_num_words
450
+
451
+ # Sample the ratio [0.0, 1.0) of left context lengths
452
+ span_lens_float = tf.cast(span_lens, tf.float32)
453
+ left_ratio = tf.random.uniform(
454
+ shape=[self._max_predictions_per_seq], minval=0.0, maxval=1.0)
455
+ left_ctx_len = left_ratio * span_lens_float * (mask_alpha - 1)
456
+
457
+ left_ctx_len = round_to_int(left_ctx_len)
458
+ right_offset = round_to_int(span_lens_float * mask_alpha) - left_ctx_len
459
+
460
+ begin_indices = (
461
+ tf.cumsum(left_ctx_len) + tf.cumsum(right_offset, exclusive=True))
462
+ end_indices = begin_indices + span_lens
463
+
464
+ # Remove out of range indices
465
+ max_boundary_index = tf.cast(tf.shape(boundary)[0] - 1, tf.int32)
466
+ valid_idx_mask = end_indices < max_boundary_index
467
+ begin_indices = tf.boolean_mask(begin_indices, valid_idx_mask)
468
+ end_indices = tf.boolean_mask(end_indices, valid_idx_mask)
469
+
470
+ begin_indices = tf.gather(boundary, begin_indices)
471
+ end_indices = tf.gather(boundary, end_indices)
472
+
473
+ # Shuffle valid indices
474
+ num_valid = tf.cast(tf.shape(begin_indices)[0], tf.int32)
475
+ order = tf.random.shuffle(tf.range(num_valid, dtype=tf.int32))
476
+ begin_indices = tf.gather(begin_indices, order)
477
+ end_indices = tf.gather(end_indices, order)
478
+
479
+ return self._index_pair_to_mask(
480
+ begin_indices=begin_indices, end_indices=end_indices, inputs=inputs)
481
+
482
+ def _online_sample_mask(self, inputs: tf.Tensor,
483
+ boundary: tf.Tensor) -> tf.Tensor:
484
+ """Samples target positions for predictions.
485
+
486
+ Descriptions of each strategy:
487
+ - 'single_token': Samples individual tokens as prediction targets.
488
+ - 'token_span': Samples spans of tokens as prediction targets.
489
+ - 'whole_word': Samples individual words as prediction targets.
490
+ - 'word_span': Samples spans of words as prediction targets.
491
+
492
+ Args:
493
+ inputs: The input tokens.
494
+ boundary: The `int` Tensor of indices indicating whole word boundaries.
495
+ This is used in 'whole_word' and 'word_span'
496
+
497
+ Returns:
498
+ The sampled `bool` input mask.
499
+
500
+ Raises:
501
+ `ValueError`: if `max_predictions_per_seq` is not set or if boundary is
502
+ not provided for 'whole_word' and 'word_span' sample strategies.
503
+ """
504
+ if self._max_predictions_per_seq is None:
505
+ raise ValueError('`max_predictions_per_seq` must be set.')
506
+
507
+ if boundary is None and 'word' in self._sample_strategy:
508
+ raise ValueError('`boundary` must be provided for {} strategy'.format(
509
+ self._sample_strategy))
510
+
511
+ if self._sample_strategy == 'single_token':
512
+ return self._single_token_mask(inputs)
513
+ elif self._sample_strategy == 'token_span':
514
+ return self._token_span_mask(inputs)
515
+ elif self._sample_strategy == 'whole_word':
516
+ return self._whole_word_mask(inputs, boundary)
517
+ elif self._sample_strategy == 'word_span':
518
+ return self._word_span_mask(inputs, boundary)
519
+ else:
520
+ raise NotImplementedError('Invalid sample strategy.')
521
+
522
+ def _get_factorization(self, inputs: tf.Tensor, input_mask: tf.Tensor):
523
+ """Samples a permutation of the factorization order.
524
+
525
+ Args:
526
+ inputs: the input tokens.
527
+ input_mask: the `bool` Tensor of the same shape as `inputs`. If `True`,
528
+ then this means select for partial prediction.
529
+
530
+ Returns:
531
+ perm_mask: An `int32` Tensor of shape [seq_length, seq_length] consisting
532
+ of 0s and 1s. If perm_mask[i][j] == 0, then this means that the i-th
533
+ token (in original order) cannot attend to the jth attention token.
534
+ target_mask: An `int32` Tensor of shape [seq_len] consisting of 0s and 1s.
535
+ If target_mask[i] == 1, then the i-th token needs to be predicted and
536
+ the mask will be used as input. This token will be included in the loss.
537
+ If target_mask[i] == 0, then the token (or [SEP], [CLS]) will be used as
538
+ input. This token will not be included in the loss.
539
+ tokens: int32 Tensor of shape [seq_length].
540
+ masked_tokens: int32 Tensor of shape [seq_length].
541
+ """
542
+ factorization_length = tf.shape(inputs)[0]
543
+ # Generate permutation indices
544
+ index = tf.range(factorization_length, dtype=tf.int32)
545
+ index = tf.transpose(tf.reshape(index, [-1, self._permutation_size]))
546
+ index = tf.random.shuffle(index)
547
+ index = tf.reshape(tf.transpose(index), [-1])
548
+
549
+ input_mask = tf.cast(input_mask, tf.bool)
550
+
551
+ # non-functional tokens
552
+ non_func_tokens = tf.logical_not(
553
+ tf.logical_or(
554
+ tf.equal(inputs, self._sep_id), tf.equal(inputs, self._cls_id)))
555
+ masked_tokens = tf.logical_and(input_mask, non_func_tokens)
556
+ non_masked_or_func_tokens = tf.logical_not(masked_tokens)
557
+
558
+ smallest_index = -2 * tf.ones([factorization_length], dtype=tf.int32)
559
+
560
+ # Similar to BERT, randomly leak some masked tokens
561
+ if self._leak_ratio > 0:
562
+ leak_tokens = tf.logical_and(
563
+ masked_tokens,
564
+ tf.random.uniform([factorization_length], maxval=1.0) <
565
+ self._leak_ratio)
566
+ can_attend_self = tf.logical_or(non_masked_or_func_tokens, leak_tokens)
567
+ else:
568
+ can_attend_self = non_masked_or_func_tokens
569
+ to_index = tf.where(can_attend_self, smallest_index, index)
570
+ from_index = tf.where(can_attend_self, to_index + 1, to_index)
571
+
572
+ # For masked tokens, can attend if i > j
573
+ # For context tokens, always can attend each other
574
+ can_attend = from_index[:, None] > to_index[None, :]
575
+
576
+ perm_mask = tf.cast(can_attend, tf.int32)
577
+
578
+ # Only masked tokens are included in the loss
579
+ target_mask = tf.cast(masked_tokens, tf.int32)
580
+
581
+ return perm_mask, target_mask, inputs, masked_tokens
582
+
583
+ def load(self, input_context: Optional[tf.distribute.InputContext] = None):
584
+ """Returns a tf.dataset.Dataset."""
585
+ if input_context:
586
+ self._num_replicas_in_sync = input_context.num_replicas_in_sync
587
+ reader = input_reader.InputReader(
588
+ params=self._params, decoder_fn=self._decode, parser_fn=self._parse)
589
+ return reader.read(input_context)
pretrain_dataloader_test.py ADDED
@@ -0,0 +1,242 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The TensorFlow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Tests for official.nlp.data.pretrain_dataloader."""
16
+ import itertools
17
+ import os
18
+
19
+ from absl.testing import parameterized
20
+ import numpy as np
21
+ import tensorflow as tf, tf_keras
22
+
23
+ from official.nlp.data import pretrain_dataloader
24
+
25
+
26
+ def create_int_feature(values):
27
+ f = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values)))
28
+ return f
29
+
30
+
31
+ def _create_fake_bert_dataset(
32
+ output_path,
33
+ seq_length,
34
+ max_predictions_per_seq,
35
+ use_position_id,
36
+ use_next_sentence_label,
37
+ use_v2_feature_names=False):
38
+ """Creates a fake dataset."""
39
+ writer = tf.io.TFRecordWriter(output_path)
40
+
41
+ def create_float_feature(values):
42
+ f = tf.train.Feature(float_list=tf.train.FloatList(value=list(values)))
43
+ return f
44
+
45
+ for _ in range(100):
46
+ features = {}
47
+ input_ids = np.random.randint(100, size=(seq_length))
48
+ features["input_mask"] = create_int_feature(np.ones_like(input_ids))
49
+ if use_v2_feature_names:
50
+ features["input_word_ids"] = create_int_feature(input_ids)
51
+ features["input_type_ids"] = create_int_feature(np.ones_like(input_ids))
52
+ else:
53
+ features["input_ids"] = create_int_feature(input_ids)
54
+ features["segment_ids"] = create_int_feature(np.ones_like(input_ids))
55
+
56
+ features["masked_lm_positions"] = create_int_feature(
57
+ np.random.randint(100, size=(max_predictions_per_seq)))
58
+ features["masked_lm_ids"] = create_int_feature(
59
+ np.random.randint(100, size=(max_predictions_per_seq)))
60
+ features["masked_lm_weights"] = create_float_feature(
61
+ [1.0] * max_predictions_per_seq)
62
+
63
+ if use_next_sentence_label:
64
+ features["next_sentence_labels"] = create_int_feature([1])
65
+
66
+ if use_position_id:
67
+ features["position_ids"] = create_int_feature(range(0, seq_length))
68
+
69
+ tf_example = tf.train.Example(features=tf.train.Features(feature=features))
70
+ writer.write(tf_example.SerializeToString())
71
+ writer.close()
72
+
73
+
74
+ def _create_fake_xlnet_dataset(
75
+ output_path, seq_length, max_predictions_per_seq):
76
+ """Creates a fake dataset."""
77
+ writer = tf.io.TFRecordWriter(output_path)
78
+ for _ in range(100):
79
+ features = {}
80
+ input_ids = np.random.randint(100, size=(seq_length))
81
+ num_boundary_indices = np.random.randint(1, seq_length)
82
+
83
+ if max_predictions_per_seq is not None:
84
+ input_mask = np.zeros_like(input_ids)
85
+ input_mask[:max_predictions_per_seq] = 1
86
+ np.random.shuffle(input_mask)
87
+ else:
88
+ input_mask = np.ones_like(input_ids)
89
+
90
+ features["input_mask"] = create_int_feature(input_mask)
91
+ features["input_word_ids"] = create_int_feature(input_ids)
92
+ features["input_type_ids"] = create_int_feature(np.ones_like(input_ids))
93
+ features["boundary_indices"] = create_int_feature(
94
+ sorted(np.random.randint(seq_length, size=(num_boundary_indices))))
95
+ features["target"] = create_int_feature(input_ids + 1)
96
+ features["label"] = create_int_feature([1])
97
+ tf_example = tf.train.Example(features=tf.train.Features(feature=features))
98
+ writer.write(tf_example.SerializeToString())
99
+ writer.close()
100
+
101
+
102
+ class BertPretrainDataTest(tf.test.TestCase, parameterized.TestCase):
103
+
104
+ @parameterized.parameters(itertools.product(
105
+ (False, True),
106
+ (False, True),
107
+ ))
108
+ def test_load_data(self, use_next_sentence_label, use_position_id):
109
+ train_data_path = os.path.join(self.get_temp_dir(), "train.tf_record")
110
+ seq_length = 128
111
+ max_predictions_per_seq = 20
112
+ _create_fake_bert_dataset(
113
+ train_data_path,
114
+ seq_length,
115
+ max_predictions_per_seq,
116
+ use_next_sentence_label=use_next_sentence_label,
117
+ use_position_id=use_position_id)
118
+ data_config = pretrain_dataloader.BertPretrainDataConfig(
119
+ input_path=train_data_path,
120
+ max_predictions_per_seq=max_predictions_per_seq,
121
+ seq_length=seq_length,
122
+ global_batch_size=10,
123
+ is_training=True,
124
+ use_next_sentence_label=use_next_sentence_label,
125
+ use_position_id=use_position_id)
126
+
127
+ dataset = pretrain_dataloader.BertPretrainDataLoader(data_config).load()
128
+ features = next(iter(dataset))
129
+ self.assertLen(features,
130
+ 6 + int(use_next_sentence_label) + int(use_position_id))
131
+ self.assertIn("input_word_ids", features)
132
+ self.assertIn("input_mask", features)
133
+ self.assertIn("input_type_ids", features)
134
+ self.assertIn("masked_lm_positions", features)
135
+ self.assertIn("masked_lm_ids", features)
136
+ self.assertIn("masked_lm_weights", features)
137
+
138
+ self.assertEqual("next_sentence_labels" in features,
139
+ use_next_sentence_label)
140
+ self.assertEqual("position_ids" in features, use_position_id)
141
+
142
+ def test_v2_feature_names(self):
143
+ train_data_path = os.path.join(self.get_temp_dir(), "train.tf_record")
144
+ seq_length = 128
145
+ max_predictions_per_seq = 20
146
+ _create_fake_bert_dataset(
147
+ train_data_path,
148
+ seq_length,
149
+ max_predictions_per_seq,
150
+ use_next_sentence_label=True,
151
+ use_position_id=False,
152
+ use_v2_feature_names=True)
153
+ data_config = pretrain_dataloader.BertPretrainDataConfig(
154
+ input_path=train_data_path,
155
+ max_predictions_per_seq=max_predictions_per_seq,
156
+ seq_length=seq_length,
157
+ global_batch_size=10,
158
+ is_training=True,
159
+ use_next_sentence_label=True,
160
+ use_position_id=False,
161
+ use_v2_feature_names=True)
162
+
163
+ dataset = pretrain_dataloader.BertPretrainDataLoader(data_config).load()
164
+ features = next(iter(dataset))
165
+ self.assertIn("input_word_ids", features)
166
+ self.assertIn("input_mask", features)
167
+ self.assertIn("input_type_ids", features)
168
+ self.assertIn("masked_lm_positions", features)
169
+ self.assertIn("masked_lm_ids", features)
170
+ self.assertIn("masked_lm_weights", features)
171
+
172
+
173
+ class XLNetPretrainDataTest(parameterized.TestCase, tf.test.TestCase):
174
+
175
+ @parameterized.parameters(itertools.product(
176
+ ("single_token", "whole_word", "token_span"),
177
+ (0, 64),
178
+ (20, None),
179
+ ))
180
+ def test_load_data(
181
+ self, sample_strategy, reuse_length, max_predictions_per_seq):
182
+ train_data_path = os.path.join(self.get_temp_dir(), "train.tf_record")
183
+ seq_length = 128
184
+ batch_size = 5
185
+
186
+ _create_fake_xlnet_dataset(
187
+ train_data_path, seq_length, max_predictions_per_seq)
188
+
189
+ data_config = pretrain_dataloader.XLNetPretrainDataConfig(
190
+ input_path=train_data_path,
191
+ max_predictions_per_seq=max_predictions_per_seq,
192
+ seq_length=seq_length,
193
+ global_batch_size=batch_size,
194
+ is_training=True,
195
+ reuse_length=reuse_length,
196
+ sample_strategy=sample_strategy,
197
+ min_num_tokens=1,
198
+ max_num_tokens=2,
199
+ permutation_size=seq_length // 2,
200
+ leak_ratio=0.1)
201
+
202
+ if max_predictions_per_seq is None:
203
+ with self.assertRaises(ValueError):
204
+ dataset = pretrain_dataloader.XLNetPretrainDataLoader(
205
+ data_config).load()
206
+ features = next(iter(dataset))
207
+ else:
208
+ dataset = pretrain_dataloader.XLNetPretrainDataLoader(data_config).load()
209
+ features = next(iter(dataset))
210
+
211
+ self.assertIn("input_word_ids", features)
212
+ self.assertIn("input_type_ids", features)
213
+ self.assertIn("permutation_mask", features)
214
+ self.assertIn("masked_tokens", features)
215
+ self.assertIn("target", features)
216
+ self.assertIn("target_mask", features)
217
+
218
+ self.assertAllClose(features["input_word_ids"].shape,
219
+ (batch_size, seq_length))
220
+ self.assertAllClose(features["input_type_ids"].shape,
221
+ (batch_size, seq_length))
222
+ self.assertAllClose(features["permutation_mask"].shape,
223
+ (batch_size, seq_length, seq_length))
224
+ self.assertAllClose(features["masked_tokens"].shape,
225
+ (batch_size, seq_length,))
226
+ if max_predictions_per_seq is not None:
227
+ self.assertIn("target_mapping", features)
228
+ self.assertAllClose(features["target_mapping"].shape,
229
+ (batch_size, max_predictions_per_seq, seq_length))
230
+ self.assertAllClose(features["target_mask"].shape,
231
+ (batch_size, max_predictions_per_seq))
232
+ self.assertAllClose(features["target"].shape,
233
+ (batch_size, max_predictions_per_seq))
234
+ else:
235
+ self.assertAllClose(features["target_mask"].shape,
236
+ (batch_size, seq_length))
237
+ self.assertAllClose(features["target"].shape,
238
+ (batch_size, seq_length))
239
+
240
+
241
+ if __name__ == "__main__":
242
+ tf.test.main()
pretrain_dynamic_dataloader.py ADDED
@@ -0,0 +1,223 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The TensorFlow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Dataset loader for the pre-training with dynamic sequence length."""
16
+ from typing import Optional, Tuple
17
+
18
+ import dataclasses
19
+ import tensorflow as tf, tf_keras
20
+
21
+ from official.core import config_definitions as cfg
22
+ from official.core import input_reader
23
+ from official.nlp.data import data_loader_factory
24
+ from official.nlp.data import pretrain_dataloader
25
+
26
+
27
+ @dataclasses.dataclass
28
+ class BertPretrainDataConfig(cfg.DataConfig):
29
+ """Data config for BERT pretraining task (tasks/masked_lm)."""
30
+ input_path: str = ''
31
+ global_batch_size: int = 512
32
+ is_training: bool = True
33
+ seq_bucket_lengths: Tuple[int, ...] = (128, 256, 384, 512,)
34
+ # TODO(rxsang): `seq_bucket_window_scale` is only useful when round robin
35
+ # tf.data service is disabled. Deprecate this flag once we always enable round
36
+ # robin tf.data service.
37
+ seq_bucket_window_scale: int = 8
38
+ use_next_sentence_label: bool = True
39
+ use_position_id: bool = False
40
+ deterministic: bool = False
41
+ enable_tf_data_service: bool = False
42
+ enable_round_robin_tf_data_service: bool = False
43
+ tf_data_service_job_name: str = 'bert_pretrain'
44
+ use_v2_feature_names: bool = False
45
+
46
+
47
+ @data_loader_factory.register_data_loader_cls(BertPretrainDataConfig)
48
+ class PretrainingDynamicDataLoader(pretrain_dataloader.BertPretrainDataLoader):
49
+ """Dataset loader for bert-style pretraining with dynamic sequenece length.
50
+
51
+ Bucketizes the input id features by the seq_bucket_lengths and features are
52
+ padded to the bucket boundaries. The mask features are usually short than
53
+ input id features and can also be dynamic. We require the mask feature lengths
54
+ within a bucket must be the same. For example, with [128, 256] buckets,
55
+ the mask features for bucket 128 should always have the length as X and
56
+ features for bucket 256 should always have the length as Y.
57
+
58
+ The dataloader does not filter out empty masks. Make sure to handle this
59
+ in the model.
60
+ """
61
+
62
+ def __init__(self, params):
63
+ self._params = params
64
+ if len(params.seq_bucket_lengths) < 1:
65
+ raise ValueError('The seq_bucket_lengths cannot be empty.')
66
+ self._seq_bucket_lengths = params.seq_bucket_lengths
67
+ self._seq_bucket_window_scale = params.seq_bucket_window_scale
68
+ self._global_batch_size = params.global_batch_size
69
+ self._use_next_sentence_label = params.use_next_sentence_label
70
+ self._use_position_id = params.use_position_id
71
+ self._drop_remainder = params.drop_remainder
72
+ self._enable_tf_data_service = params.enable_tf_data_service
73
+ self._enable_round_robin_tf_data_service = (
74
+ params.enable_round_robin_tf_data_service)
75
+ self._mask_keys = [
76
+ 'masked_lm_positions', 'masked_lm_ids', 'masked_lm_weights'
77
+ ]
78
+
79
+ def _decode(self, record: tf.Tensor):
80
+ """Decodes a serialized tf.Example."""
81
+ name_to_features = {
82
+ 'input_mask': tf.io.VarLenFeature(tf.int64),
83
+ 'masked_lm_positions': tf.io.VarLenFeature(tf.int64),
84
+ 'masked_lm_ids': tf.io.VarLenFeature(tf.int64),
85
+ 'masked_lm_weights': tf.io.VarLenFeature(tf.float32),
86
+ }
87
+ if self._params.use_v2_feature_names:
88
+ input_ids_key = 'input_word_ids'
89
+ segment_key = 'input_type_ids'
90
+ name_to_features.update({
91
+ input_ids_key: tf.io.VarLenFeature(tf.int64),
92
+ segment_key: tf.io.VarLenFeature(tf.int64),
93
+ })
94
+ else:
95
+ input_ids_key = 'input_ids'
96
+ segment_key = 'segment_ids'
97
+ name_to_features.update({
98
+ input_ids_key: tf.io.VarLenFeature(tf.int64),
99
+ segment_key: tf.io.VarLenFeature(tf.int64),
100
+ })
101
+ if self._use_next_sentence_label:
102
+ name_to_features['next_sentence_labels'] = tf.io.FixedLenFeature([1],
103
+ tf.int64)
104
+ dynamic_keys = [input_ids_key, 'input_mask', segment_key]
105
+ if self._use_position_id:
106
+ name_to_features['position_ids'] = tf.io.VarLenFeature(tf.int64)
107
+ dynamic_keys.append('position_ids')
108
+
109
+ example = tf.io.parse_single_example(record, name_to_features)
110
+ for key in dynamic_keys + self._mask_keys:
111
+ example[key] = tf.sparse.to_dense(example[key])
112
+
113
+ # Truncate padded data after the first non pad in the
114
+ # sequence length dimension.
115
+ # Pad before the first non pad from the back should not be removed.
116
+ mask = tf.math.greater(
117
+ tf.math.cumsum(example[input_ids_key], reverse=True), 0)
118
+ for key in dynamic_keys:
119
+ example[key] = tf.boolean_mask(example[key], mask)
120
+
121
+ # masked_lm_ids should be 0 padded.
122
+ # Change mask features to -1 padding so that we can differentiate
123
+ # padding from data or from bucketizing.
124
+ mask = tf.math.not_equal(example['masked_lm_ids'], 0)
125
+ example['masked_lm_ids'] = tf.where(
126
+ mask, example['masked_lm_ids'],
127
+ -tf.ones(tf.shape(example['masked_lm_ids']), dtype=example[key].dtype))
128
+
129
+ # tf.Example only supports tf.int64, but the TPU only supports tf.int32.
130
+ # So cast all int64 to int32.
131
+ # tf.data service uses dataset graph fingerprint to distinguish input
132
+ # pipeline jobs, thus we sort the keys here to make sure they are generated
133
+ # in a deterministic order each time the dataset function is traced.
134
+ for name in sorted(list(example.keys())):
135
+ t = example[name]
136
+ if t.dtype == tf.int64:
137
+ t = tf.cast(t, tf.int32)
138
+ example[name] = t
139
+
140
+ return example
141
+
142
+ def _bucketize_and_batch(
143
+ self,
144
+ dataset,
145
+ input_context: Optional[tf.distribute.InputContext] = None):
146
+ """Bucketize by sequence length and batch the datasets."""
147
+ per_replica_batch_size = input_context.get_per_replica_batch_size(
148
+ self._global_batch_size) if input_context else self._global_batch_size
149
+
150
+ def element_length_func(example, seq_len_dim):
151
+ return tf.shape(example['input_word_ids'])[seq_len_dim]
152
+
153
+ bucket_boundaries = [length + 1 for length in self._seq_bucket_lengths]
154
+ bucket_batch_sizes = [per_replica_batch_size] * (len(bucket_boundaries) + 1)
155
+
156
+ # Bucketize and batch the dataset with per replica batch size first.
157
+ dataset = dataset.apply(
158
+ tf.data.experimental.bucket_by_sequence_length(
159
+ lambda example: tf.cast(element_length_func(example, 0), tf.int32),
160
+ bucket_boundaries,
161
+ bucket_batch_sizes,
162
+ pad_to_bucket_boundary=True,
163
+ drop_remainder=self._drop_remainder))
164
+ if input_context:
165
+ window_size = input_context.num_replicas_in_sync
166
+ if self._enable_tf_data_service and (
167
+ not self._enable_round_robin_tf_data_service):
168
+ # If tf.data service is enabled but round-robin behavior is not enabled,
169
+ # different TPU workers may fetch data from one tf.data service worker
170
+ # in different speed. We set the window size to be
171
+ # `seq_bucket_window_scale` larger to leave buffer if some workers are
172
+ # fetching data faster than others, so all the data within the same
173
+ # global batch can still have more chances to be in the same bucket.
174
+ window_size *= self._seq_bucket_window_scale
175
+
176
+ # Group `num_replicas_in_sync` batches from same bucket together, so all
177
+ # replicas can get the same sequence length for one global step.
178
+ dataset = dataset.apply(
179
+ tf.data.experimental.group_by_window(
180
+ key_func=lambda example: tf.cast( # pylint: disable=g-long-lambda
181
+ element_length_func(example, 1), tf.int64),
182
+ reduce_func=lambda _, x: tf.data.Dataset.from_tensors(x),
183
+ window_size=window_size))
184
+ dataset = dataset.flat_map(lambda x: x)
185
+
186
+ def _remove_pads_from_bucketize(features):
187
+ # All mask features must have the same effective length.
188
+ # The real masked ids padding token is -1 and 0 comes from
189
+ # bucket_by_sequence_length.
190
+ mask = tf.math.not_equal(features['masked_lm_ids'], 0)
191
+
192
+ mask_per_example = tf.math.reduce_sum(tf.cast(mask, tf.int32), axis=1)
193
+ normalized = tf.cast(
194
+ mask_per_example / tf.math.reduce_max(mask_per_example), tf.int32)
195
+ assert_op = tf.debugging.assert_equal(
196
+ tf.math.reduce_sum(normalized), per_replica_batch_size,
197
+ 'Number of non padded mask tokens is not the same for each example '
198
+ 'in the same sequence length.')
199
+ with tf.control_dependencies([assert_op]):
200
+ for key in self._mask_keys:
201
+ features[key] = tf.reshape(
202
+ tf.boolean_mask(
203
+ features[key], mask), [per_replica_batch_size, -1])
204
+ # Revert masked_lm_ids to be 0-padded.
205
+ mask = tf.math.not_equal(features['masked_lm_ids'], -1)
206
+ features['masked_lm_ids'] = tf.where(
207
+ mask, features['masked_lm_ids'],
208
+ tf.zeros(
209
+ tf.shape(features['masked_lm_ids']),
210
+ dtype=features['masked_lm_ids'].dtype))
211
+ return features
212
+
213
+ dataset = dataset.map(_remove_pads_from_bucketize)
214
+ return dataset
215
+
216
+ def load(self, input_context: Optional[tf.distribute.InputContext] = None):
217
+ """Returns a tf.dataset.Dataset."""
218
+ reader = input_reader.InputReader(
219
+ params=self._params,
220
+ decoder_fn=self._decode,
221
+ parser_fn=self._parse,
222
+ transform_and_batch_fn=self._bucketize_and_batch)
223
+ return reader.read(input_context)
pretrain_dynamic_dataloader_test.py ADDED
@@ -0,0 +1,245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The TensorFlow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Tests for nlp.data.pretrain_dynamic_dataloader."""
16
+ import os
17
+
18
+ from absl import logging
19
+ from absl.testing import parameterized
20
+ import numpy as np
21
+ import orbit
22
+ import tensorflow as tf, tf_keras
23
+
24
+ from tensorflow.python.distribute import combinations
25
+ from tensorflow.python.distribute import strategy_combinations
26
+ from official.nlp.configs import bert
27
+ from official.nlp.configs import encoders
28
+ from official.nlp.data import pretrain_dataloader
29
+ from official.nlp.data import pretrain_dynamic_dataloader
30
+ from official.nlp.tasks import masked_lm
31
+
32
+
33
+ def _create_fake_dataset(output_path, seq_length, num_masked_tokens,
34
+ max_seq_length, num_examples):
35
+ """Creates a fake dataset."""
36
+ writer = tf.io.TFRecordWriter(output_path)
37
+
38
+ def create_int_feature(values):
39
+ f = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values)))
40
+ return f
41
+
42
+ def create_float_feature(values):
43
+ f = tf.train.Feature(float_list=tf.train.FloatList(value=list(values)))
44
+ return f
45
+
46
+ rng = np.random.default_rng(37)
47
+ for _ in range(num_examples):
48
+ features = {}
49
+ padding = np.zeros(shape=(max_seq_length - seq_length), dtype=np.int32)
50
+ input_ids = rng.integers(low=1, high=100, size=(seq_length))
51
+ features['input_ids'] = create_int_feature(
52
+ np.concatenate((input_ids, padding)))
53
+ features['input_mask'] = create_int_feature(
54
+ np.concatenate((np.ones_like(input_ids), padding)))
55
+ features['segment_ids'] = create_int_feature(
56
+ np.concatenate((np.ones_like(input_ids), padding)))
57
+ features['position_ids'] = create_int_feature(
58
+ np.concatenate((np.ones_like(input_ids), padding)))
59
+ features['masked_lm_positions'] = create_int_feature(
60
+ rng.integers(60, size=(num_masked_tokens), dtype=np.int64))
61
+ features['masked_lm_ids'] = create_int_feature(
62
+ rng.integers(100, size=(num_masked_tokens), dtype=np.int64))
63
+ features['masked_lm_weights'] = create_float_feature(
64
+ np.ones((num_masked_tokens,), dtype=np.float32))
65
+ features['next_sentence_labels'] = create_int_feature(np.array([0]))
66
+
67
+ tf_example = tf.train.Example(features=tf.train.Features(feature=features))
68
+ writer.write(tf_example.SerializeToString())
69
+ writer.close()
70
+
71
+
72
+ class PretrainDynamicDataLoaderTest(tf.test.TestCase, parameterized.TestCase):
73
+
74
+ @combinations.generate(
75
+ combinations.combine(
76
+ distribution_strategy=[
77
+ strategy_combinations.cloud_tpu_strategy,
78
+ ],
79
+ mode='eager'))
80
+ def test_distribution_strategy(self, distribution_strategy):
81
+ max_seq_length = 128
82
+ batch_size = 8
83
+ input_path = os.path.join(self.get_temp_dir(), 'train.tf_record')
84
+ _create_fake_dataset(
85
+ input_path,
86
+ seq_length=60,
87
+ num_masked_tokens=20,
88
+ max_seq_length=max_seq_length,
89
+ num_examples=batch_size)
90
+ data_config = pretrain_dynamic_dataloader.BertPretrainDataConfig(
91
+ is_training=False,
92
+ input_path=input_path,
93
+ seq_bucket_lengths=[64, 128],
94
+ global_batch_size=batch_size)
95
+ dataloader = pretrain_dynamic_dataloader.PretrainingDynamicDataLoader(
96
+ data_config)
97
+ distributed_ds = orbit.utils.make_distributed_dataset(
98
+ distribution_strategy, dataloader.load)
99
+ train_iter = iter(distributed_ds)
100
+ with distribution_strategy.scope():
101
+ config = masked_lm.MaskedLMConfig(
102
+ init_checkpoint=self.get_temp_dir(),
103
+ model=bert.PretrainerConfig(
104
+ encoders.EncoderConfig(
105
+ bert=encoders.BertEncoderConfig(
106
+ vocab_size=30522, num_layers=1)),
107
+ cls_heads=[
108
+ bert.ClsHeadConfig(
109
+ inner_dim=10, num_classes=2, name='next_sentence')
110
+ ]),
111
+ train_data=data_config)
112
+ task = masked_lm.MaskedLMTask(config)
113
+ model = task.build_model()
114
+ metrics = task.build_metrics()
115
+
116
+ @tf.function
117
+ def step_fn(features):
118
+ return task.validation_step(features, model, metrics=metrics)
119
+
120
+ distributed_outputs = distribution_strategy.run(
121
+ step_fn, args=(next(train_iter),))
122
+ local_results = tf.nest.map_structure(
123
+ distribution_strategy.experimental_local_results, distributed_outputs)
124
+ logging.info('Dynamic padding: local_results= %s', str(local_results))
125
+ dynamic_metrics = {}
126
+ for metric in metrics:
127
+ dynamic_metrics[metric.name] = metric.result()
128
+
129
+ data_config = pretrain_dataloader.BertPretrainDataConfig(
130
+ is_training=False,
131
+ input_path=input_path,
132
+ seq_length=max_seq_length,
133
+ max_predictions_per_seq=20,
134
+ global_batch_size=batch_size)
135
+ dataloader = pretrain_dataloader.BertPretrainDataLoader(data_config)
136
+ distributed_ds = orbit.utils.make_distributed_dataset(
137
+ distribution_strategy, dataloader.load)
138
+ train_iter = iter(distributed_ds)
139
+ with distribution_strategy.scope():
140
+ metrics = task.build_metrics()
141
+
142
+ @tf.function
143
+ def step_fn_b(features):
144
+ return task.validation_step(features, model, metrics=metrics)
145
+
146
+ distributed_outputs = distribution_strategy.run(
147
+ step_fn_b, args=(next(train_iter),))
148
+ local_results = tf.nest.map_structure(
149
+ distribution_strategy.experimental_local_results, distributed_outputs)
150
+ logging.info('Static padding: local_results= %s', str(local_results))
151
+ static_metrics = {}
152
+ for metric in metrics:
153
+ static_metrics[metric.name] = metric.result()
154
+ for key in static_metrics:
155
+ # We need to investigate the differences on losses.
156
+ if key != 'next_sentence_loss':
157
+ self.assertEqual(dynamic_metrics[key], static_metrics[key])
158
+
159
+ def test_load_dataset(self):
160
+ tf.random.set_seed(0)
161
+ max_seq_length = 128
162
+ batch_size = 2
163
+ input_path_1 = os.path.join(self.get_temp_dir(), 'train_1.tf_record')
164
+ _create_fake_dataset(
165
+ input_path_1,
166
+ seq_length=60,
167
+ num_masked_tokens=20,
168
+ max_seq_length=max_seq_length,
169
+ num_examples=batch_size)
170
+ input_path_2 = os.path.join(self.get_temp_dir(), 'train_2.tf_record')
171
+ _create_fake_dataset(
172
+ input_path_2,
173
+ seq_length=100,
174
+ num_masked_tokens=70,
175
+ max_seq_length=max_seq_length,
176
+ num_examples=batch_size)
177
+ input_paths = ','.join([input_path_1, input_path_2])
178
+ data_config = pretrain_dynamic_dataloader.BertPretrainDataConfig(
179
+ is_training=False,
180
+ input_path=input_paths,
181
+ seq_bucket_lengths=[64, 128],
182
+ use_position_id=True,
183
+ global_batch_size=batch_size,
184
+ deterministic=True)
185
+ dataset = pretrain_dynamic_dataloader.PretrainingDynamicDataLoader(
186
+ data_config).load()
187
+ dataset_it = iter(dataset)
188
+ features = next(dataset_it)
189
+ self.assertCountEqual([
190
+ 'input_word_ids',
191
+ 'input_mask',
192
+ 'input_type_ids',
193
+ 'next_sentence_labels',
194
+ 'masked_lm_positions',
195
+ 'masked_lm_ids',
196
+ 'masked_lm_weights',
197
+ 'position_ids',
198
+ ], features.keys())
199
+ # Sequence length dimension should be bucketized and pad to 64.
200
+ self.assertEqual(features['input_word_ids'].shape, (batch_size, 64))
201
+ self.assertEqual(features['input_mask'].shape, (batch_size, 64))
202
+ self.assertEqual(features['input_type_ids'].shape, (batch_size, 64))
203
+ self.assertEqual(features['position_ids'].shape, (batch_size, 64))
204
+ self.assertEqual(features['masked_lm_positions'].shape, (batch_size, 20))
205
+ features = next(dataset_it)
206
+ self.assertEqual(features['input_word_ids'].shape, (batch_size, 128))
207
+ self.assertEqual(features['input_mask'].shape, (batch_size, 128))
208
+ self.assertEqual(features['input_type_ids'].shape, (batch_size, 128))
209
+ self.assertEqual(features['position_ids'].shape, (batch_size, 128))
210
+ self.assertEqual(features['masked_lm_positions'].shape, (batch_size, 70))
211
+
212
+ def test_load_dataset_not_same_masks(self):
213
+ max_seq_length = 128
214
+ batch_size = 2
215
+ input_path_1 = os.path.join(self.get_temp_dir(), 'train_3.tf_record')
216
+ _create_fake_dataset(
217
+ input_path_1,
218
+ seq_length=60,
219
+ num_masked_tokens=20,
220
+ max_seq_length=max_seq_length,
221
+ num_examples=batch_size)
222
+ input_path_2 = os.path.join(self.get_temp_dir(), 'train_4.tf_record')
223
+ _create_fake_dataset(
224
+ input_path_2,
225
+ seq_length=60,
226
+ num_masked_tokens=15,
227
+ max_seq_length=max_seq_length,
228
+ num_examples=batch_size)
229
+ input_paths = ','.join([input_path_1, input_path_2])
230
+ data_config = pretrain_dynamic_dataloader.BertPretrainDataConfig(
231
+ is_training=False,
232
+ input_path=input_paths,
233
+ seq_bucket_lengths=[64, 128],
234
+ use_position_id=True,
235
+ global_batch_size=batch_size * 2)
236
+ dataset = pretrain_dynamic_dataloader.PretrainingDynamicDataLoader(
237
+ data_config).load()
238
+ dataset_it = iter(dataset)
239
+ with self.assertRaisesRegex(
240
+ tf.errors.InvalidArgumentError, '.*Number of non padded mask tokens.*'):
241
+ next(dataset_it)
242
+
243
+
244
+ if __name__ == '__main__':
245
+ tf.test.main()
pretrain_text_dataloader.py ADDED
@@ -0,0 +1,226 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The TensorFlow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Loads text dataset for the BERT pretraining task."""
16
+ import dataclasses
17
+ from typing import List, Mapping, Optional, Text
18
+
19
+ import tensorflow as tf, tf_keras
20
+ import tensorflow_text as tf_text
21
+
22
+ from official.common import dataset_fn
23
+ from official.core import config_definitions as cfg
24
+ from official.core import input_reader
25
+ from official.nlp.data import data_loader
26
+ from official.nlp.data import data_loader_factory
27
+ from official.nlp.modeling.ops import segment_extractor
28
+
29
+
30
+ @dataclasses.dataclass
31
+ class BertPretrainTextDataConfig(cfg.DataConfig):
32
+ """Data config for BERT pretraining task (tasks/masked_lm) from text."""
33
+ input_path: str = ""
34
+ doc_batch_size: int = 8
35
+ global_batch_size: int = 512
36
+ is_training: bool = True
37
+ seq_length: int = 512
38
+ max_predictions_per_seq: int = 76
39
+ use_next_sentence_label: bool = True
40
+ # The name of the text feature fields. The text features will be
41
+ # concatenated in order.
42
+ # Note: More than 1 field name is not compatible with NSP.
43
+ text_field_names: Optional[List[str]] = dataclasses.field(
44
+ default_factory=lambda: ["text"])
45
+ vocab_file_path: str = ""
46
+ masking_rate: float = 0.15
47
+ use_whole_word_masking: bool = False
48
+ file_type: str = "tfrecord"
49
+
50
+
51
+ _CLS_TOKEN = b"[CLS]"
52
+ _SEP_TOKEN = b"[SEP]"
53
+ _MASK_TOKEN = b"[MASK]"
54
+ _NUM_OOV_BUCKETS = 1
55
+ # Accounts for [CLS] and 2 x [SEP] tokens
56
+ _NUM_SPECIAL_TOKENS = 3
57
+
58
+
59
+ @data_loader_factory.register_data_loader_cls(BertPretrainTextDataConfig)
60
+ class BertPretrainTextDataLoader(data_loader.DataLoader):
61
+ """A class to load text dataset for BERT pretraining task."""
62
+
63
+ def __init__(self, params):
64
+ """Inits `BertPretrainTextDataLoader` class.
65
+
66
+ Args:
67
+ params: A `BertPretrainTextDataConfig` object.
68
+ """
69
+ if len(params.text_field_names) > 1 and params.use_next_sentence_label:
70
+ raise ValueError("Currently there is no support for more than text field "
71
+ "while generating next sentence labels.")
72
+
73
+ self._params = params
74
+ self._seq_length = params.seq_length
75
+ self._max_predictions_per_seq = params.max_predictions_per_seq
76
+ self._use_next_sentence_label = params.use_next_sentence_label
77
+ self._masking_rate = params.masking_rate
78
+ self._use_whole_word_masking = params.use_whole_word_masking
79
+
80
+ lookup_table_init = tf.lookup.TextFileInitializer(
81
+ params.vocab_file_path,
82
+ key_dtype=tf.string,
83
+ key_index=tf.lookup.TextFileIndex.WHOLE_LINE,
84
+ value_dtype=tf.int64,
85
+ value_index=tf.lookup.TextFileIndex.LINE_NUMBER)
86
+ self._vocab_lookup_table = tf.lookup.StaticVocabularyTable(
87
+ lookup_table_init,
88
+ num_oov_buckets=_NUM_OOV_BUCKETS,
89
+ lookup_key_dtype=tf.string)
90
+
91
+ self._cls_token = self._vocab_lookup_table.lookup(tf.constant(_CLS_TOKEN))
92
+ self._sep_token = self._vocab_lookup_table.lookup(tf.constant(_SEP_TOKEN))
93
+ self._mask_token = self._vocab_lookup_table.lookup(tf.constant(_MASK_TOKEN))
94
+
95
+ # -_NUM_OOV_BUCKETS to offset unused OOV bucket.
96
+ self._vocab_size = self._vocab_lookup_table.size() - _NUM_OOV_BUCKETS
97
+
98
+ def _decode(self, record: tf.Tensor) -> Mapping[Text, tf.Tensor]:
99
+ """Decodes a serialized tf.Example."""
100
+ name_to_features = {}
101
+ for text_field_name in self._params.text_field_names:
102
+ name_to_features[text_field_name] = tf.io.FixedLenFeature([], tf.string)
103
+ return tf.io.parse_single_example(record, name_to_features)
104
+
105
+ def _tokenize(self, segments):
106
+ """Tokenize the input segments."""
107
+ # Tokenize segments
108
+ tokenizer = tf_text.BertTokenizer(
109
+ self._vocab_lookup_table, token_out_type=tf.int64)
110
+
111
+ if self._use_whole_word_masking:
112
+ # tokenize the segments which should have the shape:
113
+ # [num_sentence, (num_words), (num_wordpieces)]
114
+ segments = [tokenizer.tokenize(s) for s in segments]
115
+ else:
116
+ # tokenize the segments and merge out the token dimension so that each
117
+ # segment has the shape: [num_sentence, (num_wordpieces)]
118
+ segments = [tokenizer.tokenize(s).merge_dims(-2, -1) for s in segments]
119
+
120
+ # Truncate inputs
121
+ trimmer = tf_text.WaterfallTrimmer(
122
+ self._seq_length - _NUM_SPECIAL_TOKENS, axis=-1)
123
+ truncated_segments = trimmer.trim(segments)
124
+
125
+ # Combine segments, get segment ids and add special tokens
126
+ return tf_text.combine_segments(
127
+ truncated_segments,
128
+ start_of_sequence_id=self._cls_token,
129
+ end_of_segment_id=self._sep_token)
130
+
131
+ def _bert_preprocess(self, record: Mapping[str, tf.Tensor]):
132
+ """Parses raw tensors into a dict of tensors to be consumed by the model."""
133
+ if self._use_next_sentence_label:
134
+ input_text = record[self._params.text_field_names[0]]
135
+ # Split sentences
136
+ sentence_breaker = tf_text.RegexSplitter()
137
+ sentences = sentence_breaker.split(input_text)
138
+
139
+ # Extract next-sentence-prediction labels and segments
140
+ next_or_random_segment, is_next = (
141
+ segment_extractor.get_next_sentence_labels(sentences))
142
+ # merge dims to change shape from [num_docs, (num_segments)] to
143
+ # [total_num_segments]
144
+ is_next = is_next.merge_dims(-2, -1)
145
+
146
+ # construct segments with shape [(num_sentence)]
147
+ segments = [
148
+ sentences.merge_dims(-2, -1),
149
+ next_or_random_segment.merge_dims(-2, -1)
150
+ ]
151
+ else:
152
+ segments = [record[name] for name in self._params.text_field_names]
153
+
154
+ segments_combined, segment_ids = self._tokenize(segments)
155
+
156
+ # Dynamic masking
157
+ item_selector = tf_text.RandomItemSelector(
158
+ self._max_predictions_per_seq,
159
+ selection_rate=self._masking_rate,
160
+ unselectable_ids=[self._cls_token, self._sep_token],
161
+ shuffle_fn=(tf.identity if self._params.deterministic else None))
162
+ values_chooser = tf_text.MaskValuesChooser(
163
+ vocab_size=self._vocab_size, mask_token=self._mask_token)
164
+ masked_input_ids, masked_lm_positions, masked_lm_ids = (
165
+ tf_text.mask_language_model(
166
+ segments_combined,
167
+ item_selector=item_selector,
168
+ mask_values_chooser=values_chooser,
169
+ ))
170
+
171
+ # Pad out to fixed shape and get input mask.
172
+ seq_lengths = {
173
+ "input_word_ids": self._seq_length,
174
+ "input_type_ids": self._seq_length,
175
+ "masked_lm_positions": self._max_predictions_per_seq,
176
+ "masked_lm_ids": self._max_predictions_per_seq,
177
+ }
178
+ model_inputs = {
179
+ "input_word_ids": masked_input_ids,
180
+ "input_type_ids": segment_ids,
181
+ "masked_lm_positions": masked_lm_positions,
182
+ "masked_lm_ids": masked_lm_ids,
183
+ }
184
+ padded_inputs_and_mask = tf.nest.map_structure(tf_text.pad_model_inputs,
185
+ model_inputs, seq_lengths)
186
+ model_inputs = {
187
+ k: padded_inputs_and_mask[k][0] for k in padded_inputs_and_mask
188
+ }
189
+ model_inputs["masked_lm_weights"] = tf.cast(
190
+ padded_inputs_and_mask["masked_lm_ids"][1], tf.float32)
191
+ model_inputs["input_mask"] = padded_inputs_and_mask["input_word_ids"][1]
192
+
193
+ if self._use_next_sentence_label:
194
+ model_inputs["next_sentence_labels"] = is_next
195
+
196
+ for name in model_inputs:
197
+ t = model_inputs[name]
198
+ if t.dtype == tf.int64:
199
+ t = tf.cast(t, tf.int32)
200
+ model_inputs[name] = t
201
+
202
+ return model_inputs
203
+
204
+ def load(self, input_context: Optional[tf.distribute.InputContext] = None):
205
+ """Returns a tf.dataset.Dataset."""
206
+
207
+ def _batch_docs(dataset, input_context):
208
+ per_core_doc_batch_size = (
209
+ input_context.get_per_replica_batch_size(self._params.doc_batch_size)
210
+ if input_context else self._params.doc_batch_size)
211
+ return dataset.batch(per_core_doc_batch_size)
212
+
213
+ reader = input_reader.InputReader(
214
+ params=self._params,
215
+ dataset_fn=dataset_fn.pick_dataset_fn(self._params.file_type),
216
+ decoder_fn=self._decode if self._params.input_path else None,
217
+ transform_and_batch_fn=_batch_docs
218
+ if self._use_next_sentence_label else None,
219
+ postprocess_fn=self._bert_preprocess)
220
+ transformed_inputs = reader.read(input_context)
221
+ per_core_example_batch_size = (
222
+ input_context.get_per_replica_batch_size(self._params.global_batch_size)
223
+ if input_context else self._params.global_batch_size)
224
+ batched_inputs = transformed_inputs.unbatch().batch(
225
+ per_core_example_batch_size, self._params.drop_remainder)
226
+ return batched_inputs.prefetch(tf.data.experimental.AUTOTUNE)
question_answering_dataloader.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The TensorFlow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Loads dataset for the question answering (e.g, SQuAD) task."""
16
+ import dataclasses
17
+ from typing import Mapping, Optional
18
+
19
+ import tensorflow as tf, tf_keras
20
+ from official.common import dataset_fn
21
+ from official.core import config_definitions as cfg
22
+ from official.core import input_reader
23
+ from official.nlp.data import data_loader
24
+ from official.nlp.data import data_loader_factory
25
+
26
+
27
+ @dataclasses.dataclass
28
+ class QADataConfig(cfg.DataConfig):
29
+ """Data config for question answering task (tasks/question_answering)."""
30
+ # For training, `input_path` is expected to be a pre-processed TFRecord file,
31
+ # while for evaluation, it is expected to be a raw JSON file (b/173814590).
32
+ input_path: str = ''
33
+ global_batch_size: int = 48
34
+ is_training: bool = True
35
+ seq_length: int = 384
36
+ # Settings below are question answering specific.
37
+ version_2_with_negative: bool = False
38
+ # Settings below are only used for eval mode.
39
+ input_preprocessed_data_path: str = ''
40
+ doc_stride: int = 128
41
+ query_length: int = 64
42
+ # The path to the vocab file of word piece tokenizer or the
43
+ # model of the sentence piece tokenizer.
44
+ vocab_file: str = ''
45
+ tokenization: str = 'WordPiece' # WordPiece or SentencePiece
46
+ do_lower_case: bool = True
47
+ xlnet_format: bool = False
48
+ file_type: str = 'tfrecord'
49
+
50
+
51
+ @data_loader_factory.register_data_loader_cls(QADataConfig)
52
+ class QuestionAnsweringDataLoader(data_loader.DataLoader):
53
+ """A class to load dataset for sentence prediction (classification) task."""
54
+
55
+ def __init__(self, params):
56
+ self._params = params
57
+ self._seq_length = params.seq_length
58
+ self._is_training = params.is_training
59
+ self._xlnet_format = params.xlnet_format
60
+
61
+ def _decode(self, record: tf.Tensor):
62
+ """Decodes a serialized tf.Example."""
63
+ name_to_features = {
64
+ 'input_ids': tf.io.FixedLenFeature([self._seq_length], tf.int64),
65
+ 'input_mask': tf.io.FixedLenFeature([self._seq_length], tf.int64),
66
+ 'segment_ids': tf.io.FixedLenFeature([self._seq_length], tf.int64),
67
+ }
68
+ if self._xlnet_format:
69
+ name_to_features['class_index'] = tf.io.FixedLenFeature([], tf.int64)
70
+ name_to_features['paragraph_mask'] = tf.io.FixedLenFeature(
71
+ [self._seq_length], tf.int64)
72
+ if self._is_training:
73
+ name_to_features['is_impossible'] = tf.io.FixedLenFeature([], tf.int64)
74
+
75
+ if self._is_training:
76
+ name_to_features['start_positions'] = tf.io.FixedLenFeature([], tf.int64)
77
+ name_to_features['end_positions'] = tf.io.FixedLenFeature([], tf.int64)
78
+ else:
79
+ name_to_features['unique_ids'] = tf.io.FixedLenFeature([], tf.int64)
80
+ example = tf.io.parse_single_example(record, name_to_features)
81
+
82
+ # tf.Example only supports tf.int64, but the TPU only supports tf.int32.
83
+ # So cast all int64 to int32.
84
+ for name in example:
85
+ t = example[name]
86
+ if t.dtype == tf.int64:
87
+ t = tf.cast(t, tf.int32)
88
+ example[name] = t
89
+
90
+ return example
91
+
92
+ def _parse(self, record: Mapping[str, tf.Tensor]):
93
+ """Parses raw tensors into a dict of tensors to be consumed by the model."""
94
+ x, y = {}, {}
95
+ for name, tensor in record.items():
96
+ if name in ('start_positions', 'end_positions', 'is_impossible'):
97
+ y[name] = tensor
98
+ elif name == 'input_ids':
99
+ x['input_word_ids'] = tensor
100
+ elif name == 'segment_ids':
101
+ x['input_type_ids'] = tensor
102
+ else:
103
+ x[name] = tensor
104
+ if name == 'start_positions' and self._xlnet_format:
105
+ x[name] = tensor
106
+ return (x, y)
107
+
108
+ def load(self, input_context: Optional[tf.distribute.InputContext] = None):
109
+ """Returns a tf.dataset.Dataset."""
110
+ reader = input_reader.InputReader(
111
+ params=self._params,
112
+ dataset_fn=dataset_fn.pick_dataset_fn(self._params.file_type),
113
+ decoder_fn=self._decode,
114
+ parser_fn=self._parse)
115
+ return reader.read(input_context)
question_answering_dataloader_test.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The TensorFlow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Tests for official.nlp.data.question_answering_dataloader."""
16
+ import os
17
+
18
+ import numpy as np
19
+ import tensorflow as tf, tf_keras
20
+
21
+ from official.nlp.data import question_answering_dataloader
22
+
23
+
24
+ def _create_fake_dataset(output_path, seq_length):
25
+ """Creates a fake dataset."""
26
+ writer = tf.io.TFRecordWriter(output_path)
27
+
28
+ def create_int_feature(values):
29
+ f = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values)))
30
+ return f
31
+
32
+ for _ in range(100):
33
+ features = {}
34
+ input_ids = np.random.randint(100, size=(seq_length))
35
+ features['input_ids'] = create_int_feature(input_ids)
36
+ features['input_mask'] = create_int_feature(np.ones_like(input_ids))
37
+ features['segment_ids'] = create_int_feature(np.ones_like(input_ids))
38
+ features['start_positions'] = create_int_feature(np.array([0]))
39
+ features['end_positions'] = create_int_feature(np.array([10]))
40
+
41
+ tf_example = tf.train.Example(features=tf.train.Features(feature=features))
42
+ writer.write(tf_example.SerializeToString())
43
+ writer.close()
44
+
45
+
46
+ class QuestionAnsweringDataTest(tf.test.TestCase):
47
+
48
+ def test_load_dataset(self):
49
+ seq_length = 128
50
+ batch_size = 10
51
+ input_path = os.path.join(self.get_temp_dir(), 'train.tf_record')
52
+ _create_fake_dataset(input_path, seq_length)
53
+ data_config = question_answering_dataloader.QADataConfig(
54
+ is_training=True,
55
+ input_path=input_path,
56
+ seq_length=seq_length,
57
+ global_batch_size=batch_size)
58
+ dataset = question_answering_dataloader.QuestionAnsweringDataLoader(
59
+ data_config).load()
60
+ features, labels = next(iter(dataset))
61
+
62
+ self.assertCountEqual(['input_word_ids', 'input_mask', 'input_type_ids'],
63
+ features.keys())
64
+ self.assertEqual(features['input_word_ids'].shape, (batch_size, seq_length))
65
+ self.assertEqual(features['input_mask'].shape, (batch_size, seq_length))
66
+ self.assertEqual(features['input_type_ids'].shape, (batch_size, seq_length))
67
+
68
+ self.assertCountEqual(['start_positions', 'end_positions'], labels.keys())
69
+ self.assertEqual(labels['start_positions'].shape, (batch_size,))
70
+ self.assertEqual(labels['end_positions'].shape, (batch_size,))
71
+
72
+
73
+ if __name__ == '__main__':
74
+ tf.test.main()
sentence_prediction_dataloader.py ADDED
@@ -0,0 +1,267 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The TensorFlow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Loads dataset for the sentence prediction (classification) task."""
16
+ import dataclasses
17
+ import functools
18
+ from typing import List, Mapping, Optional, Tuple
19
+
20
+ import tensorflow as tf, tf_keras
21
+ import tensorflow_hub as hub
22
+
23
+ from official.common import dataset_fn
24
+ from official.core import config_definitions as cfg
25
+ from official.core import input_reader
26
+ from official.nlp import modeling
27
+ from official.nlp.data import data_loader
28
+ from official.nlp.data import data_loader_factory
29
+
30
+ LABEL_TYPES_MAP = {'int': tf.int64, 'float': tf.float32}
31
+
32
+
33
+ @dataclasses.dataclass
34
+ class SentencePredictionDataConfig(cfg.DataConfig):
35
+ """Data config for sentence prediction task (tasks/sentence_prediction)."""
36
+ input_path: str = ''
37
+ global_batch_size: int = 32
38
+ is_training: bool = True
39
+ seq_length: int = 128
40
+ label_type: str = 'int'
41
+ # Whether to include the example id number.
42
+ include_example_id: bool = False
43
+ label_field: str = 'label_ids'
44
+ # Maps the key in TfExample to feature name.
45
+ # E.g 'label_ids' to 'next_sentence_labels'
46
+ label_name: Optional[Tuple[str, str]] = None
47
+ # Either tfrecord, sstable, or recordio.
48
+ file_type: str = 'tfrecord'
49
+
50
+
51
+ @data_loader_factory.register_data_loader_cls(SentencePredictionDataConfig)
52
+ class SentencePredictionDataLoader(data_loader.DataLoader):
53
+ """A class to load dataset for sentence prediction (classification) task."""
54
+
55
+ def __init__(self, params):
56
+ self._params = params
57
+ self._seq_length = params.seq_length
58
+ self._include_example_id = params.include_example_id
59
+ self._label_field = params.label_field
60
+ if params.label_name:
61
+ self._label_name_mapping = dict([params.label_name])
62
+ else:
63
+ self._label_name_mapping = dict()
64
+
65
+ def name_to_features_spec(self):
66
+ """Defines features to decode. Subclass may override to append features."""
67
+ label_type = LABEL_TYPES_MAP[self._params.label_type]
68
+ name_to_features = {
69
+ 'input_ids': tf.io.FixedLenFeature([self._seq_length], tf.int64),
70
+ 'input_mask': tf.io.FixedLenFeature([self._seq_length], tf.int64),
71
+ 'segment_ids': tf.io.FixedLenFeature([self._seq_length], tf.int64),
72
+ self._label_field: tf.io.FixedLenFeature([], label_type),
73
+ }
74
+ if self._include_example_id:
75
+ name_to_features['example_id'] = tf.io.FixedLenFeature([], tf.int64)
76
+
77
+ return name_to_features
78
+
79
+ def _decode(self, record: tf.Tensor):
80
+ """Decodes a serialized tf.Example."""
81
+ example = tf.io.parse_single_example(record, self.name_to_features_spec())
82
+
83
+ # tf.Example only supports tf.int64, but the TPU only supports tf.int32.
84
+ # So cast all int64 to int32.
85
+ for name in example:
86
+ t = example[name]
87
+ if t.dtype == tf.int64:
88
+ t = tf.cast(t, tf.int32)
89
+ example[name] = t
90
+
91
+ return example
92
+
93
+ def _parse(self, record: Mapping[str, tf.Tensor]):
94
+ """Parses raw tensors into a dict of tensors to be consumed by the model."""
95
+ key_mapping = {
96
+ 'input_ids': 'input_word_ids',
97
+ 'input_mask': 'input_mask',
98
+ 'segment_ids': 'input_type_ids'
99
+ }
100
+ ret = {}
101
+ for record_key in record:
102
+ if record_key in key_mapping:
103
+ ret[key_mapping[record_key]] = record[record_key]
104
+ else:
105
+ ret[record_key] = record[record_key]
106
+
107
+ if self._label_field in self._label_name_mapping:
108
+ ret[self._label_name_mapping[self._label_field]] = record[
109
+ self._label_field]
110
+
111
+ return ret
112
+
113
+ def load(self, input_context: Optional[tf.distribute.InputContext] = None):
114
+ """Returns a tf.dataset.Dataset."""
115
+ reader = input_reader.InputReader(
116
+ dataset_fn=dataset_fn.pick_dataset_fn(self._params.file_type),
117
+ params=self._params,
118
+ decoder_fn=self._decode,
119
+ parser_fn=self._parse)
120
+ return reader.read(input_context)
121
+
122
+
123
+ @dataclasses.dataclass
124
+ class SentencePredictionTextDataConfig(cfg.DataConfig):
125
+ """Data config for sentence prediction task with raw text."""
126
+ # Either set `input_path`...
127
+ input_path: str = ''
128
+ # Either `int` or `float`.
129
+ label_type: str = 'int'
130
+ # ...or `tfds_name` and `tfds_split` to specify input.
131
+ tfds_name: str = ''
132
+ tfds_split: str = ''
133
+ # The name of the text feature fields. The text features will be
134
+ # concatenated in order.
135
+ text_fields: Optional[List[str]] = None
136
+ label_field: str = 'label'
137
+ global_batch_size: int = 32
138
+ seq_length: int = 128
139
+ is_training: bool = True
140
+ # Either build preprocessing with Python code by specifying these values
141
+ # for modeling.layers.BertTokenizer()/SentencepieceTokenizer()....
142
+ tokenization: str = 'WordPiece' # WordPiece or SentencePiece
143
+ # Text vocab file if tokenization is WordPiece, or sentencepiece.ModelProto
144
+ # file if tokenization is SentencePiece.
145
+ vocab_file: str = ''
146
+ lower_case: bool = True
147
+ # ...or load preprocessing from a SavedModel at this location.
148
+ preprocessing_hub_module_url: str = ''
149
+ # Either tfrecord or sstsable or recordio.
150
+ file_type: str = 'tfrecord'
151
+ include_example_id: bool = False
152
+
153
+
154
+ class TextProcessor(tf.Module):
155
+ """Text features processing for sentence prediction task."""
156
+
157
+ def __init__(self,
158
+ seq_length: int,
159
+ vocab_file: Optional[str] = None,
160
+ tokenization: Optional[str] = None,
161
+ lower_case: Optional[bool] = True,
162
+ preprocessing_hub_module_url: Optional[str] = None):
163
+ if preprocessing_hub_module_url:
164
+ self._preprocessing_hub_module = hub.load(preprocessing_hub_module_url)
165
+ self._tokenizer = self._preprocessing_hub_module.tokenize
166
+ self._pack_inputs = functools.partial(
167
+ self._preprocessing_hub_module.bert_pack_inputs,
168
+ seq_length=seq_length)
169
+ return
170
+
171
+ if tokenization == 'WordPiece':
172
+ self._tokenizer = modeling.layers.BertTokenizer(
173
+ vocab_file=vocab_file, lower_case=lower_case)
174
+ elif tokenization == 'SentencePiece':
175
+ self._tokenizer = modeling.layers.SentencepieceTokenizer(
176
+ model_file_path=vocab_file,
177
+ lower_case=lower_case,
178
+ strip_diacritics=True) # Strip diacritics to follow ALBERT model
179
+ else:
180
+ raise ValueError('Unsupported tokenization: %s' % tokenization)
181
+
182
+ self._pack_inputs = modeling.layers.BertPackInputs(
183
+ seq_length=seq_length,
184
+ special_tokens_dict=self._tokenizer.get_special_tokens_dict())
185
+
186
+ def __call__(self, segments):
187
+ segments = [self._tokenizer(s) for s in segments]
188
+ # BertTokenizer returns a RaggedTensor with shape [batch, word, subword],
189
+ # and SentencepieceTokenizer returns a RaggedTensor with shape
190
+ # [batch, sentencepiece],
191
+ segments = [
192
+ tf.cast(x.merge_dims(1, -1) if x.shape.rank > 2 else x, tf.int32)
193
+ for x in segments
194
+ ]
195
+ return self._pack_inputs(segments)
196
+
197
+
198
+ @data_loader_factory.register_data_loader_cls(SentencePredictionTextDataConfig)
199
+ class SentencePredictionTextDataLoader(data_loader.DataLoader):
200
+ """Loads dataset with raw text for sentence prediction task."""
201
+
202
+ def __init__(self, params):
203
+ if bool(params.tfds_name) != bool(params.tfds_split):
204
+ raise ValueError('`tfds_name` and `tfds_split` should be specified or '
205
+ 'unspecified at the same time.')
206
+ if bool(params.tfds_name) == bool(params.input_path):
207
+ raise ValueError('Must specify either `tfds_name` and `tfds_split` '
208
+ 'or `input_path`.')
209
+ if not params.text_fields:
210
+ raise ValueError('Unexpected empty text fields.')
211
+ if bool(params.vocab_file) == bool(params.preprocessing_hub_module_url):
212
+ raise ValueError('Must specify exactly one of vocab_file (with matching '
213
+ 'lower_case flag) or preprocessing_hub_module_url.')
214
+
215
+ self._params = params
216
+ self._text_fields = params.text_fields
217
+ self._label_field = params.label_field
218
+ self._label_type = params.label_type
219
+ self._include_example_id = params.include_example_id
220
+ self._text_processor = TextProcessor(
221
+ seq_length=params.seq_length,
222
+ vocab_file=params.vocab_file,
223
+ tokenization=params.tokenization,
224
+ lower_case=params.lower_case,
225
+ preprocessing_hub_module_url=params.preprocessing_hub_module_url)
226
+
227
+ def _bert_preprocess(self, record: Mapping[str, tf.Tensor]):
228
+ """Berts preprocess."""
229
+ segments = [record[x] for x in self._text_fields]
230
+ model_inputs = self._text_processor(segments)
231
+ for key in record:
232
+ if key not in self._text_fields:
233
+ model_inputs[key] = record[key]
234
+ return model_inputs
235
+
236
+ def name_to_features_spec(self):
237
+ name_to_features = {}
238
+ for text_field in self._text_fields:
239
+ name_to_features[text_field] = tf.io.FixedLenFeature([], tf.string)
240
+
241
+ label_type = LABEL_TYPES_MAP[self._label_type]
242
+ name_to_features[self._label_field] = tf.io.FixedLenFeature([], label_type)
243
+ if self._include_example_id:
244
+ name_to_features['example_id'] = tf.io.FixedLenFeature([], tf.int64)
245
+ return name_to_features
246
+
247
+ def _decode(self, record: tf.Tensor):
248
+ """Decodes a serialized tf.Example."""
249
+ example = tf.io.parse_single_example(record, self.name_to_features_spec())
250
+ # tf.Example only supports tf.int64, but the TPU only supports tf.int32.
251
+ # So cast all int64 to int32.
252
+ for name in example:
253
+ t = example[name]
254
+ if t.dtype == tf.int64:
255
+ t = tf.cast(t, tf.int32)
256
+ example[name] = t
257
+
258
+ return example
259
+
260
+ def load(self, input_context: Optional[tf.distribute.InputContext] = None):
261
+ """Returns a tf.dataset.Dataset."""
262
+ reader = input_reader.InputReader(
263
+ dataset_fn=dataset_fn.pick_dataset_fn(self._params.file_type),
264
+ decoder_fn=self._decode if self._params.input_path else None,
265
+ params=self._params,
266
+ postprocess_fn=self._bert_preprocess)
267
+ return reader.read(input_context)
sentence_prediction_dataloader_test.py ADDED
@@ -0,0 +1,290 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The TensorFlow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Tests for official.nlp.data.sentence_prediction_dataloader."""
16
+ import os
17
+
18
+ from absl.testing import parameterized
19
+ import numpy as np
20
+ import tensorflow as tf, tf_keras
21
+
22
+ from sentencepiece import SentencePieceTrainer
23
+ from official.nlp.data import sentence_prediction_dataloader as loader
24
+
25
+
26
+ def _create_fake_preprocessed_dataset(output_path, seq_length, label_type):
27
+ """Creates a fake dataset."""
28
+ writer = tf.io.TFRecordWriter(output_path)
29
+
30
+ def create_int_feature(values):
31
+ f = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values)))
32
+ return f
33
+
34
+ def create_float_feature(values):
35
+ f = tf.train.Feature(float_list=tf.train.FloatList(value=list(values)))
36
+ return f
37
+
38
+ for _ in range(100):
39
+ features = {}
40
+ input_ids = np.random.randint(100, size=(seq_length))
41
+ features['input_ids'] = create_int_feature(input_ids)
42
+ features['input_mask'] = create_int_feature(np.ones_like(input_ids))
43
+ features['segment_ids'] = create_int_feature(np.ones_like(input_ids))
44
+
45
+ if label_type == 'int':
46
+ features['label_ids'] = create_int_feature([1])
47
+ elif label_type == 'float':
48
+ features['label_ids'] = create_float_feature([0.5])
49
+ else:
50
+ raise ValueError('Unsupported label_type: %s' % label_type)
51
+
52
+ tf_example = tf.train.Example(features=tf.train.Features(feature=features))
53
+ writer.write(tf_example.SerializeToString())
54
+ writer.close()
55
+
56
+
57
+ def _create_fake_raw_dataset(output_path, text_fields, label_type):
58
+ """Creates a fake tf record file."""
59
+ writer = tf.io.TFRecordWriter(output_path)
60
+
61
+ def create_str_feature(value):
62
+ f = tf.train.Feature(bytes_list=tf.train.BytesList(value=value))
63
+ return f
64
+
65
+ def create_int_feature(values):
66
+ f = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values)))
67
+ return f
68
+
69
+ def create_float_feature(values):
70
+ f = tf.train.Feature(float_list=tf.train.FloatList(value=list(values)))
71
+ return f
72
+
73
+ for _ in range(100):
74
+ features = {}
75
+ for text_field in text_fields:
76
+ features[text_field] = create_str_feature([b'hello world'])
77
+
78
+ if label_type == 'int':
79
+ features['label'] = create_int_feature([0])
80
+ elif label_type == 'float':
81
+ features['label'] = create_float_feature([0.5])
82
+ else:
83
+ raise ValueError('Unexpected label_type: %s' % label_type)
84
+ tf_example = tf.train.Example(features=tf.train.Features(feature=features))
85
+ writer.write(tf_example.SerializeToString())
86
+ writer.close()
87
+
88
+
89
+ def _create_fake_sentencepiece_model(output_dir):
90
+ vocab = ['a', 'b', 'c', 'd', 'e', 'abc', 'def', 'ABC', 'DEF']
91
+ model_prefix = os.path.join(output_dir, 'spm_model')
92
+ input_text_file_path = os.path.join(output_dir, 'train_input.txt')
93
+ with tf.io.gfile.GFile(input_text_file_path, 'w') as f:
94
+ f.write(' '.join(vocab + ['\n']))
95
+ # Add 7 more tokens: <pad>, <unk>, [CLS], [SEP], [MASK], <s>, </s>.
96
+ full_vocab_size = len(vocab) + 7
97
+ flags = dict(
98
+ model_prefix=model_prefix,
99
+ model_type='word',
100
+ input=input_text_file_path,
101
+ pad_id=0,
102
+ unk_id=1,
103
+ control_symbols='[CLS],[SEP],[MASK]',
104
+ vocab_size=full_vocab_size,
105
+ bos_id=full_vocab_size - 2,
106
+ eos_id=full_vocab_size - 1)
107
+ SentencePieceTrainer.Train(' '.join(
108
+ ['--{}={}'.format(k, v) for k, v in flags.items()]))
109
+ return model_prefix + '.model'
110
+
111
+
112
+ def _create_fake_vocab_file(vocab_file_path):
113
+ tokens = ['[PAD]']
114
+ for i in range(1, 100):
115
+ tokens.append('[unused%d]' % i)
116
+ tokens.extend(['[UNK]', '[CLS]', '[SEP]', '[MASK]', 'hello', 'world'])
117
+ with tf.io.gfile.GFile(vocab_file_path, 'w') as outfile:
118
+ outfile.write('\n'.join(tokens))
119
+
120
+
121
+ class SentencePredictionDataTest(tf.test.TestCase, parameterized.TestCase):
122
+
123
+ @parameterized.parameters(('int', tf.int32), ('float', tf.float32))
124
+ def test_load_dataset(self, label_type, expected_label_type):
125
+ input_path = os.path.join(self.get_temp_dir(), 'train.tf_record')
126
+ batch_size = 10
127
+ seq_length = 128
128
+ _create_fake_preprocessed_dataset(input_path, seq_length, label_type)
129
+ data_config = loader.SentencePredictionDataConfig(
130
+ input_path=input_path,
131
+ seq_length=seq_length,
132
+ global_batch_size=batch_size,
133
+ label_type=label_type)
134
+ dataset = loader.SentencePredictionDataLoader(data_config).load()
135
+ features = next(iter(dataset))
136
+ self.assertCountEqual(
137
+ ['input_word_ids', 'input_type_ids', 'input_mask', 'label_ids'],
138
+ features.keys())
139
+ self.assertEqual(features['input_word_ids'].shape, (batch_size, seq_length))
140
+ self.assertEqual(features['input_mask'].shape, (batch_size, seq_length))
141
+ self.assertEqual(features['input_type_ids'].shape, (batch_size, seq_length))
142
+ self.assertEqual(features['label_ids'].shape, (batch_size,))
143
+ self.assertEqual(features['label_ids'].dtype, expected_label_type)
144
+
145
+ def test_load_dataset_with_label_mapping(self):
146
+ input_path = os.path.join(self.get_temp_dir(), 'train.tf_record')
147
+ batch_size = 10
148
+ seq_length = 128
149
+ _create_fake_preprocessed_dataset(input_path, seq_length, 'int')
150
+ data_config = loader.SentencePredictionDataConfig(
151
+ input_path=input_path,
152
+ seq_length=seq_length,
153
+ global_batch_size=batch_size,
154
+ label_type='int',
155
+ label_name=('label_ids', 'next_sentence_labels'))
156
+ dataset = loader.SentencePredictionDataLoader(data_config).load()
157
+ features = next(iter(dataset))
158
+ self.assertCountEqual([
159
+ 'input_word_ids', 'input_mask', 'input_type_ids',
160
+ 'next_sentence_labels', 'label_ids'
161
+ ], features.keys())
162
+ self.assertEqual(features['input_word_ids'].shape, (batch_size, seq_length))
163
+ self.assertEqual(features['input_mask'].shape, (batch_size, seq_length))
164
+ self.assertEqual(features['input_type_ids'].shape, (batch_size, seq_length))
165
+ self.assertEqual(features['label_ids'].shape, (batch_size,))
166
+ self.assertEqual(features['label_ids'].dtype, tf.int32)
167
+ self.assertEqual(features['next_sentence_labels'].shape, (batch_size,))
168
+ self.assertEqual(features['next_sentence_labels'].dtype, tf.int32)
169
+
170
+
171
+ class SentencePredictionTfdsDataLoaderTest(tf.test.TestCase,
172
+ parameterized.TestCase):
173
+
174
+ @parameterized.parameters(True, False)
175
+ def test_python_wordpiece_preprocessing(self, use_tfds):
176
+ batch_size = 10
177
+ seq_length = 256 # Non-default value.
178
+ lower_case = True
179
+
180
+ tf_record_path = os.path.join(self.get_temp_dir(), 'train.tf_record')
181
+ text_fields = ['sentence1', 'sentence2']
182
+ if not use_tfds:
183
+ _create_fake_raw_dataset(tf_record_path, text_fields, label_type='int')
184
+
185
+ vocab_file_path = os.path.join(self.get_temp_dir(), 'vocab.txt')
186
+ _create_fake_vocab_file(vocab_file_path)
187
+
188
+ data_config = loader.SentencePredictionTextDataConfig(
189
+ input_path='' if use_tfds else tf_record_path,
190
+ tfds_name='glue/mrpc' if use_tfds else '',
191
+ tfds_split='train' if use_tfds else '',
192
+ text_fields=text_fields,
193
+ global_batch_size=batch_size,
194
+ seq_length=seq_length,
195
+ is_training=True,
196
+ lower_case=lower_case,
197
+ vocab_file=vocab_file_path)
198
+ dataset = loader.SentencePredictionTextDataLoader(data_config).load()
199
+ features = next(iter(dataset))
200
+ label_field = data_config.label_field
201
+ expected_keys = [
202
+ 'input_word_ids', 'input_type_ids', 'input_mask', label_field
203
+ ]
204
+ if use_tfds:
205
+ expected_keys += ['idx']
206
+ self.assertCountEqual(expected_keys, features.keys())
207
+ self.assertEqual(features['input_word_ids'].shape, (batch_size, seq_length))
208
+ self.assertEqual(features['input_mask'].shape, (batch_size, seq_length))
209
+ self.assertEqual(features['input_type_ids'].shape, (batch_size, seq_length))
210
+ self.assertEqual(features[label_field].shape, (batch_size,))
211
+
212
+ @parameterized.parameters(True, False)
213
+ def test_python_sentencepiece_preprocessing(self, use_tfds):
214
+ batch_size = 10
215
+ seq_length = 256 # Non-default value.
216
+ lower_case = True
217
+
218
+ tf_record_path = os.path.join(self.get_temp_dir(), 'train.tf_record')
219
+ text_fields = ['sentence1', 'sentence2']
220
+ if not use_tfds:
221
+ _create_fake_raw_dataset(tf_record_path, text_fields, label_type='int')
222
+
223
+ sp_model_file_path = _create_fake_sentencepiece_model(self.get_temp_dir())
224
+ data_config = loader.SentencePredictionTextDataConfig(
225
+ input_path='' if use_tfds else tf_record_path,
226
+ tfds_name='glue/mrpc' if use_tfds else '',
227
+ tfds_split='train' if use_tfds else '',
228
+ text_fields=text_fields,
229
+ global_batch_size=batch_size,
230
+ seq_length=seq_length,
231
+ is_training=True,
232
+ lower_case=lower_case,
233
+ tokenization='SentencePiece',
234
+ vocab_file=sp_model_file_path,
235
+ )
236
+ dataset = loader.SentencePredictionTextDataLoader(data_config).load()
237
+ features = next(iter(dataset))
238
+ label_field = data_config.label_field
239
+ expected_keys = [
240
+ 'input_word_ids', 'input_type_ids', 'input_mask', label_field
241
+ ]
242
+ if use_tfds:
243
+ expected_keys += ['idx']
244
+ self.assertCountEqual(expected_keys, features.keys())
245
+ self.assertEqual(features['input_word_ids'].shape, (batch_size, seq_length))
246
+ self.assertEqual(features['input_mask'].shape, (batch_size, seq_length))
247
+ self.assertEqual(features['input_type_ids'].shape, (batch_size, seq_length))
248
+ self.assertEqual(features[label_field].shape, (batch_size,))
249
+
250
+ @parameterized.parameters(True, False)
251
+ def test_saved_model_preprocessing(self, use_tfds):
252
+ batch_size = 10
253
+ seq_length = 256 # Non-default value.
254
+
255
+ tf_record_path = os.path.join(self.get_temp_dir(), 'train.tf_record')
256
+ text_fields = ['sentence1', 'sentence2']
257
+ if not use_tfds:
258
+ _create_fake_raw_dataset(tf_record_path, text_fields, label_type='float')
259
+
260
+ vocab_file_path = os.path.join(self.get_temp_dir(), 'vocab.txt')
261
+ _create_fake_vocab_file(vocab_file_path)
262
+ data_config = loader.SentencePredictionTextDataConfig(
263
+ input_path='' if use_tfds else tf_record_path,
264
+ tfds_name='glue/mrpc' if use_tfds else '',
265
+ tfds_split='train' if use_tfds else '',
266
+ text_fields=text_fields,
267
+ global_batch_size=batch_size,
268
+ seq_length=seq_length,
269
+ is_training=True,
270
+ preprocessing_hub_module_url=(
271
+ 'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3'),
272
+ label_type='int' if use_tfds else 'float',
273
+ )
274
+ dataset = loader.SentencePredictionTextDataLoader(data_config).load()
275
+ features = next(iter(dataset))
276
+ label_field = data_config.label_field
277
+ expected_keys = [
278
+ 'input_word_ids', 'input_type_ids', 'input_mask', label_field
279
+ ]
280
+ if use_tfds:
281
+ expected_keys += ['idx']
282
+ self.assertCountEqual(expected_keys, features.keys())
283
+ self.assertEqual(features['input_word_ids'].shape, (batch_size, seq_length))
284
+ self.assertEqual(features['input_mask'].shape, (batch_size, seq_length))
285
+ self.assertEqual(features['input_type_ids'].shape, (batch_size, seq_length))
286
+ self.assertEqual(features[label_field].shape, (batch_size,))
287
+
288
+
289
+ if __name__ == '__main__':
290
+ tf.test.main()
sentence_retrieval_lib.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The TensorFlow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """BERT library to process data for cross lingual sentence retrieval task."""
16
+
17
+ import os
18
+
19
+ from absl import logging
20
+ from official.nlp.data import classifier_data_lib
21
+ from official.nlp.tools import tokenization
22
+
23
+
24
+ class BuccProcessor(classifier_data_lib.DataProcessor):
25
+ """Procssor for Xtreme BUCC data set."""
26
+ supported_languages = ["de", "fr", "ru", "zh"]
27
+
28
+ def __init__(self, process_text_fn=tokenization.convert_to_unicode):
29
+ super(BuccProcessor, self).__init__(process_text_fn)
30
+ self.languages = BuccProcessor.supported_languages
31
+
32
+ def get_dev_examples(self, data_dir, file_pattern):
33
+ return self._create_examples(
34
+ self._read_tsv(os.path.join(data_dir, file_pattern.format("dev"))),
35
+ "sample")
36
+
37
+ def get_test_examples(self, data_dir, file_pattern):
38
+ return self._create_examples(
39
+ self._read_tsv(os.path.join(data_dir, file_pattern.format("test"))),
40
+ "test")
41
+
42
+ @staticmethod
43
+ def get_processor_name():
44
+ """See base class."""
45
+ return "BUCC"
46
+
47
+ def _create_examples(self, lines, set_type):
48
+ """Creates examples for the training and dev sets."""
49
+ examples = []
50
+ for (i, line) in enumerate(lines):
51
+ guid = "%s-%s" % (set_type, i)
52
+ example_id = int(line[0].split("-")[1])
53
+ text_a = self.process_text_fn(line[1])
54
+ examples.append(
55
+ classifier_data_lib.InputExample(
56
+ guid=guid, text_a=text_a, example_id=example_id))
57
+ return examples
58
+
59
+
60
+ class TatoebaProcessor(classifier_data_lib.DataProcessor):
61
+ """Procssor for Xtreme Tatoeba data set."""
62
+ supported_languages = [
63
+ "af", "ar", "bg", "bn", "de", "el", "es", "et", "eu", "fa", "fi", "fr",
64
+ "he", "hi", "hu", "id", "it", "ja", "jv", "ka", "kk", "ko", "ml", "mr",
65
+ "nl", "pt", "ru", "sw", "ta", "te", "th", "tl", "tr", "ur", "vi", "zh"
66
+ ]
67
+
68
+ def __init__(self, process_text_fn=tokenization.convert_to_unicode):
69
+ super(TatoebaProcessor, self).__init__(process_text_fn)
70
+ self.languages = TatoebaProcessor.supported_languages
71
+
72
+ def get_test_examples(self, data_dir, file_path):
73
+ return self._create_examples(
74
+ self._read_tsv(os.path.join(data_dir, file_path)), "test")
75
+
76
+ @staticmethod
77
+ def get_processor_name():
78
+ """See base class."""
79
+ return "TATOEBA"
80
+
81
+ def _create_examples(self, lines, set_type):
82
+ """Creates examples for the training and dev sets."""
83
+ examples = []
84
+ for (i, line) in enumerate(lines):
85
+ guid = "%s-%s" % (set_type, i)
86
+ text_a = self.process_text_fn(line[0])
87
+ examples.append(
88
+ classifier_data_lib.InputExample(
89
+ guid=guid, text_a=text_a, example_id=i))
90
+ return examples
91
+
92
+
93
+ def generate_sentence_retrevial_tf_record(processor,
94
+ data_dir,
95
+ tokenizer,
96
+ eval_data_output_path=None,
97
+ test_data_output_path=None,
98
+ max_seq_length=128):
99
+ """Generates the tf records for retrieval tasks.
100
+
101
+ Args:
102
+ processor: Input processor object to be used for generating data. Subclass
103
+ of `DataProcessor`.
104
+ data_dir: Directory that contains train/eval data to process. Data files
105
+ should be in from.
106
+ tokenizer: The tokenizer to be applied on the data.
107
+ eval_data_output_path: Output to which processed tf record for evaluation
108
+ will be saved.
109
+ test_data_output_path: Output to which processed tf record for testing
110
+ will be saved. Must be a pattern template with {} if processor has
111
+ language specific test data.
112
+ max_seq_length: Maximum sequence length of the to be generated
113
+ training/eval data.
114
+
115
+ Returns:
116
+ A dictionary containing input meta data.
117
+ """
118
+ assert eval_data_output_path or test_data_output_path
119
+
120
+ if processor.get_processor_name() == "BUCC":
121
+ path_pattern = "{}-en.{{}}.{}"
122
+
123
+ if processor.get_processor_name() == "TATOEBA":
124
+ path_pattern = "{}-en.{}"
125
+
126
+ meta_data = {
127
+ "processor_type": processor.get_processor_name(),
128
+ "max_seq_length": max_seq_length,
129
+ "number_eval_data": {},
130
+ "number_test_data": {},
131
+ }
132
+ logging.info("Start to process %s task data", processor.get_processor_name())
133
+
134
+ for lang_a in processor.languages:
135
+ for lang_b in [lang_a, "en"]:
136
+ if eval_data_output_path:
137
+ eval_input_data_examples = processor.get_dev_examples(
138
+ data_dir, os.path.join(path_pattern.format(lang_a, lang_b)))
139
+
140
+ num_eval_data = len(eval_input_data_examples)
141
+ logging.info("Processing %d dev examples of %s-en.%s", num_eval_data,
142
+ lang_a, lang_b)
143
+ output_file = os.path.join(
144
+ eval_data_output_path,
145
+ "{}-en-{}.{}.tfrecords".format(lang_a, lang_b, "dev"))
146
+ classifier_data_lib.file_based_convert_examples_to_features(
147
+ eval_input_data_examples, None, max_seq_length, tokenizer,
148
+ output_file, None)
149
+ meta_data["number_eval_data"][f"{lang_a}-en.{lang_b}"] = num_eval_data
150
+
151
+ if test_data_output_path:
152
+ test_input_data_examples = processor.get_test_examples(
153
+ data_dir, os.path.join(path_pattern.format(lang_a, lang_b)))
154
+
155
+ num_test_data = len(test_input_data_examples)
156
+ logging.info("Processing %d test examples of %s-en.%s", num_test_data,
157
+ lang_a, lang_b)
158
+ output_file = os.path.join(
159
+ test_data_output_path,
160
+ "{}-en-{}.{}.tfrecords".format(lang_a, lang_b, "test"))
161
+ classifier_data_lib.file_based_convert_examples_to_features(
162
+ test_input_data_examples, None, max_seq_length, tokenizer,
163
+ output_file, None)
164
+ meta_data["number_test_data"][f"{lang_a}-en.{lang_b}"] = num_test_data
165
+
166
+ return meta_data
squad_lib.py ADDED
@@ -0,0 +1,975 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The TensorFlow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Library to process data for SQuAD 1.1 and SQuAD 2.0."""
16
+ # pylint: disable=g-bad-import-order
17
+ import collections
18
+ import copy
19
+ import json
20
+ import math
21
+ import os
22
+
23
+ import six
24
+
25
+ from absl import logging
26
+ import tensorflow as tf, tf_keras
27
+
28
+ from official.nlp.tools import tokenization
29
+
30
+
31
+ class SquadExample(object):
32
+ """A single training/test example for simple sequence classification.
33
+
34
+ For examples without an answer, the start and end position are -1.
35
+
36
+ Attributes:
37
+ qas_id: ID of the question-answer pair.
38
+ question_text: Original text for the question.
39
+ doc_tokens: The list of tokens in the context obtained by splitting on
40
+ whitespace only.
41
+ orig_answer_text: Original text for the answer.
42
+ start_position: Starting index of the answer in `doc_tokens`.
43
+ end_position: Ending index of the answer in `doc_tokens`.
44
+ is_impossible: Whether the question is impossible to answer given the
45
+ context. Only used in SQuAD 2.0.
46
+ """
47
+
48
+ def __init__(self,
49
+ qas_id,
50
+ question_text,
51
+ doc_tokens,
52
+ orig_answer_text=None,
53
+ start_position=None,
54
+ end_position=None,
55
+ is_impossible=False):
56
+ self.qas_id = qas_id
57
+ self.question_text = question_text
58
+ self.doc_tokens = doc_tokens
59
+ self.orig_answer_text = orig_answer_text
60
+ self.start_position = start_position
61
+ self.end_position = end_position
62
+ self.is_impossible = is_impossible
63
+
64
+ def __str__(self):
65
+ return self.__repr__()
66
+
67
+ def __repr__(self):
68
+ s = ""
69
+ s += "qas_id: %s" % (tokenization.printable_text(self.qas_id))
70
+ s += ", question_text: %s" % (
71
+ tokenization.printable_text(self.question_text))
72
+ s += ", doc_tokens: [%s]" % (" ".join(self.doc_tokens))
73
+ if self.start_position:
74
+ s += ", start_position: %d" % (self.start_position)
75
+ if self.start_position:
76
+ s += ", end_position: %d" % (self.end_position)
77
+ if self.start_position:
78
+ s += ", is_impossible: %r" % (self.is_impossible)
79
+ return s
80
+
81
+
82
+ class InputFeatures(object):
83
+ """A single set of features of data."""
84
+
85
+ def __init__(self,
86
+ unique_id,
87
+ example_index,
88
+ doc_span_index,
89
+ tokens,
90
+ token_to_orig_map,
91
+ token_is_max_context,
92
+ input_ids,
93
+ input_mask,
94
+ segment_ids,
95
+ paragraph_mask=None,
96
+ class_index=None,
97
+ start_position=None,
98
+ end_position=None,
99
+ is_impossible=None):
100
+ self.unique_id = unique_id
101
+ self.example_index = example_index
102
+ self.doc_span_index = doc_span_index
103
+ self.tokens = tokens
104
+ self.token_to_orig_map = token_to_orig_map
105
+ self.token_is_max_context = token_is_max_context
106
+ self.input_ids = input_ids
107
+ self.input_mask = input_mask
108
+ self.segment_ids = segment_ids
109
+ self.start_position = start_position
110
+ self.end_position = end_position
111
+ self.is_impossible = is_impossible
112
+ self.paragraph_mask = paragraph_mask
113
+ self.class_index = class_index
114
+
115
+
116
+ class FeatureWriter(object):
117
+ """Writes InputFeature to TF example file."""
118
+
119
+ def __init__(self, filename, is_training):
120
+ self.filename = filename
121
+ self.is_training = is_training
122
+ self.num_features = 0
123
+ tf.io.gfile.makedirs(os.path.dirname(filename))
124
+ self._writer = tf.io.TFRecordWriter(filename)
125
+
126
+ def process_feature(self, feature):
127
+ """Write a InputFeature to the TFRecordWriter as a tf.train.Example."""
128
+ self.num_features += 1
129
+
130
+ def create_int_feature(values):
131
+ feature = tf.train.Feature(
132
+ int64_list=tf.train.Int64List(value=list(values)))
133
+ return feature
134
+
135
+ features = collections.OrderedDict()
136
+ features["unique_ids"] = create_int_feature([feature.unique_id])
137
+ features["input_ids"] = create_int_feature(feature.input_ids)
138
+ features["input_mask"] = create_int_feature(feature.input_mask)
139
+ features["segment_ids"] = create_int_feature(feature.segment_ids)
140
+
141
+ if feature.paragraph_mask is not None:
142
+ features["paragraph_mask"] = create_int_feature(feature.paragraph_mask)
143
+ if feature.class_index is not None:
144
+ features["class_index"] = create_int_feature([feature.class_index])
145
+
146
+ if self.is_training:
147
+ features["start_positions"] = create_int_feature([feature.start_position])
148
+ features["end_positions"] = create_int_feature([feature.end_position])
149
+ impossible = 0
150
+ if feature.is_impossible:
151
+ impossible = 1
152
+ features["is_impossible"] = create_int_feature([impossible])
153
+
154
+ tf_example = tf.train.Example(features=tf.train.Features(feature=features))
155
+ self._writer.write(tf_example.SerializeToString())
156
+
157
+ def close(self):
158
+ self._writer.close()
159
+
160
+
161
+ def read_squad_examples(input_file, is_training,
162
+ version_2_with_negative,
163
+ translated_input_folder=None):
164
+ """Read a SQuAD json file into a list of SquadExample."""
165
+ with tf.io.gfile.GFile(input_file, "r") as reader:
166
+ input_data = json.load(reader)["data"]
167
+
168
+ if translated_input_folder is not None:
169
+ translated_files = tf.io.gfile.glob(
170
+ os.path.join(translated_input_folder, "*.json"))
171
+ for file in translated_files:
172
+ with tf.io.gfile.GFile(file, "r") as reader:
173
+ input_data.extend(json.load(reader)["data"])
174
+
175
+ def is_whitespace(c):
176
+ if c == " " or c == "\t" or c == "\r" or c == "\n" or ord(c) == 0x202F:
177
+ return True
178
+ return False
179
+
180
+ examples = []
181
+ for entry in input_data:
182
+ for paragraph in entry["paragraphs"]:
183
+ paragraph_text = paragraph["context"]
184
+ doc_tokens = []
185
+ char_to_word_offset = []
186
+ prev_is_whitespace = True
187
+ for c in paragraph_text:
188
+ if is_whitespace(c):
189
+ prev_is_whitespace = True
190
+ else:
191
+ if prev_is_whitespace:
192
+ doc_tokens.append(c)
193
+ else:
194
+ doc_tokens[-1] += c
195
+ prev_is_whitespace = False
196
+ char_to_word_offset.append(len(doc_tokens) - 1)
197
+
198
+ for qa in paragraph["qas"]:
199
+ qas_id = qa["id"]
200
+ question_text = qa["question"]
201
+ start_position = None
202
+ end_position = None
203
+ orig_answer_text = None
204
+ is_impossible = False
205
+ if is_training:
206
+
207
+ if version_2_with_negative:
208
+ is_impossible = qa["is_impossible"]
209
+ if (len(qa["answers"]) != 1) and (not is_impossible):
210
+ raise ValueError(
211
+ "For training, each question should have exactly 1 answer.")
212
+ if not is_impossible:
213
+ answer = qa["answers"][0]
214
+ orig_answer_text = answer["text"]
215
+ answer_offset = answer["answer_start"]
216
+ answer_length = len(orig_answer_text)
217
+ start_position = char_to_word_offset[answer_offset]
218
+ end_position = char_to_word_offset[answer_offset + answer_length -
219
+ 1]
220
+ # Only add answers where the text can be exactly recovered from the
221
+ # document. If this CAN'T happen it's likely due to weird Unicode
222
+ # stuff so we will just skip the example.
223
+ #
224
+ # Note that this means for training mode, every example is NOT
225
+ # guaranteed to be preserved.
226
+ actual_text = " ".join(doc_tokens[start_position:(end_position +
227
+ 1)])
228
+ cleaned_answer_text = " ".join(
229
+ tokenization.whitespace_tokenize(orig_answer_text))
230
+ if actual_text.find(cleaned_answer_text) == -1:
231
+ logging.warning("Could not find answer: '%s' vs. '%s'",
232
+ actual_text, cleaned_answer_text)
233
+ continue
234
+ else:
235
+ start_position = -1
236
+ end_position = -1
237
+ orig_answer_text = ""
238
+
239
+ example = SquadExample(
240
+ qas_id=qas_id,
241
+ question_text=question_text,
242
+ doc_tokens=doc_tokens,
243
+ orig_answer_text=orig_answer_text,
244
+ start_position=start_position,
245
+ end_position=end_position,
246
+ is_impossible=is_impossible)
247
+ examples.append(example)
248
+
249
+ return examples
250
+
251
+
252
+ def convert_examples_to_features(examples,
253
+ tokenizer,
254
+ max_seq_length,
255
+ doc_stride,
256
+ max_query_length,
257
+ is_training,
258
+ output_fn,
259
+ xlnet_format=False,
260
+ batch_size=None):
261
+ """Loads a data file into a list of `InputBatch`s."""
262
+
263
+ base_id = 1000000000
264
+ unique_id = base_id
265
+ feature = None
266
+ for (example_index, example) in enumerate(examples):
267
+ query_tokens = tokenizer.tokenize(example.question_text)
268
+
269
+ if len(query_tokens) > max_query_length:
270
+ query_tokens = query_tokens[0:max_query_length]
271
+
272
+ tok_to_orig_index = []
273
+ orig_to_tok_index = []
274
+ all_doc_tokens = []
275
+ for (i, token) in enumerate(example.doc_tokens):
276
+ orig_to_tok_index.append(len(all_doc_tokens))
277
+ sub_tokens = tokenizer.tokenize(token)
278
+ for sub_token in sub_tokens:
279
+ tok_to_orig_index.append(i)
280
+ all_doc_tokens.append(sub_token)
281
+
282
+ tok_start_position = None
283
+ tok_end_position = None
284
+ if is_training and example.is_impossible:
285
+ tok_start_position = -1
286
+ tok_end_position = -1
287
+ if is_training and not example.is_impossible:
288
+ tok_start_position = orig_to_tok_index[example.start_position]
289
+ if example.end_position < len(example.doc_tokens) - 1:
290
+ tok_end_position = orig_to_tok_index[example.end_position + 1] - 1
291
+ else:
292
+ tok_end_position = len(all_doc_tokens) - 1
293
+ (tok_start_position, tok_end_position) = _improve_answer_span(
294
+ all_doc_tokens, tok_start_position, tok_end_position, tokenizer,
295
+ example.orig_answer_text)
296
+
297
+ # The -3 accounts for [CLS], [SEP] and [SEP]
298
+ max_tokens_for_doc = max_seq_length - len(query_tokens) - 3
299
+
300
+ # We can have documents that are longer than the maximum sequence length.
301
+ # To deal with this we do a sliding window approach, where we take chunks
302
+ # of the up to our max length with a stride of `doc_stride`.
303
+ _DocSpan = collections.namedtuple( # pylint: disable=invalid-name
304
+ "DocSpan", ["start", "length"])
305
+ doc_spans = []
306
+ start_offset = 0
307
+ while start_offset < len(all_doc_tokens):
308
+ length = len(all_doc_tokens) - start_offset
309
+ if length > max_tokens_for_doc:
310
+ length = max_tokens_for_doc
311
+ doc_spans.append(_DocSpan(start=start_offset, length=length))
312
+ if start_offset + length == len(all_doc_tokens):
313
+ break
314
+ start_offset += min(length, doc_stride)
315
+
316
+ for (doc_span_index, doc_span) in enumerate(doc_spans):
317
+ tokens = []
318
+ token_to_orig_map = {}
319
+ token_is_max_context = {}
320
+ segment_ids = []
321
+
322
+ # Paragraph mask used in XLNet.
323
+ # 1 represents paragraph and class tokens.
324
+ # 0 represents query and other special tokens.
325
+ paragraph_mask = []
326
+
327
+ # pylint: disable=cell-var-from-loop
328
+ def process_query(seg_q):
329
+ for token in query_tokens:
330
+ tokens.append(token)
331
+ segment_ids.append(seg_q)
332
+ paragraph_mask.append(0)
333
+ tokens.append("[SEP]")
334
+ segment_ids.append(seg_q)
335
+ paragraph_mask.append(0)
336
+
337
+ def process_paragraph(seg_p):
338
+ for i in range(doc_span.length):
339
+ split_token_index = doc_span.start + i
340
+ token_to_orig_map[len(tokens)] = tok_to_orig_index[split_token_index]
341
+
342
+ is_max_context = _check_is_max_context(doc_spans, doc_span_index,
343
+ split_token_index)
344
+ token_is_max_context[len(tokens)] = is_max_context
345
+ tokens.append(all_doc_tokens[split_token_index])
346
+ segment_ids.append(seg_p)
347
+ paragraph_mask.append(1)
348
+ tokens.append("[SEP]")
349
+ segment_ids.append(seg_p)
350
+ paragraph_mask.append(0)
351
+
352
+ def process_class(seg_class):
353
+ class_index = len(segment_ids)
354
+ tokens.append("[CLS]")
355
+ segment_ids.append(seg_class)
356
+ paragraph_mask.append(1)
357
+ return class_index
358
+
359
+ if xlnet_format:
360
+ seg_p, seg_q, seg_class, seg_pad = 0, 1, 2, 3
361
+ process_paragraph(seg_p)
362
+ process_query(seg_q)
363
+ class_index = process_class(seg_class)
364
+ else:
365
+ seg_p, seg_q, seg_class, seg_pad = 1, 0, 0, 0
366
+ class_index = process_class(seg_class)
367
+ process_query(seg_q)
368
+ process_paragraph(seg_p)
369
+
370
+ input_ids = tokenizer.convert_tokens_to_ids(tokens)
371
+
372
+ # The mask has 1 for real tokens and 0 for padding tokens. Only real
373
+ # tokens are attended to.
374
+ input_mask = [1] * len(input_ids)
375
+
376
+ # Zero-pad up to the sequence length.
377
+ while len(input_ids) < max_seq_length:
378
+ input_ids.append(0)
379
+ input_mask.append(0)
380
+ segment_ids.append(seg_pad)
381
+ paragraph_mask.append(0)
382
+
383
+ assert len(input_ids) == max_seq_length
384
+ assert len(input_mask) == max_seq_length
385
+ assert len(segment_ids) == max_seq_length
386
+ assert len(paragraph_mask) == max_seq_length
387
+
388
+ start_position = 0
389
+ end_position = 0
390
+ span_contains_answer = False
391
+
392
+ if is_training and not example.is_impossible:
393
+ # For training, if our document chunk does not contain an annotation
394
+ # we throw it out, since there is nothing to predict.
395
+ doc_start = doc_span.start
396
+ doc_end = doc_span.start + doc_span.length - 1
397
+ span_contains_answer = (tok_start_position >= doc_start and
398
+ tok_end_position <= doc_end)
399
+ if span_contains_answer:
400
+ doc_offset = 0 if xlnet_format else len(query_tokens) + 2
401
+ start_position = tok_start_position - doc_start + doc_offset
402
+ end_position = tok_end_position - doc_start + doc_offset
403
+
404
+ if example_index < 20:
405
+ logging.info("*** Example ***")
406
+ logging.info("unique_id: %s", (unique_id))
407
+ logging.info("example_index: %s", (example_index))
408
+ logging.info("doc_span_index: %s", (doc_span_index))
409
+ logging.info("tokens: %s",
410
+ " ".join([tokenization.printable_text(x) for x in tokens]))
411
+ logging.info(
412
+ "token_to_orig_map: %s", " ".join([
413
+ "%d:%d" % (x, y) for (x, y) in six.iteritems(token_to_orig_map)
414
+ ]))
415
+ logging.info(
416
+ "token_is_max_context: %s", " ".join([
417
+ "%d:%s" % (x, y)
418
+ for (x, y) in six.iteritems(token_is_max_context)
419
+ ]))
420
+ logging.info("input_ids: %s", " ".join([str(x) for x in input_ids]))
421
+ logging.info("input_mask: %s", " ".join([str(x) for x in input_mask]))
422
+ logging.info("segment_ids: %s", " ".join([str(x) for x in segment_ids]))
423
+ logging.info("paragraph_mask: %s", " ".join(
424
+ [str(x) for x in paragraph_mask]))
425
+ logging.info("class_index: %d", class_index)
426
+ if is_training:
427
+ if span_contains_answer:
428
+ answer_text = " ".join(tokens[start_position:(end_position + 1)])
429
+ logging.info("start_position: %d", (start_position))
430
+ logging.info("end_position: %d", (end_position))
431
+ logging.info("answer: %s", tokenization.printable_text(answer_text))
432
+ else:
433
+ logging.info("document span doesn't contain answer")
434
+
435
+ feature = InputFeatures(
436
+ unique_id=unique_id,
437
+ example_index=example_index,
438
+ doc_span_index=doc_span_index,
439
+ tokens=tokens,
440
+ paragraph_mask=paragraph_mask,
441
+ class_index=class_index,
442
+ token_to_orig_map=token_to_orig_map,
443
+ token_is_max_context=token_is_max_context,
444
+ input_ids=input_ids,
445
+ input_mask=input_mask,
446
+ segment_ids=segment_ids,
447
+ start_position=start_position,
448
+ end_position=end_position,
449
+ is_impossible=not span_contains_answer)
450
+
451
+ # Run callback
452
+ if is_training:
453
+ output_fn(feature)
454
+ else:
455
+ output_fn(feature, is_padding=False)
456
+
457
+ unique_id += 1
458
+
459
+ if not is_training and feature:
460
+ assert batch_size
461
+ num_padding = 0
462
+ num_examples = unique_id - base_id
463
+ if unique_id % batch_size != 0:
464
+ num_padding = batch_size - (num_examples % batch_size)
465
+ logging.info("Adding padding examples to make sure no partial batch.")
466
+ logging.info("Adds %d padding examples for inference.", num_padding)
467
+ dummy_feature = copy.deepcopy(feature)
468
+ for _ in range(num_padding):
469
+ dummy_feature.unique_id = unique_id
470
+
471
+ # Run callback
472
+ output_fn(feature, is_padding=True)
473
+ unique_id += 1
474
+ return unique_id - base_id
475
+
476
+
477
+ def _improve_answer_span(doc_tokens, input_start, input_end, tokenizer,
478
+ orig_answer_text):
479
+ """Returns tokenized answer spans that better match the annotated answer."""
480
+
481
+ # The SQuAD annotations are character based. We first project them to
482
+ # whitespace-tokenized words. But then after WordPiece tokenization, we can
483
+ # often find a "better match". For example:
484
+ #
485
+ # Question: What year was John Smith born?
486
+ # Context: The leader was John Smith (1895-1943).
487
+ # Answer: 1895
488
+ #
489
+ # The original whitespace-tokenized answer will be "(1895-1943).". However
490
+ # after tokenization, our tokens will be "( 1895 - 1943 ) .". So we can match
491
+ # the exact answer, 1895.
492
+ #
493
+ # However, this is not always possible. Consider the following:
494
+ #
495
+ # Question: What country is the top exporter of electronics?
496
+ # Context: The Japanese electronics industry is the lagest in the world.
497
+ # Answer: Japan
498
+ #
499
+ # In this case, the annotator chose "Japan" as a character sub-span of
500
+ # the word "Japanese". Since our WordPiece tokenizer does not split
501
+ # "Japanese", we just use "Japanese" as the annotation. This is fairly rare
502
+ # in SQuAD, but does happen.
503
+ tok_answer_text = " ".join(tokenizer.tokenize(orig_answer_text))
504
+
505
+ for new_start in range(input_start, input_end + 1):
506
+ for new_end in range(input_end, new_start - 1, -1):
507
+ text_span = " ".join(doc_tokens[new_start:(new_end + 1)])
508
+ if text_span == tok_answer_text:
509
+ return (new_start, new_end)
510
+
511
+ return (input_start, input_end)
512
+
513
+
514
+ def _check_is_max_context(doc_spans, cur_span_index, position):
515
+ """Check if this is the 'max context' doc span for the token."""
516
+
517
+ # Because of the sliding window approach taken to scoring documents, a single
518
+ # token can appear in multiple documents. E.g.
519
+ # Doc: the man went to the store and bought a gallon of milk
520
+ # Span A: the man went to the
521
+ # Span B: to the store and bought
522
+ # Span C: and bought a gallon of
523
+ # ...
524
+ #
525
+ # Now the word 'bought' will have two scores from spans B and C. We only
526
+ # want to consider the score with "maximum context", which we define as
527
+ # the *minimum* of its left and right context (the *sum* of left and
528
+ # right context will always be the same, of course).
529
+ #
530
+ # In the example the maximum context for 'bought' would be span C since
531
+ # it has 1 left context and 3 right context, while span B has 4 left context
532
+ # and 0 right context.
533
+ best_score = None
534
+ best_span_index = None
535
+ for (span_index, doc_span) in enumerate(doc_spans):
536
+ end = doc_span.start + doc_span.length - 1
537
+ if position < doc_span.start:
538
+ continue
539
+ if position > end:
540
+ continue
541
+ num_left_context = position - doc_span.start
542
+ num_right_context = end - position
543
+ score = min(num_left_context, num_right_context) + 0.01 * doc_span.length
544
+ if best_score is None or score > best_score:
545
+ best_score = score
546
+ best_span_index = span_index
547
+
548
+ return cur_span_index == best_span_index
549
+
550
+
551
+ def write_predictions(all_examples,
552
+ all_features,
553
+ all_results,
554
+ n_best_size,
555
+ max_answer_length,
556
+ do_lower_case,
557
+ output_prediction_file,
558
+ output_nbest_file,
559
+ output_null_log_odds_file,
560
+ version_2_with_negative=False,
561
+ null_score_diff_threshold=0.0,
562
+ verbose=False):
563
+ """Write final predictions to the json file and log-odds of null if needed."""
564
+ logging.info("Writing predictions to: %s", (output_prediction_file))
565
+ logging.info("Writing nbest to: %s", (output_nbest_file))
566
+
567
+ all_predictions, all_nbest_json, scores_diff_json = (
568
+ postprocess_output(
569
+ all_examples=all_examples,
570
+ all_features=all_features,
571
+ all_results=all_results,
572
+ n_best_size=n_best_size,
573
+ max_answer_length=max_answer_length,
574
+ do_lower_case=do_lower_case,
575
+ version_2_with_negative=version_2_with_negative,
576
+ null_score_diff_threshold=null_score_diff_threshold,
577
+ verbose=verbose))
578
+
579
+ write_to_json_files(all_predictions, output_prediction_file)
580
+ write_to_json_files(all_nbest_json, output_nbest_file)
581
+ if version_2_with_negative:
582
+ write_to_json_files(scores_diff_json, output_null_log_odds_file)
583
+
584
+
585
+ def postprocess_output(all_examples,
586
+ all_features,
587
+ all_results,
588
+ n_best_size,
589
+ max_answer_length,
590
+ do_lower_case,
591
+ version_2_with_negative=False,
592
+ null_score_diff_threshold=0.0,
593
+ xlnet_format=False,
594
+ verbose=False):
595
+ """Postprocess model output, to form predicton results."""
596
+
597
+ example_index_to_features = collections.defaultdict(list)
598
+ for feature in all_features:
599
+ example_index_to_features[feature.example_index].append(feature)
600
+ unique_id_to_result = {}
601
+ for result in all_results:
602
+ unique_id_to_result[result.unique_id] = result
603
+
604
+ _PrelimPrediction = collections.namedtuple( # pylint: disable=invalid-name
605
+ "PrelimPrediction",
606
+ ["feature_index", "start_index", "end_index", "start_logit", "end_logit"])
607
+
608
+ all_predictions = collections.OrderedDict()
609
+ all_nbest_json = collections.OrderedDict()
610
+ scores_diff_json = collections.OrderedDict()
611
+
612
+ for (example_index, example) in enumerate(all_examples):
613
+ features = example_index_to_features[example_index]
614
+
615
+ prelim_predictions = []
616
+ # keep track of the minimum score of null start+end of position 0
617
+ score_null = 1000000 # large and positive
618
+ min_null_feature_index = 0 # the paragraph slice with min mull score
619
+ null_start_logit = 0 # the start logit at the slice with min null score
620
+ null_end_logit = 0 # the end logit at the slice with min null score
621
+ for (feature_index, feature) in enumerate(features):
622
+ if feature.unique_id not in unique_id_to_result:
623
+ logging.info("Skip eval example %s, not in pred.", feature.unique_id)
624
+ continue
625
+ result = unique_id_to_result[feature.unique_id]
626
+
627
+ # if we could have irrelevant answers, get the min score of irrelevant
628
+ if version_2_with_negative:
629
+ if xlnet_format:
630
+ feature_null_score = result.class_logits
631
+ else:
632
+ feature_null_score = result.start_logits[0] + result.end_logits[0]
633
+ if feature_null_score < score_null:
634
+ score_null = feature_null_score
635
+ min_null_feature_index = feature_index
636
+ null_start_logit = result.start_logits[0]
637
+ null_end_logit = result.end_logits[0]
638
+ for (start_index, start_logit,
639
+ end_index, end_logit) in _get_best_indexes_and_logits(
640
+ result=result,
641
+ n_best_size=n_best_size,
642
+ xlnet_format=xlnet_format):
643
+ # We could hypothetically create invalid predictions, e.g., predict
644
+ # that the start of the span is in the question. We throw out all
645
+ # invalid predictions.
646
+ if start_index >= len(feature.tokens):
647
+ continue
648
+ if end_index >= len(feature.tokens):
649
+ continue
650
+ if start_index not in feature.token_to_orig_map:
651
+ continue
652
+ if end_index not in feature.token_to_orig_map:
653
+ continue
654
+ if not feature.token_is_max_context.get(start_index, False):
655
+ continue
656
+ if end_index < start_index:
657
+ continue
658
+ length = end_index - start_index + 1
659
+ if length > max_answer_length:
660
+ continue
661
+ prelim_predictions.append(
662
+ _PrelimPrediction(
663
+ feature_index=feature_index,
664
+ start_index=start_index,
665
+ end_index=end_index,
666
+ start_logit=start_logit,
667
+ end_logit=end_logit))
668
+
669
+ if version_2_with_negative and not xlnet_format:
670
+ prelim_predictions.append(
671
+ _PrelimPrediction(
672
+ feature_index=min_null_feature_index,
673
+ start_index=0,
674
+ end_index=0,
675
+ start_logit=null_start_logit,
676
+ end_logit=null_end_logit))
677
+ prelim_predictions = sorted(
678
+ prelim_predictions,
679
+ key=lambda x: (x.start_logit + x.end_logit),
680
+ reverse=True)
681
+
682
+ _NbestPrediction = collections.namedtuple( # pylint: disable=invalid-name
683
+ "NbestPrediction", ["text", "start_logit", "end_logit"])
684
+
685
+ seen_predictions = {}
686
+ nbest = []
687
+ for pred in prelim_predictions:
688
+ if len(nbest) >= n_best_size:
689
+ break
690
+ feature = features[pred.feature_index]
691
+ if pred.start_index > 0 or xlnet_format: # this is a non-null prediction
692
+ tok_tokens = feature.tokens[pred.start_index:(pred.end_index + 1)]
693
+ orig_doc_start = feature.token_to_orig_map[pred.start_index]
694
+ orig_doc_end = feature.token_to_orig_map[pred.end_index]
695
+ orig_tokens = example.doc_tokens[orig_doc_start:(orig_doc_end + 1)]
696
+ tok_text = " ".join(tok_tokens)
697
+
698
+ # De-tokenize WordPieces that have been split off.
699
+ tok_text = tok_text.replace(" ##", "")
700
+ tok_text = tok_text.replace("##", "")
701
+
702
+ # Clean whitespace
703
+ tok_text = tok_text.strip()
704
+ tok_text = " ".join(tok_text.split())
705
+ orig_text = " ".join(orig_tokens)
706
+
707
+ final_text = get_final_text(
708
+ tok_text, orig_text, do_lower_case, verbose=verbose)
709
+ if final_text in seen_predictions:
710
+ continue
711
+
712
+ seen_predictions[final_text] = True
713
+ else:
714
+ final_text = ""
715
+ seen_predictions[final_text] = True
716
+
717
+ nbest.append(
718
+ _NbestPrediction(
719
+ text=final_text,
720
+ start_logit=pred.start_logit,
721
+ end_logit=pred.end_logit))
722
+
723
+ # if we didn't include the empty option in the n-best, include it
724
+ if version_2_with_negative and not xlnet_format:
725
+ if "" not in seen_predictions:
726
+ nbest.append(
727
+ _NbestPrediction(
728
+ text="", start_logit=null_start_logit,
729
+ end_logit=null_end_logit))
730
+ # In very rare edge cases we could have no valid predictions. So we
731
+ # just create a nonce prediction in this case to avoid failure.
732
+ if not nbest:
733
+ nbest.append(
734
+ _NbestPrediction(text="empty", start_logit=0.0, end_logit=0.0))
735
+
736
+ assert len(nbest) >= 1
737
+
738
+ total_scores = []
739
+ best_non_null_entry = None
740
+ for entry in nbest:
741
+ total_scores.append(entry.start_logit + entry.end_logit)
742
+ if not best_non_null_entry:
743
+ if entry.text:
744
+ best_non_null_entry = entry
745
+
746
+ probs = _compute_softmax(total_scores)
747
+
748
+ nbest_json = []
749
+ for (i, entry) in enumerate(nbest):
750
+ output = collections.OrderedDict()
751
+ output["text"] = entry.text
752
+ output["probability"] = probs[i]
753
+ output["start_logit"] = entry.start_logit
754
+ output["end_logit"] = entry.end_logit
755
+ nbest_json.append(output)
756
+
757
+ assert len(nbest_json) >= 1
758
+
759
+ if not version_2_with_negative:
760
+ all_predictions[example.qas_id] = nbest_json[0]["text"]
761
+ else:
762
+ # pytype: disable=attribute-error
763
+ # predict "" iff the null score - the score of best non-null > threshold
764
+ if best_non_null_entry is not None:
765
+ if xlnet_format:
766
+ score_diff = score_null
767
+ scores_diff_json[example.qas_id] = score_diff
768
+ all_predictions[example.qas_id] = best_non_null_entry.text
769
+ else:
770
+ score_diff = score_null - best_non_null_entry.start_logit - (
771
+ best_non_null_entry.end_logit)
772
+ scores_diff_json[example.qas_id] = score_diff
773
+ if score_diff > null_score_diff_threshold:
774
+ all_predictions[example.qas_id] = ""
775
+ else:
776
+ all_predictions[example.qas_id] = best_non_null_entry.text
777
+ else:
778
+ logging.warning("best_non_null_entry is None")
779
+ scores_diff_json[example.qas_id] = score_null
780
+ all_predictions[example.qas_id] = ""
781
+ # pytype: enable=attribute-error
782
+
783
+ all_nbest_json[example.qas_id] = nbest_json
784
+
785
+ return all_predictions, all_nbest_json, scores_diff_json
786
+
787
+
788
+ def write_to_json_files(json_records, json_file):
789
+ with tf.io.gfile.GFile(json_file, "w") as writer:
790
+ writer.write(json.dumps(json_records, indent=4) + "\n")
791
+
792
+
793
+ def get_final_text(pred_text, orig_text, do_lower_case, verbose=False):
794
+ """Project the tokenized prediction back to the original text."""
795
+
796
+ # When we created the data, we kept track of the alignment between original
797
+ # (whitespace tokenized) tokens and our WordPiece tokenized tokens. So
798
+ # now `orig_text` contains the span of our original text corresponding to the
799
+ # span that we predicted.
800
+ #
801
+ # However, `orig_text` may contain extra characters that we don't want in
802
+ # our prediction.
803
+ #
804
+ # For example, let's say:
805
+ # pred_text = steve smith
806
+ # orig_text = Steve Smith's
807
+ #
808
+ # We don't want to return `orig_text` because it contains the extra "'s".
809
+ #
810
+ # We don't want to return `pred_text` because it's already been normalized
811
+ # (the SQuAD eval script also does punctuation stripping/lower casing but
812
+ # our tokenizer does additional normalization like stripping accent
813
+ # characters).
814
+ #
815
+ # What we really want to return is "Steve Smith".
816
+ #
817
+ # Therefore, we have to apply a semi-complicated alignment heruistic between
818
+ # `pred_text` and `orig_text` to get a character-to-character alignment. This
819
+ # can fail in certain cases in which case we just return `orig_text`.
820
+
821
+ def _strip_spaces(text):
822
+ ns_chars = []
823
+ ns_to_s_map = collections.OrderedDict()
824
+ for (i, c) in enumerate(text):
825
+ if c == " ":
826
+ continue
827
+ ns_to_s_map[len(ns_chars)] = i
828
+ ns_chars.append(c)
829
+ ns_text = "".join(ns_chars)
830
+ return (ns_text, ns_to_s_map)
831
+
832
+ # We first tokenize `orig_text`, strip whitespace from the result
833
+ # and `pred_text`, and check if they are the same length. If they are
834
+ # NOT the same length, the heuristic has failed. If they are the same
835
+ # length, we assume the characters are one-to-one aligned.
836
+ tokenizer = tokenization.BasicTokenizer(do_lower_case=do_lower_case)
837
+
838
+ tok_text = " ".join(tokenizer.tokenize(orig_text))
839
+
840
+ start_position = tok_text.find(pred_text)
841
+ if start_position == -1:
842
+ if verbose:
843
+ logging.info("Unable to find text: '%s' in '%s'", pred_text, orig_text)
844
+ return orig_text
845
+ end_position = start_position + len(pred_text) - 1
846
+
847
+ (orig_ns_text, orig_ns_to_s_map) = _strip_spaces(orig_text)
848
+ (tok_ns_text, tok_ns_to_s_map) = _strip_spaces(tok_text)
849
+
850
+ if len(orig_ns_text) != len(tok_ns_text):
851
+ if verbose:
852
+ logging.info("Length not equal after stripping spaces: '%s' vs '%s'",
853
+ orig_ns_text, tok_ns_text)
854
+ return orig_text
855
+
856
+ # We then project the characters in `pred_text` back to `orig_text` using
857
+ # the character-to-character alignment.
858
+ tok_s_to_ns_map = {}
859
+ for (i, tok_index) in six.iteritems(tok_ns_to_s_map):
860
+ tok_s_to_ns_map[tok_index] = i
861
+
862
+ orig_start_position = None
863
+ if start_position in tok_s_to_ns_map:
864
+ ns_start_position = tok_s_to_ns_map[start_position]
865
+ if ns_start_position in orig_ns_to_s_map:
866
+ orig_start_position = orig_ns_to_s_map[ns_start_position]
867
+
868
+ if orig_start_position is None:
869
+ if verbose:
870
+ logging.info("Couldn't map start position")
871
+ return orig_text
872
+
873
+ orig_end_position = None
874
+ if end_position in tok_s_to_ns_map:
875
+ ns_end_position = tok_s_to_ns_map[end_position]
876
+ if ns_end_position in orig_ns_to_s_map:
877
+ orig_end_position = orig_ns_to_s_map[ns_end_position]
878
+
879
+ if orig_end_position is None:
880
+ if verbose:
881
+ logging.info("Couldn't map end position")
882
+ return orig_text
883
+
884
+ output_text = orig_text[orig_start_position:(orig_end_position + 1)]
885
+ return output_text
886
+
887
+
888
+ def _get_best_indexes_and_logits(result,
889
+ n_best_size,
890
+ xlnet_format=False):
891
+ """Generates the n-best indexes and logits from a list."""
892
+ if xlnet_format:
893
+ for i in range(n_best_size):
894
+ for j in range(n_best_size):
895
+ j_index = i * n_best_size + j
896
+ yield (result.start_indexes[i], result.start_logits[i],
897
+ result.end_indexes[j_index], result.end_logits[j_index])
898
+ else:
899
+ start_index_and_score = sorted(enumerate(result.start_logits),
900
+ key=lambda x: x[1], reverse=True)
901
+ end_index_and_score = sorted(enumerate(result.end_logits),
902
+ key=lambda x: x[1], reverse=True)
903
+ for i in range(len(start_index_and_score)):
904
+ if i >= n_best_size:
905
+ break
906
+ for j in range(len(end_index_and_score)):
907
+ if j >= n_best_size:
908
+ break
909
+ yield (start_index_and_score[i][0], start_index_and_score[i][1],
910
+ end_index_and_score[j][0], end_index_and_score[j][1])
911
+
912
+
913
+ def _compute_softmax(scores):
914
+ """Compute softmax probability over raw logits."""
915
+ if not scores:
916
+ return []
917
+
918
+ max_score = None
919
+ for score in scores:
920
+ if max_score is None or score > max_score:
921
+ max_score = score
922
+
923
+ exp_scores = []
924
+ total_sum = 0.0
925
+ for score in scores:
926
+ x = math.exp(score - max_score)
927
+ exp_scores.append(x)
928
+ total_sum += x
929
+
930
+ probs = []
931
+ for score in exp_scores:
932
+ probs.append(score / total_sum)
933
+ return probs
934
+
935
+
936
+ def generate_tf_record_from_json_file(input_file_path,
937
+ vocab_file_path,
938
+ output_path,
939
+ translated_input_folder=None,
940
+ max_seq_length=384,
941
+ do_lower_case=True,
942
+ max_query_length=64,
943
+ doc_stride=128,
944
+ version_2_with_negative=False,
945
+ xlnet_format=False):
946
+ """Generates and saves training data into a tf record file."""
947
+ train_examples = read_squad_examples(
948
+ input_file=input_file_path,
949
+ is_training=True,
950
+ version_2_with_negative=version_2_with_negative,
951
+ translated_input_folder=translated_input_folder)
952
+ tokenizer = tokenization.FullTokenizer(
953
+ vocab_file=vocab_file_path, do_lower_case=do_lower_case)
954
+ train_writer = FeatureWriter(filename=output_path, is_training=True)
955
+ number_of_examples = convert_examples_to_features(
956
+ examples=train_examples,
957
+ tokenizer=tokenizer,
958
+ max_seq_length=max_seq_length,
959
+ doc_stride=doc_stride,
960
+ max_query_length=max_query_length,
961
+ is_training=True,
962
+ output_fn=train_writer.process_feature,
963
+ xlnet_format=xlnet_format)
964
+ train_writer.close()
965
+
966
+ meta_data = {
967
+ "task_type": "bert_squad",
968
+ "train_data_size": number_of_examples,
969
+ "max_seq_length": max_seq_length,
970
+ "max_query_length": max_query_length,
971
+ "doc_stride": doc_stride,
972
+ "version_2_with_negative": version_2_with_negative,
973
+ }
974
+
975
+ return meta_data
squad_lib_sp.py ADDED
@@ -0,0 +1,976 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The TensorFlow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Run ALBERT on SQuAD 1.1 and SQuAD 2.0 using sentence piece tokenization.
16
+
17
+ The file is forked from:
18
+
19
+ https://github.com/google-research/ALBERT/blob/master/run_squad_sp.py
20
+ """
21
+ import collections
22
+ import copy
23
+ import json
24
+ import math
25
+ import os
26
+
27
+ from absl import logging
28
+ import numpy as np
29
+ import tensorflow as tf, tf_keras
30
+
31
+ from official.nlp.tools import tokenization
32
+
33
+
34
+ class SquadExample(object):
35
+ """A single training/test example for simple sequence classification.
36
+
37
+ For examples without an answer, the start and end position are -1.
38
+ """
39
+
40
+ def __init__(self,
41
+ qas_id,
42
+ question_text,
43
+ paragraph_text,
44
+ orig_answer_text=None,
45
+ start_position=None,
46
+ end_position=None,
47
+ is_impossible=False):
48
+ self.qas_id = qas_id
49
+ self.question_text = question_text
50
+ self.paragraph_text = paragraph_text
51
+ self.orig_answer_text = orig_answer_text
52
+ self.start_position = start_position
53
+ self.end_position = end_position
54
+ self.is_impossible = is_impossible
55
+
56
+ def __str__(self):
57
+ return self.__repr__()
58
+
59
+ def __repr__(self):
60
+ s = ""
61
+ s += "qas_id: %s" % (tokenization.printable_text(self.qas_id))
62
+ s += ", question_text: %s" % (
63
+ tokenization.printable_text(self.question_text))
64
+ s += ", paragraph_text: [%s]" % (" ".join(self.paragraph_text))
65
+ if self.start_position:
66
+ s += ", start_position: %d" % (self.start_position,)
67
+ if self.start_position:
68
+ s += ", end_position: %d" % (self.end_position)
69
+ if self.start_position:
70
+ s += ", is_impossible: %r" % (self.is_impossible)
71
+ return s
72
+
73
+
74
+ class InputFeatures(object):
75
+ """A single set of features of data."""
76
+
77
+ def __init__(self,
78
+ unique_id,
79
+ example_index,
80
+ doc_span_index,
81
+ tok_start_to_orig_index,
82
+ tok_end_to_orig_index,
83
+ token_is_max_context,
84
+ tokens,
85
+ input_ids,
86
+ input_mask,
87
+ segment_ids,
88
+ paragraph_len,
89
+ class_index=None,
90
+ paragraph_mask=None,
91
+ start_position=None,
92
+ end_position=None,
93
+ is_impossible=None):
94
+ self.unique_id = unique_id
95
+ self.example_index = example_index
96
+ self.doc_span_index = doc_span_index
97
+ self.tok_start_to_orig_index = tok_start_to_orig_index
98
+ self.tok_end_to_orig_index = tok_end_to_orig_index
99
+ self.token_is_max_context = token_is_max_context
100
+ self.tokens = tokens
101
+ self.input_ids = input_ids
102
+ self.input_mask = input_mask
103
+ self.paragraph_mask = paragraph_mask
104
+ self.segment_ids = segment_ids
105
+ self.paragraph_len = paragraph_len
106
+ self.class_index = class_index
107
+ self.start_position = start_position
108
+ self.end_position = end_position
109
+ self.is_impossible = is_impossible
110
+
111
+
112
+ def read_squad_examples(input_file,
113
+ is_training,
114
+ version_2_with_negative,
115
+ translated_input_folder=None):
116
+ """Read a SQuAD json file into a list of SquadExample."""
117
+ del version_2_with_negative
118
+ with tf.io.gfile.GFile(input_file, "r") as reader:
119
+ input_data = json.load(reader)["data"]
120
+
121
+ if translated_input_folder is not None:
122
+ translated_files = tf.io.gfile.glob(
123
+ os.path.join(translated_input_folder, "*.json"))
124
+ for file in translated_files:
125
+ with tf.io.gfile.GFile(file, "r") as reader:
126
+ input_data.extend(json.load(reader)["data"])
127
+
128
+ examples = []
129
+ for entry in input_data:
130
+ for paragraph in entry["paragraphs"]:
131
+ paragraph_text = paragraph["context"]
132
+
133
+ for qa in paragraph["qas"]:
134
+ qas_id = qa["id"]
135
+ question_text = qa["question"]
136
+ start_position = None
137
+ orig_answer_text = None
138
+ is_impossible = False
139
+
140
+ if is_training:
141
+ is_impossible = qa.get("is_impossible", False)
142
+ if (len(qa["answers"]) != 1) and (not is_impossible):
143
+ raise ValueError(
144
+ "For training, each question should have exactly 1 answer.")
145
+ if not is_impossible:
146
+ answer = qa["answers"][0]
147
+ orig_answer_text = answer["text"]
148
+ start_position = answer["answer_start"]
149
+ else:
150
+ start_position = -1
151
+ orig_answer_text = ""
152
+
153
+ example = SquadExample(
154
+ qas_id=qas_id,
155
+ question_text=question_text,
156
+ paragraph_text=paragraph_text,
157
+ orig_answer_text=orig_answer_text,
158
+ start_position=start_position,
159
+ is_impossible=is_impossible)
160
+ examples.append(example)
161
+
162
+ return examples
163
+
164
+
165
+ def _convert_index(index, pos, m=None, is_start=True):
166
+ """Converts index."""
167
+ if index[pos] is not None:
168
+ return index[pos]
169
+ n = len(index)
170
+ rear = pos
171
+ while rear < n - 1 and index[rear] is None:
172
+ rear += 1
173
+ front = pos
174
+ while front > 0 and index[front] is None:
175
+ front -= 1
176
+ assert index[front] is not None or index[rear] is not None
177
+ if index[front] is None:
178
+ if index[rear] >= 1: # pytype: disable=unsupported-operands
179
+ if is_start:
180
+ return 0
181
+ else:
182
+ return index[rear] - 1
183
+ return index[rear]
184
+ if index[rear] is None:
185
+ if m is not None and index[front] < m - 1:
186
+ if is_start:
187
+ return index[front] + 1
188
+ else:
189
+ return m - 1
190
+ return index[front]
191
+ if is_start:
192
+ if index[rear] > index[front] + 1:
193
+ return index[front] + 1
194
+ else:
195
+ return index[rear]
196
+ else:
197
+ if index[rear] > index[front] + 1:
198
+ return index[rear] - 1
199
+ else:
200
+ return index[front]
201
+
202
+
203
+ def convert_examples_to_features(examples,
204
+ tokenizer,
205
+ max_seq_length,
206
+ doc_stride,
207
+ max_query_length,
208
+ is_training,
209
+ output_fn,
210
+ do_lower_case,
211
+ xlnet_format=False,
212
+ batch_size=None):
213
+ """Loads a data file into a list of `InputBatch`s."""
214
+ cnt_pos, cnt_neg = 0, 0
215
+ base_id = 1000000000
216
+ unique_id = base_id
217
+ max_n, max_m = 1024, 1024
218
+ f = np.zeros((max_n, max_m), dtype=np.float32)
219
+
220
+ for (example_index, example) in enumerate(examples):
221
+
222
+ if example_index % 100 == 0:
223
+ logging.info("Converting %d/%d pos %d neg %d", example_index,
224
+ len(examples), cnt_pos, cnt_neg)
225
+
226
+ query_tokens = tokenization.encode_ids(
227
+ tokenizer.sp_model,
228
+ tokenization.preprocess_text(
229
+ example.question_text, lower=do_lower_case))
230
+
231
+ if len(query_tokens) > max_query_length:
232
+ query_tokens = query_tokens[0:max_query_length]
233
+
234
+ paragraph_text = example.paragraph_text
235
+ para_tokens = tokenization.encode_pieces(
236
+ tokenizer.sp_model,
237
+ tokenization.preprocess_text(
238
+ example.paragraph_text, lower=do_lower_case))
239
+
240
+ chartok_to_tok_index = []
241
+ tok_start_to_chartok_index = []
242
+ tok_end_to_chartok_index = []
243
+ char_cnt = 0
244
+ for i, token in enumerate(para_tokens):
245
+ new_token = token.replace(tokenization.SPIECE_UNDERLINE, " ")
246
+ chartok_to_tok_index.extend([i] * len(new_token))
247
+ tok_start_to_chartok_index.append(char_cnt)
248
+ char_cnt += len(new_token)
249
+ tok_end_to_chartok_index.append(char_cnt - 1)
250
+
251
+ tok_cat_text = "".join(para_tokens).replace(tokenization.SPIECE_UNDERLINE,
252
+ " ")
253
+ n, m = len(paragraph_text), len(tok_cat_text)
254
+
255
+ if n > max_n or m > max_m:
256
+ max_n = max(n, max_n)
257
+ max_m = max(m, max_m)
258
+ f = np.zeros((max_n, max_m), dtype=np.float32)
259
+
260
+ g = {}
261
+
262
+ # pylint: disable=cell-var-from-loop
263
+ def _lcs_match(max_dist, n=n, m=m):
264
+ """Longest-common-substring algorithm."""
265
+ f.fill(0)
266
+ g.clear()
267
+
268
+ ### longest common sub sequence
269
+ # f[i, j] = max(f[i - 1, j], f[i, j - 1], f[i - 1, j - 1] + match(i, j))
270
+ for i in range(n):
271
+
272
+ # unlike standard LCS, this is specifically optimized for the setting
273
+ # because the mismatch between sentence pieces and original text will
274
+ # be small
275
+ for j in range(i - max_dist, i + max_dist):
276
+ if j >= m or j < 0:
277
+ continue
278
+
279
+ if i > 0:
280
+ g[(i, j)] = 0
281
+ f[i, j] = f[i - 1, j]
282
+
283
+ if j > 0 and f[i, j - 1] > f[i, j]:
284
+ g[(i, j)] = 1
285
+ f[i, j] = f[i, j - 1]
286
+
287
+ f_prev = f[i - 1, j - 1] if i > 0 and j > 0 else 0
288
+ if (tokenization.preprocess_text(
289
+ paragraph_text[i], lower=do_lower_case,
290
+ remove_space=False) == tok_cat_text[j] and f_prev + 1 > f[i, j]):
291
+ g[(i, j)] = 2
292
+ f[i, j] = f_prev + 1
293
+
294
+ # pylint: enable=cell-var-from-loop
295
+
296
+ max_dist = abs(n - m) + 5
297
+ for _ in range(2):
298
+ _lcs_match(max_dist)
299
+ if f[n - 1, m - 1] > 0.8 * n:
300
+ break
301
+ max_dist *= 2
302
+
303
+ orig_to_chartok_index = [None] * n
304
+ chartok_to_orig_index = [None] * m
305
+ i, j = n - 1, m - 1
306
+ while i >= 0 and j >= 0:
307
+ if (i, j) not in g:
308
+ break
309
+ if g[(i, j)] == 2:
310
+ orig_to_chartok_index[i] = j
311
+ chartok_to_orig_index[j] = i
312
+ i, j = i - 1, j - 1
313
+ elif g[(i, j)] == 1:
314
+ j = j - 1
315
+ else:
316
+ i = i - 1
317
+
318
+ if (all(v is None for v in orig_to_chartok_index) or
319
+ f[n - 1, m - 1] < 0.8 * n):
320
+ logging.info("MISMATCH DETECTED!")
321
+ continue
322
+
323
+ tok_start_to_orig_index = []
324
+ tok_end_to_orig_index = []
325
+ for i in range(len(para_tokens)):
326
+ start_chartok_pos = tok_start_to_chartok_index[i]
327
+ end_chartok_pos = tok_end_to_chartok_index[i]
328
+ start_orig_pos = _convert_index(
329
+ chartok_to_orig_index, start_chartok_pos, n, is_start=True)
330
+ end_orig_pos = _convert_index(
331
+ chartok_to_orig_index, end_chartok_pos, n, is_start=False)
332
+
333
+ tok_start_to_orig_index.append(start_orig_pos)
334
+ tok_end_to_orig_index.append(end_orig_pos)
335
+
336
+ if not is_training:
337
+ tok_start_position = tok_end_position = None
338
+
339
+ if is_training and example.is_impossible:
340
+ tok_start_position = 0
341
+ tok_end_position = 0
342
+
343
+ if is_training and not example.is_impossible:
344
+ start_position = example.start_position
345
+ end_position = start_position + len(example.orig_answer_text) - 1
346
+
347
+ start_chartok_pos = _convert_index(
348
+ orig_to_chartok_index, start_position, is_start=True)
349
+ tok_start_position = chartok_to_tok_index[start_chartok_pos]
350
+
351
+ end_chartok_pos = _convert_index(
352
+ orig_to_chartok_index, end_position, is_start=False)
353
+ tok_end_position = chartok_to_tok_index[end_chartok_pos]
354
+ assert tok_start_position <= tok_end_position
355
+
356
+ def _piece_to_id(x):
357
+ return tokenizer.sp_model.PieceToId(x)
358
+
359
+ all_doc_tokens = list(map(_piece_to_id, para_tokens))
360
+
361
+ # The -3 accounts for [CLS], [SEP] and [SEP]
362
+ max_tokens_for_doc = max_seq_length - len(query_tokens) - 3
363
+
364
+ # We can have documents that are longer than the maximum sequence length.
365
+ # To deal with this we do a sliding window approach, where we take chunks
366
+ # of the up to our max length with a stride of `doc_stride`.
367
+ _DocSpan = collections.namedtuple( # pylint: disable=invalid-name
368
+ "DocSpan", ["start", "length"])
369
+ doc_spans = []
370
+ start_offset = 0
371
+
372
+ while start_offset < len(all_doc_tokens):
373
+ length = len(all_doc_tokens) - start_offset
374
+ if length > max_tokens_for_doc:
375
+ length = max_tokens_for_doc
376
+ doc_spans.append(_DocSpan(start=start_offset, length=length))
377
+ if start_offset + length == len(all_doc_tokens):
378
+ break
379
+ start_offset += min(length, doc_stride)
380
+
381
+ for (doc_span_index, doc_span) in enumerate(doc_spans):
382
+ tokens = []
383
+ token_is_max_context = {}
384
+ segment_ids = []
385
+
386
+ # Paragraph mask used in XLNet.
387
+ # 1 represents paragraph and class tokens.
388
+ # 0 represents query and other special tokens.
389
+ paragraph_mask = []
390
+
391
+ cur_tok_start_to_orig_index = []
392
+ cur_tok_end_to_orig_index = []
393
+
394
+ # pylint: disable=cell-var-from-loop
395
+ def process_query(seg_q):
396
+ for token in query_tokens:
397
+ tokens.append(token)
398
+ segment_ids.append(seg_q)
399
+ paragraph_mask.append(0)
400
+ tokens.append(tokenizer.sp_model.PieceToId("[SEP]"))
401
+ segment_ids.append(seg_q)
402
+ paragraph_mask.append(0)
403
+
404
+ def process_paragraph(seg_p):
405
+ for i in range(doc_span.length):
406
+ split_token_index = doc_span.start + i
407
+
408
+ cur_tok_start_to_orig_index.append(
409
+ tok_start_to_orig_index[split_token_index])
410
+ cur_tok_end_to_orig_index.append(
411
+ tok_end_to_orig_index[split_token_index])
412
+
413
+ is_max_context = _check_is_max_context(doc_spans, doc_span_index,
414
+ split_token_index)
415
+ token_is_max_context[len(tokens)] = is_max_context
416
+ tokens.append(all_doc_tokens[split_token_index])
417
+ segment_ids.append(seg_p)
418
+ paragraph_mask.append(1)
419
+ tokens.append(tokenizer.sp_model.PieceToId("[SEP]"))
420
+ segment_ids.append(seg_p)
421
+ paragraph_mask.append(0)
422
+ return len(tokens)
423
+
424
+ def process_class(seg_class):
425
+ class_index = len(segment_ids)
426
+ tokens.append(tokenizer.sp_model.PieceToId("[CLS]"))
427
+ segment_ids.append(seg_class)
428
+ paragraph_mask.append(1)
429
+ return class_index
430
+
431
+ if xlnet_format:
432
+ seg_p, seg_q, seg_class, seg_pad = 0, 1, 2, 3
433
+ paragraph_len = process_paragraph(seg_p)
434
+ process_query(seg_q)
435
+ class_index = process_class(seg_class)
436
+ else:
437
+ seg_p, seg_q, seg_class, seg_pad = 1, 0, 0, 0
438
+ class_index = process_class(seg_class)
439
+ process_query(seg_q)
440
+ paragraph_len = process_paragraph(seg_p)
441
+
442
+ input_ids = tokens
443
+
444
+ # The mask has 1 for real tokens and 0 for padding tokens. Only real
445
+ # tokens are attended to.
446
+ input_mask = [1] * len(input_ids)
447
+
448
+ # Zero-pad up to the sequence length.
449
+ while len(input_ids) < max_seq_length:
450
+ input_ids.append(0)
451
+ input_mask.append(0)
452
+ segment_ids.append(seg_pad)
453
+ paragraph_mask.append(0)
454
+
455
+ assert len(input_ids) == max_seq_length
456
+ assert len(input_mask) == max_seq_length
457
+ assert len(segment_ids) == max_seq_length
458
+ assert len(paragraph_mask) == max_seq_length
459
+
460
+ span_is_impossible = example.is_impossible
461
+ start_position = None
462
+ end_position = None
463
+ if is_training and not span_is_impossible:
464
+ # For training, if our document chunk does not contain an annotation
465
+ # we throw it out, since there is nothing to predict.
466
+ doc_start = doc_span.start
467
+ doc_end = doc_span.start + doc_span.length - 1
468
+ out_of_span = False
469
+ if not (tok_start_position >= doc_start and
470
+ tok_end_position <= doc_end):
471
+ out_of_span = True
472
+ if out_of_span:
473
+ # continue
474
+ start_position = 0
475
+ end_position = 0
476
+ span_is_impossible = True
477
+ else:
478
+ doc_offset = 0 if xlnet_format else len(query_tokens) + 2
479
+ start_position = tok_start_position - doc_start + doc_offset
480
+ end_position = tok_end_position - doc_start + doc_offset
481
+
482
+ if is_training and span_is_impossible:
483
+ start_position = class_index
484
+ end_position = class_index
485
+
486
+ if example_index < 20:
487
+ logging.info("*** Example ***")
488
+ logging.info("unique_id: %s", (unique_id))
489
+ logging.info("example_index: %s", (example_index))
490
+ logging.info("doc_span_index: %s", (doc_span_index))
491
+ logging.info("tok_start_to_orig_index: %s",
492
+ " ".join([str(x) for x in cur_tok_start_to_orig_index]))
493
+ logging.info("tok_end_to_orig_index: %s",
494
+ " ".join([str(x) for x in cur_tok_end_to_orig_index]))
495
+ logging.info(
496
+ "token_is_max_context: %s", " ".join(
497
+ ["%d:%s" % (x, y) for (x, y) in token_is_max_context.items()]))
498
+ logging.info(
499
+ "input_pieces: %s",
500
+ " ".join([tokenizer.sp_model.IdToPiece(x) for x in tokens]))
501
+ logging.info("input_ids: %s", " ".join([str(x) for x in input_ids]))
502
+ logging.info("input_mask: %s", " ".join([str(x) for x in input_mask]))
503
+ logging.info("segment_ids: %s", " ".join([str(x) for x in segment_ids]))
504
+ logging.info("paragraph_mask: %s", " ".join(
505
+ [str(x) for x in paragraph_mask]))
506
+ logging.info("class_index: %d", class_index)
507
+
508
+ if is_training and span_is_impossible:
509
+ logging.info("impossible example span")
510
+
511
+ if is_training and not span_is_impossible:
512
+ pieces = [
513
+ tokenizer.sp_model.IdToPiece(token)
514
+ for token in tokens[start_position:(end_position + 1)]
515
+ ]
516
+ answer_text = tokenizer.sp_model.DecodePieces(pieces)
517
+ logging.info("start_position: %d", (start_position))
518
+ logging.info("end_position: %d", (end_position))
519
+ logging.info("answer: %s", (tokenization.printable_text(answer_text)))
520
+
521
+ # With multi processing, the example_index is actually the index
522
+ # within the current process therefore we use example_index=None
523
+ # to avoid being used in the future.
524
+ # The current code does not use example_index of training data.
525
+ if is_training:
526
+ feat_example_index = None
527
+ else:
528
+ feat_example_index = example_index
529
+
530
+ feature = InputFeatures(
531
+ unique_id=unique_id,
532
+ example_index=feat_example_index,
533
+ doc_span_index=doc_span_index,
534
+ tok_start_to_orig_index=cur_tok_start_to_orig_index,
535
+ tok_end_to_orig_index=cur_tok_end_to_orig_index,
536
+ token_is_max_context=token_is_max_context,
537
+ tokens=[tokenizer.sp_model.IdToPiece(x) for x in tokens],
538
+ input_ids=input_ids,
539
+ input_mask=input_mask,
540
+ paragraph_mask=paragraph_mask,
541
+ segment_ids=segment_ids,
542
+ paragraph_len=paragraph_len,
543
+ class_index=class_index,
544
+ start_position=start_position,
545
+ end_position=end_position,
546
+ is_impossible=span_is_impossible)
547
+
548
+ # Run callback
549
+ if is_training:
550
+ output_fn(feature)
551
+ else:
552
+ output_fn(feature, is_padding=False)
553
+
554
+ unique_id += 1
555
+ if span_is_impossible:
556
+ cnt_neg += 1
557
+ else:
558
+ cnt_pos += 1
559
+
560
+ if not is_training and feature:
561
+ assert batch_size
562
+ num_padding = 0
563
+ num_examples = unique_id - base_id
564
+ if unique_id % batch_size != 0:
565
+ num_padding = batch_size - (num_examples % batch_size)
566
+ dummy_feature = copy.deepcopy(feature)
567
+ for _ in range(num_padding):
568
+ dummy_feature.unique_id = unique_id
569
+
570
+ # Run callback
571
+ output_fn(feature, is_padding=True)
572
+ unique_id += 1
573
+
574
+ logging.info("Total number of instances: %d = pos %d neg %d",
575
+ cnt_pos + cnt_neg, cnt_pos, cnt_neg)
576
+ return unique_id - base_id
577
+
578
+
579
+ def _check_is_max_context(doc_spans, cur_span_index, position):
580
+ """Check if this is the 'max context' doc span for the token."""
581
+
582
+ # Because of the sliding window approach taken to scoring documents, a single
583
+ # token can appear in multiple documents. E.g.
584
+ # Doc: the man went to the store and bought a gallon of milk
585
+ # Span A: the man went to the
586
+ # Span B: to the store and bought
587
+ # Span C: and bought a gallon of
588
+ # ...
589
+ #
590
+ # Now the word 'bought' will have two scores from spans B and C. We only
591
+ # want to consider the score with "maximum context", which we define as
592
+ # the *minimum* of its left and right context (the *sum* of left and
593
+ # right context will always be the same, of course).
594
+ #
595
+ # In the example the maximum context for 'bought' would be span C since
596
+ # it has 1 left context and 3 right context, while span B has 4 left context
597
+ # and 0 right context.
598
+ best_score = None
599
+ best_span_index = None
600
+ for (span_index, doc_span) in enumerate(doc_spans):
601
+ end = doc_span.start + doc_span.length - 1
602
+ if position < doc_span.start:
603
+ continue
604
+ if position > end:
605
+ continue
606
+ num_left_context = position - doc_span.start
607
+ num_right_context = end - position
608
+ score = min(num_left_context, num_right_context) + 0.01 * doc_span.length
609
+ if best_score is None or score > best_score:
610
+ best_score = score
611
+ best_span_index = span_index
612
+
613
+ return cur_span_index == best_span_index
614
+
615
+
616
+ def write_predictions(all_examples,
617
+ all_features,
618
+ all_results,
619
+ n_best_size,
620
+ max_answer_length,
621
+ do_lower_case,
622
+ output_prediction_file,
623
+ output_nbest_file,
624
+ output_null_log_odds_file,
625
+ version_2_with_negative=False,
626
+ null_score_diff_threshold=0.0,
627
+ verbose=False):
628
+ """Write final predictions to the json file and log-odds of null if needed."""
629
+ logging.info("Writing predictions to: %s", (output_prediction_file))
630
+ logging.info("Writing nbest to: %s", (output_nbest_file))
631
+
632
+ all_predictions, all_nbest_json, scores_diff_json = (
633
+ postprocess_output(
634
+ all_examples=all_examples,
635
+ all_features=all_features,
636
+ all_results=all_results,
637
+ n_best_size=n_best_size,
638
+ max_answer_length=max_answer_length,
639
+ do_lower_case=do_lower_case,
640
+ version_2_with_negative=version_2_with_negative,
641
+ null_score_diff_threshold=null_score_diff_threshold,
642
+ verbose=verbose))
643
+
644
+ write_to_json_files(all_predictions, output_prediction_file)
645
+ write_to_json_files(all_nbest_json, output_nbest_file)
646
+ if version_2_with_negative:
647
+ write_to_json_files(scores_diff_json, output_null_log_odds_file)
648
+
649
+
650
+ def postprocess_output(all_examples,
651
+ all_features,
652
+ all_results,
653
+ n_best_size,
654
+ max_answer_length,
655
+ do_lower_case,
656
+ version_2_with_negative=False,
657
+ null_score_diff_threshold=0.0,
658
+ xlnet_format=False,
659
+ verbose=False):
660
+ """Postprocess model output, to form predicton results."""
661
+
662
+ del do_lower_case, verbose
663
+ example_index_to_features = collections.defaultdict(list)
664
+ for feature in all_features:
665
+ example_index_to_features[feature.example_index].append(feature)
666
+
667
+ unique_id_to_result = {}
668
+ for result in all_results:
669
+ unique_id_to_result[result.unique_id] = result
670
+
671
+ _PrelimPrediction = collections.namedtuple( # pylint: disable=invalid-name
672
+ "PrelimPrediction",
673
+ ["feature_index", "start_index", "end_index", "start_logit", "end_logit"])
674
+
675
+ all_predictions = collections.OrderedDict()
676
+ all_nbest_json = collections.OrderedDict()
677
+ scores_diff_json = collections.OrderedDict()
678
+
679
+ for (example_index, example) in enumerate(all_examples):
680
+ features = example_index_to_features[example_index]
681
+
682
+ prelim_predictions = []
683
+ # keep track of the minimum score of null start+end of position 0
684
+ score_null = 1000000 # large and positive
685
+ min_null_feature_index = 0 # the paragraph slice with min mull score
686
+ null_start_logit = 0 # the start logit at the slice with min null score
687
+ null_end_logit = 0 # the end logit at the slice with min null score
688
+ for (feature_index, feature) in enumerate(features):
689
+ if feature.unique_id not in unique_id_to_result:
690
+ logging.info("Skip eval example %s, not in pred.", feature.unique_id)
691
+ continue
692
+ result = unique_id_to_result[feature.unique_id]
693
+
694
+ # if we could have irrelevant answers, get the min score of irrelevant
695
+ if version_2_with_negative:
696
+ if xlnet_format:
697
+ feature_null_score = result.class_logits
698
+ else:
699
+ feature_null_score = result.start_logits[0] + result.end_logits[0]
700
+ if feature_null_score < score_null:
701
+ score_null = feature_null_score
702
+ min_null_feature_index = feature_index
703
+ null_start_logit = result.start_logits[0]
704
+ null_end_logit = result.end_logits[0]
705
+
706
+ doc_offset = 0 if xlnet_format else feature.tokens.index("[SEP]") + 1
707
+
708
+ for (start_index, start_logit,
709
+ end_index, end_logit) in _get_best_indexes_and_logits(
710
+ result=result,
711
+ n_best_size=n_best_size,
712
+ xlnet_format=xlnet_format):
713
+ # We could hypothetically create invalid predictions, e.g., predict
714
+ # that the start of the span is in the question. We throw out all
715
+ # invalid predictions.
716
+ if start_index - doc_offset >= len(feature.tok_start_to_orig_index):
717
+ continue
718
+ if end_index - doc_offset >= len(feature.tok_end_to_orig_index):
719
+ continue
720
+ if not feature.token_is_max_context.get(start_index, False):
721
+ continue
722
+ if end_index < start_index:
723
+ continue
724
+ length = end_index - start_index + 1
725
+ if length > max_answer_length:
726
+ continue
727
+ prelim_predictions.append(
728
+ _PrelimPrediction(
729
+ feature_index=feature_index,
730
+ start_index=start_index - doc_offset,
731
+ end_index=end_index - doc_offset,
732
+ start_logit=start_logit,
733
+ end_logit=end_logit))
734
+
735
+ if version_2_with_negative and not xlnet_format:
736
+ prelim_predictions.append(
737
+ _PrelimPrediction(
738
+ feature_index=min_null_feature_index,
739
+ start_index=-1,
740
+ end_index=-1,
741
+ start_logit=null_start_logit,
742
+ end_logit=null_end_logit))
743
+ prelim_predictions = sorted(
744
+ prelim_predictions,
745
+ key=lambda x: (x.start_logit + x.end_logit),
746
+ reverse=True)
747
+
748
+ _NbestPrediction = collections.namedtuple( # pylint: disable=invalid-name
749
+ "NbestPrediction", ["text", "start_logit", "end_logit"])
750
+
751
+ seen_predictions = {}
752
+ nbest = []
753
+ for pred in prelim_predictions:
754
+ if len(nbest) >= n_best_size:
755
+ break
756
+ feature = features[pred.feature_index]
757
+ if pred.start_index >= 0 or xlnet_format: # this is a non-null prediction
758
+ tok_start_to_orig_index = feature.tok_start_to_orig_index
759
+ tok_end_to_orig_index = feature.tok_end_to_orig_index
760
+ start_orig_pos = tok_start_to_orig_index[pred.start_index]
761
+ end_orig_pos = tok_end_to_orig_index[pred.end_index]
762
+
763
+ paragraph_text = example.paragraph_text
764
+ final_text = paragraph_text[start_orig_pos:end_orig_pos + 1].strip()
765
+ if final_text in seen_predictions:
766
+ continue
767
+
768
+ seen_predictions[final_text] = True
769
+ else:
770
+ final_text = ""
771
+ seen_predictions[final_text] = True
772
+
773
+ nbest.append(
774
+ _NbestPrediction(
775
+ text=final_text,
776
+ start_logit=pred.start_logit,
777
+ end_logit=pred.end_logit))
778
+
779
+ # if we didn't include the empty option in the n-best, include it
780
+ if version_2_with_negative and not xlnet_format:
781
+ if "" not in seen_predictions:
782
+ nbest.append(
783
+ _NbestPrediction(
784
+ text="", start_logit=null_start_logit,
785
+ end_logit=null_end_logit))
786
+ # In very rare edge cases we could have no valid predictions. So we
787
+ # just create a nonce prediction in this case to avoid failure.
788
+ if not nbest:
789
+ nbest.append(
790
+ _NbestPrediction(text="empty", start_logit=0.0, end_logit=0.0))
791
+
792
+ assert len(nbest) >= 1
793
+
794
+ total_scores = []
795
+ best_non_null_entry = None
796
+ for entry in nbest:
797
+ total_scores.append(entry.start_logit + entry.end_logit)
798
+ if not best_non_null_entry:
799
+ if entry.text:
800
+ best_non_null_entry = entry
801
+
802
+ probs = _compute_softmax(total_scores)
803
+
804
+ nbest_json = []
805
+ for (i, entry) in enumerate(nbest):
806
+ output = collections.OrderedDict()
807
+ output["text"] = entry.text
808
+ output["probability"] = probs[i]
809
+ output["start_logit"] = entry.start_logit
810
+ output["end_logit"] = entry.end_logit
811
+ nbest_json.append(output)
812
+
813
+ assert len(nbest_json) >= 1
814
+
815
+ if not version_2_with_negative:
816
+ all_predictions[example.qas_id] = nbest_json[0]["text"]
817
+ else:
818
+ assert best_non_null_entry is not None
819
+ if xlnet_format:
820
+ score_diff = score_null
821
+ scores_diff_json[example.qas_id] = score_diff
822
+ all_predictions[example.qas_id] = best_non_null_entry.text
823
+ else:
824
+ # predict "" iff the null score - the score of best non-null > threshold
825
+ score_diff = score_null - best_non_null_entry.start_logit - (
826
+ best_non_null_entry.end_logit)
827
+ scores_diff_json[example.qas_id] = score_diff
828
+ if score_diff > null_score_diff_threshold:
829
+ all_predictions[example.qas_id] = ""
830
+ else:
831
+ all_predictions[example.qas_id] = best_non_null_entry.text
832
+
833
+ all_nbest_json[example.qas_id] = nbest_json
834
+
835
+ return all_predictions, all_nbest_json, scores_diff_json
836
+
837
+
838
+ def write_to_json_files(json_records, json_file):
839
+ with tf.io.gfile.GFile(json_file, "w") as writer:
840
+ writer.write(json.dumps(json_records, indent=4) + "\n")
841
+
842
+
843
+ def _get_best_indexes_and_logits(result,
844
+ n_best_size,
845
+ xlnet_format=False):
846
+ """Generates the n-best indexes and logits from a list."""
847
+ if xlnet_format:
848
+ for i in range(n_best_size):
849
+ for j in range(n_best_size):
850
+ j_index = i * n_best_size + j
851
+ yield (result.start_indexes[i], result.start_logits[i],
852
+ result.end_indexes[j_index], result.end_logits[j_index])
853
+ else:
854
+ start_index_and_score = sorted(enumerate(result.start_logits),
855
+ key=lambda x: x[1], reverse=True)
856
+ end_index_and_score = sorted(enumerate(result.end_logits),
857
+ key=lambda x: x[1], reverse=True)
858
+ for i in range(len(start_index_and_score)):
859
+ if i >= n_best_size:
860
+ break
861
+ for j in range(len(end_index_and_score)):
862
+ if j >= n_best_size:
863
+ break
864
+ yield (start_index_and_score[i][0], start_index_and_score[i][1],
865
+ end_index_and_score[j][0], end_index_and_score[j][1])
866
+
867
+
868
+ def _compute_softmax(scores):
869
+ """Compute softmax probability over raw logits."""
870
+ if not scores:
871
+ return []
872
+
873
+ max_score = None
874
+ for score in scores:
875
+ if max_score is None or score > max_score:
876
+ max_score = score
877
+
878
+ exp_scores = []
879
+ total_sum = 0.0
880
+ for score in scores:
881
+ x = math.exp(score - max_score)
882
+ exp_scores.append(x)
883
+ total_sum += x
884
+
885
+ probs = []
886
+ for score in exp_scores:
887
+ probs.append(score / total_sum)
888
+ return probs
889
+
890
+
891
+ class FeatureWriter(object):
892
+ """Writes InputFeature to TF example file."""
893
+
894
+ def __init__(self, filename, is_training):
895
+ self.filename = filename
896
+ self.is_training = is_training
897
+ self.num_features = 0
898
+ tf.io.gfile.makedirs(os.path.dirname(filename))
899
+ self._writer = tf.io.TFRecordWriter(filename)
900
+
901
+ def process_feature(self, feature):
902
+ """Write a InputFeature to the TFRecordWriter as a tf.train.Example."""
903
+ self.num_features += 1
904
+
905
+ def create_int_feature(values):
906
+ feature = tf.train.Feature(
907
+ int64_list=tf.train.Int64List(value=list(values)))
908
+ return feature
909
+
910
+ features = collections.OrderedDict()
911
+ features["unique_ids"] = create_int_feature([feature.unique_id])
912
+ features["input_ids"] = create_int_feature(feature.input_ids)
913
+ features["input_mask"] = create_int_feature(feature.input_mask)
914
+ features["segment_ids"] = create_int_feature(feature.segment_ids)
915
+ if feature.paragraph_mask is not None:
916
+ features["paragraph_mask"] = create_int_feature(feature.paragraph_mask)
917
+ if feature.class_index is not None:
918
+ features["class_index"] = create_int_feature([feature.class_index])
919
+
920
+ if self.is_training:
921
+ features["start_positions"] = create_int_feature([feature.start_position])
922
+ features["end_positions"] = create_int_feature([feature.end_position])
923
+ impossible = 0
924
+ if feature.is_impossible:
925
+ impossible = 1
926
+ features["is_impossible"] = create_int_feature([impossible])
927
+
928
+ tf_example = tf.train.Example(features=tf.train.Features(feature=features))
929
+ self._writer.write(tf_example.SerializeToString())
930
+
931
+ def close(self):
932
+ self._writer.close()
933
+
934
+
935
+ def generate_tf_record_from_json_file(input_file_path,
936
+ sp_model_file,
937
+ output_path,
938
+ translated_input_folder=None,
939
+ max_seq_length=384,
940
+ do_lower_case=True,
941
+ max_query_length=64,
942
+ doc_stride=128,
943
+ xlnet_format=False,
944
+ version_2_with_negative=False):
945
+ """Generates and saves training data into a tf record file."""
946
+ train_examples = read_squad_examples(
947
+ input_file=input_file_path,
948
+ is_training=True,
949
+ version_2_with_negative=version_2_with_negative,
950
+ translated_input_folder=translated_input_folder)
951
+ tokenizer = tokenization.FullSentencePieceTokenizer(
952
+ sp_model_file=sp_model_file)
953
+ train_writer = FeatureWriter(
954
+ filename=output_path, is_training=True)
955
+ number_of_examples = convert_examples_to_features(
956
+ examples=train_examples,
957
+ tokenizer=tokenizer,
958
+ max_seq_length=max_seq_length,
959
+ doc_stride=doc_stride,
960
+ max_query_length=max_query_length,
961
+ is_training=True,
962
+ output_fn=train_writer.process_feature,
963
+ xlnet_format=xlnet_format,
964
+ do_lower_case=do_lower_case)
965
+ train_writer.close()
966
+
967
+ meta_data = {
968
+ "task_type": "bert_squad",
969
+ "train_data_size": number_of_examples,
970
+ "max_seq_length": max_seq_length,
971
+ "max_query_length": max_query_length,
972
+ "doc_stride": doc_stride,
973
+ "version_2_with_negative": version_2_with_negative,
974
+ }
975
+
976
+ return meta_data
tagging_data_lib.py ADDED
@@ -0,0 +1,426 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The TensorFlow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Library to process data for tagging task such as NER/POS."""
16
+ import collections
17
+ import os
18
+
19
+ from absl import logging
20
+ import tensorflow as tf, tf_keras
21
+
22
+ from official.nlp.data import classifier_data_lib
23
+ from official.nlp.tools import tokenization
24
+
25
+ # A negative label id for the padding label, which will not contribute
26
+ # to loss/metrics in training.
27
+ _PADDING_LABEL_ID = -1
28
+
29
+ # The special unknown token, used to substitute a word which has too many
30
+ # subwords after tokenization.
31
+ _UNK_TOKEN = "[UNK]"
32
+
33
+
34
+ class InputExample(object):
35
+ """A single training/test example for token classification."""
36
+
37
+ def __init__(self,
38
+ sentence_id,
39
+ sub_sentence_id=0,
40
+ words=None,
41
+ label_ids=None):
42
+ """Constructs an InputExample."""
43
+ self.sentence_id = sentence_id
44
+ self.sub_sentence_id = sub_sentence_id
45
+ self.words = words if words else []
46
+ self.label_ids = label_ids if label_ids else []
47
+
48
+ def add_word_and_label_id(self, word, label_id):
49
+ """Adds word and label_id pair in the example."""
50
+ self.words.append(word)
51
+ self.label_ids.append(label_id)
52
+
53
+
54
+ def _read_one_file(file_name, label_list):
55
+ """Reads one file and returns a list of `InputExample` instances."""
56
+ lines = tf.io.gfile.GFile(file_name, "r").readlines()
57
+ examples = []
58
+ label_id_map = {label: i for i, label in enumerate(label_list)}
59
+ sentence_id = 0
60
+ example = InputExample(sentence_id=0)
61
+ for line in lines:
62
+ line = line.strip("\n")
63
+ if line:
64
+ # The format is: <token>\t<label> for train/dev set and <token> for test.
65
+ items = line.split("\t")
66
+ assert len(items) == 2 or len(items) == 1
67
+ token = items[0].strip()
68
+
69
+ # Assign a dummy label_id for test set
70
+ label_id = label_id_map[items[1].strip()] if len(items) == 2 else 0
71
+ example.add_word_and_label_id(token, label_id)
72
+ else:
73
+ # Empty line indicates a new sentence.
74
+ if example.words:
75
+ examples.append(example)
76
+ sentence_id += 1
77
+ example = InputExample(sentence_id=sentence_id)
78
+
79
+ if example.words:
80
+ examples.append(example)
81
+ return examples
82
+
83
+
84
+ class PanxProcessor(classifier_data_lib.DataProcessor):
85
+ """Processor for the Panx data set."""
86
+ supported_languages = [
87
+ "ar", "he", "vi", "id", "jv", "ms", "tl", "eu", "ml", "ta", "te", "af",
88
+ "nl", "en", "de", "el", "bn", "hi", "mr", "ur", "fa", "fr", "it", "pt",
89
+ "es", "bg", "ru", "ja", "ka", "ko", "th", "sw", "yo", "my", "zh", "kk",
90
+ "tr", "et", "fi", "hu"
91
+ ]
92
+
93
+ def __init__(self,
94
+ process_text_fn=tokenization.convert_to_unicode,
95
+ only_use_en_train=True,
96
+ only_use_en_dev=True):
97
+ """See base class.
98
+
99
+ Args:
100
+ process_text_fn: See base class.
101
+ only_use_en_train: If True, only use english training data. Otherwise, use
102
+ training data from all languages.
103
+ only_use_en_dev: If True, only use english dev data. Otherwise, use dev
104
+ data from all languages.
105
+ """
106
+ super(PanxProcessor, self).__init__(process_text_fn)
107
+ self.only_use_en_train = only_use_en_train
108
+ self.only_use_en_dev = only_use_en_dev
109
+
110
+ def get_train_examples(self, data_dir):
111
+ examples = _read_one_file(
112
+ os.path.join(data_dir, "train-en.tsv"), self.get_labels())
113
+ if not self.only_use_en_train:
114
+ for language in self.supported_languages:
115
+ if language == "en":
116
+ continue
117
+ examples.extend(
118
+ _read_one_file(
119
+ os.path.join(data_dir, f"train-{language}.tsv"),
120
+ self.get_labels()))
121
+ return examples
122
+
123
+ def get_dev_examples(self, data_dir):
124
+ examples = _read_one_file(
125
+ os.path.join(data_dir, "dev-en.tsv"), self.get_labels())
126
+ if not self.only_use_en_dev:
127
+ for language in self.supported_languages:
128
+ if language == "en":
129
+ continue
130
+ examples.extend(
131
+ _read_one_file(
132
+ os.path.join(data_dir, f"dev-{language}.tsv"),
133
+ self.get_labels()))
134
+ return examples
135
+
136
+ def get_test_examples(self, data_dir):
137
+ examples_dict = {}
138
+ for language in self.supported_languages:
139
+ examples_dict[language] = _read_one_file(
140
+ os.path.join(data_dir, "test-%s.tsv" % language), self.get_labels())
141
+ return examples_dict
142
+
143
+ def get_labels(self):
144
+ return ["O", "B-PER", "I-PER", "B-LOC", "I-LOC", "B-ORG", "I-ORG"]
145
+
146
+ @staticmethod
147
+ def get_processor_name():
148
+ return "panx"
149
+
150
+
151
+ class UdposProcessor(classifier_data_lib.DataProcessor):
152
+ """Processor for the Udpos data set."""
153
+ supported_languages = [
154
+ "af", "ar", "bg", "de", "el", "en", "es", "et", "eu", "fa", "fi", "fr",
155
+ "he", "hi", "hu", "id", "it", "ja", "kk", "ko", "mr", "nl", "pt", "ru",
156
+ "ta", "te", "th", "tl", "tr", "ur", "vi", "yo", "zh"
157
+ ]
158
+
159
+ def __init__(self,
160
+ process_text_fn=tokenization.convert_to_unicode,
161
+ only_use_en_train=True,
162
+ only_use_en_dev=True):
163
+ """See base class.
164
+
165
+ Args:
166
+ process_text_fn: See base class.
167
+ only_use_en_train: If True, only use english training data. Otherwise, use
168
+ training data from all languages.
169
+ only_use_en_dev: If True, only use english dev data. Otherwise, use dev
170
+ data from all languages.
171
+ """
172
+ super(UdposProcessor, self).__init__(process_text_fn)
173
+ self.only_use_en_train = only_use_en_train
174
+ self.only_use_en_dev = only_use_en_dev
175
+
176
+ def get_train_examples(self, data_dir):
177
+ if self.only_use_en_train:
178
+ examples = _read_one_file(
179
+ os.path.join(data_dir, "train-en.tsv"), self.get_labels())
180
+ else:
181
+ examples = []
182
+ # Uses glob because some languages are missing in train.
183
+ for filepath in tf.io.gfile.glob(os.path.join(data_dir, "train-*.tsv")):
184
+ examples.extend(
185
+ _read_one_file(
186
+ filepath,
187
+ self.get_labels()))
188
+ return examples
189
+
190
+ def get_dev_examples(self, data_dir):
191
+ if self.only_use_en_dev:
192
+ examples = _read_one_file(
193
+ os.path.join(data_dir, "dev-en.tsv"), self.get_labels())
194
+ else:
195
+ examples = []
196
+ for filepath in tf.io.gfile.glob(os.path.join(data_dir, "dev-*.tsv")):
197
+ examples.extend(
198
+ _read_one_file(
199
+ filepath,
200
+ self.get_labels()))
201
+ return examples
202
+
203
+ def get_test_examples(self, data_dir):
204
+ examples_dict = {}
205
+ for language in self.supported_languages:
206
+ examples_dict[language] = _read_one_file(
207
+ os.path.join(data_dir, "test-%s.tsv" % language), self.get_labels())
208
+ return examples_dict
209
+
210
+ def get_labels(self):
211
+ return [
212
+ "ADJ", "ADP", "ADV", "AUX", "CCONJ", "DET", "INTJ", "NOUN", "NUM",
213
+ "PART", "PRON", "PROPN", "PUNCT", "SCONJ", "SYM", "VERB", "X"
214
+ ]
215
+
216
+ @staticmethod
217
+ def get_processor_name():
218
+ return "udpos"
219
+
220
+
221
+ def _tokenize_example(example, max_length, tokenizer, text_preprocessing=None):
222
+ """Tokenizes words and breaks long example into short ones."""
223
+ # Needs additional [CLS] and [SEP] tokens.
224
+ max_length = max_length - 2
225
+ new_examples = []
226
+ new_example = InputExample(sentence_id=example.sentence_id, sub_sentence_id=0)
227
+ if any([x < 0 for x in example.label_ids]):
228
+ raise ValueError("Unexpected negative label_id: %s" % example.label_ids)
229
+
230
+ for i, word in enumerate(example.words):
231
+ if text_preprocessing:
232
+ word = text_preprocessing(word)
233
+ subwords = tokenizer.tokenize(word)
234
+ if (not subwords or len(subwords) > max_length) and word:
235
+ subwords = [_UNK_TOKEN]
236
+
237
+ if len(subwords) + len(new_example.words) > max_length:
238
+ # Start a new example.
239
+ new_examples.append(new_example)
240
+ last_sub_sentence_id = new_example.sub_sentence_id
241
+ new_example = InputExample(
242
+ sentence_id=example.sentence_id,
243
+ sub_sentence_id=last_sub_sentence_id + 1)
244
+
245
+ for j, subword in enumerate(subwords):
246
+ # Use the real label for the first subword, and pad label for
247
+ # the remainings.
248
+ subword_label = example.label_ids[i] if j == 0 else _PADDING_LABEL_ID
249
+ new_example.add_word_and_label_id(subword, subword_label)
250
+
251
+ if new_example.words:
252
+ new_examples.append(new_example)
253
+
254
+ return new_examples
255
+
256
+
257
+ def _convert_single_example(example, max_seq_length, tokenizer):
258
+ """Converts an `InputExample` instance to a `tf.train.Example` instance."""
259
+ tokens = ["[CLS]"]
260
+ tokens.extend(example.words)
261
+ tokens.append("[SEP]")
262
+ input_ids = tokenizer.convert_tokens_to_ids(tokens)
263
+ label_ids = [_PADDING_LABEL_ID]
264
+ label_ids.extend(example.label_ids)
265
+ label_ids.append(_PADDING_LABEL_ID)
266
+
267
+ segment_ids = [0] * len(input_ids)
268
+ input_mask = [1] * len(input_ids)
269
+
270
+ # Pad up to the sequence length.
271
+ while len(input_ids) < max_seq_length:
272
+ input_ids.append(0)
273
+ input_mask.append(0)
274
+ segment_ids.append(0)
275
+ label_ids.append(_PADDING_LABEL_ID)
276
+
277
+ def create_int_feature(values):
278
+ return tf.train.Feature(int64_list=tf.train.Int64List(value=list(values)))
279
+
280
+ features = collections.OrderedDict()
281
+ features["input_ids"] = create_int_feature(input_ids)
282
+ features["input_mask"] = create_int_feature(input_mask)
283
+ features["segment_ids"] = create_int_feature(segment_ids)
284
+ features["label_ids"] = create_int_feature(label_ids)
285
+ features["sentence_id"] = create_int_feature([example.sentence_id])
286
+ features["sub_sentence_id"] = create_int_feature([example.sub_sentence_id])
287
+
288
+ tf_example = tf.train.Example(features=tf.train.Features(feature=features))
289
+ return tf_example
290
+
291
+
292
+ def write_example_to_file(examples,
293
+ tokenizer,
294
+ max_seq_length,
295
+ output_file,
296
+ text_preprocessing=None):
297
+ """Writes `InputExample`s into a tfrecord file with `tf.train.Example` protos.
298
+
299
+ Note that the words inside each example will be tokenized and be applied by
300
+ `text_preprocessing` if available. Also, if the length of sentence (plus
301
+ special [CLS] and [SEP] tokens) exceeds `max_seq_length`, the long sentence
302
+ will be broken into multiple short examples. For example:
303
+
304
+ Example (text_preprocessing=lowercase, max_seq_length=5)
305
+ words: ["What", "a", "great", "weekend"]
306
+ labels: [ 7, 5, 9, 10]
307
+ sentence_id: 0
308
+ preprocessed: ["what", "a", "great", "weekend"]
309
+ tokenized: ["what", "a", "great", "week", "##end"]
310
+
311
+ will result in two tf.example protos:
312
+
313
+ tokens: ["[CLS]", "what", "a", "great", "[SEP]"]
314
+ label_ids: [-1, 7, 5, 9, -1]
315
+ input_mask: [ 1, 1, 1, 1, 1]
316
+ segment_ids: [ 0, 0, 0, 0, 0]
317
+ input_ids: [ tokenizer.convert_tokens_to_ids(tokens) ]
318
+ sentence_id: 0
319
+
320
+ tokens: ["[CLS]", "week", "##end", "[SEP]", "[PAD]"]
321
+ label_ids: [-1, 10, -1, -1, -1]
322
+ input_mask: [ 1, 1, 1, 0, 0]
323
+ segment_ids: [ 0, 0, 0, 0, 0]
324
+ input_ids: [ tokenizer.convert_tokens_to_ids(tokens) ]
325
+ sentence_id: 0
326
+
327
+ Note the use of -1 in `label_ids` to indicate that a token should not be
328
+ considered for classification (e.g., trailing ## wordpieces or special
329
+ token). Token classification models should accordingly ignore these when
330
+ calculating loss, metrics, etc...
331
+
332
+ Args:
333
+ examples: A list of `InputExample` instances.
334
+ tokenizer: The tokenizer to be applied on the data.
335
+ max_seq_length: Maximum length of generated sequences.
336
+ output_file: The name of the output tfrecord file.
337
+ text_preprocessing: optional preprocessing run on each word prior to
338
+ tokenization.
339
+
340
+ Returns:
341
+ The total number of tf.train.Example proto written to file.
342
+ """
343
+ tf.io.gfile.makedirs(os.path.dirname(output_file))
344
+ writer = tf.io.TFRecordWriter(output_file)
345
+ num_tokenized_examples = 0
346
+ for (ex_index, example) in enumerate(examples):
347
+ if ex_index % 10000 == 0:
348
+ logging.info("Writing example %d of %d to %s", ex_index, len(examples),
349
+ output_file)
350
+
351
+ tokenized_examples = _tokenize_example(example, max_seq_length, tokenizer,
352
+ text_preprocessing)
353
+ num_tokenized_examples += len(tokenized_examples)
354
+ for per_tokenized_example in tokenized_examples:
355
+ tf_example = _convert_single_example(per_tokenized_example,
356
+ max_seq_length, tokenizer)
357
+ writer.write(tf_example.SerializeToString())
358
+
359
+ writer.close()
360
+ return num_tokenized_examples
361
+
362
+
363
+ def token_classification_meta_data(train_data_size,
364
+ max_seq_length,
365
+ num_labels,
366
+ eval_data_size=None,
367
+ test_data_size=None,
368
+ label_list=None,
369
+ processor_type=None):
370
+ """Creates metadata for tagging (token classification) datasets."""
371
+ meta_data = {
372
+ "train_data_size": train_data_size,
373
+ "max_seq_length": max_seq_length,
374
+ "num_labels": num_labels,
375
+ "task_type": "tagging",
376
+ "label_type": "int",
377
+ "label_shape": [max_seq_length],
378
+ }
379
+ if eval_data_size:
380
+ meta_data["eval_data_size"] = eval_data_size
381
+ if test_data_size:
382
+ meta_data["test_data_size"] = test_data_size
383
+ if label_list:
384
+ meta_data["label_list"] = label_list
385
+ if processor_type:
386
+ meta_data["processor_type"] = processor_type
387
+
388
+ return meta_data
389
+
390
+
391
+ def generate_tf_record_from_data_file(processor, data_dir, tokenizer,
392
+ max_seq_length, train_data_output_path,
393
+ eval_data_output_path,
394
+ test_data_output_path,
395
+ text_preprocessing):
396
+ """Generates tfrecord files from the raw data."""
397
+ common_kwargs = dict(
398
+ tokenizer=tokenizer,
399
+ max_seq_length=max_seq_length,
400
+ text_preprocessing=text_preprocessing)
401
+ train_examples = processor.get_train_examples(data_dir)
402
+ train_data_size = write_example_to_file(
403
+ train_examples, output_file=train_data_output_path, **common_kwargs)
404
+
405
+ eval_examples = processor.get_dev_examples(data_dir)
406
+ eval_data_size = write_example_to_file(
407
+ eval_examples, output_file=eval_data_output_path, **common_kwargs)
408
+
409
+ test_input_data_examples = processor.get_test_examples(data_dir)
410
+ test_data_size = {}
411
+ for language, examples in test_input_data_examples.items():
412
+ test_data_size[language] = write_example_to_file(
413
+ examples,
414
+ output_file=test_data_output_path.format(language),
415
+ **common_kwargs)
416
+
417
+ labels = processor.get_labels()
418
+ meta_data = token_classification_meta_data(
419
+ train_data_size,
420
+ max_seq_length,
421
+ len(labels),
422
+ eval_data_size,
423
+ test_data_size,
424
+ label_list=labels,
425
+ processor_type=processor.get_processor_name())
426
+ return meta_data
tagging_data_lib_test.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The TensorFlow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Tests for official.nlp.data.tagging_data_lib."""
16
+ import os
17
+ import random
18
+
19
+ from absl.testing import parameterized
20
+ import tensorflow as tf, tf_keras
21
+
22
+ from official.nlp.data import tagging_data_lib
23
+ from official.nlp.tools import tokenization
24
+
25
+
26
+ def _create_fake_file(filename, labels, is_test):
27
+
28
+ def write_one_sentence(writer, length):
29
+ for _ in range(length):
30
+ line = "hiworld"
31
+ if not is_test:
32
+ line += "\t%s" % (labels[random.randint(0, len(labels) - 1)])
33
+ writer.write(line + "\n")
34
+
35
+ # Writes two sentences with length of 3 and 12 respectively.
36
+ with tf.io.gfile.GFile(filename, "w") as writer:
37
+ write_one_sentence(writer, 3)
38
+ writer.write("\n")
39
+ write_one_sentence(writer, 12)
40
+
41
+
42
+ class TaggingDataLibTest(tf.test.TestCase, parameterized.TestCase):
43
+
44
+ def setUp(self):
45
+ super(TaggingDataLibTest, self).setUp()
46
+
47
+ self.processors = {
48
+ "panx": tagging_data_lib.PanxProcessor,
49
+ "udpos": tagging_data_lib.UdposProcessor,
50
+ }
51
+ self.vocab_file = os.path.join(self.get_temp_dir(), "vocab.txt")
52
+ with tf.io.gfile.GFile(self.vocab_file, "w") as writer:
53
+ writer.write("\n".join(["[CLS]", "[SEP]", "hi", "##world", "[UNK]"]))
54
+
55
+ @parameterized.parameters(
56
+ {"task_type": "panx"},
57
+ {"task_type": "udpos"},
58
+ )
59
+ def test_generate_tf_record(self, task_type):
60
+ processor = self.processors[task_type]()
61
+ input_data_dir = os.path.join(self.get_temp_dir(), task_type)
62
+ tf.io.gfile.mkdir(input_data_dir)
63
+ # Write fake train file.
64
+ _create_fake_file(
65
+ os.path.join(input_data_dir, "train-en.tsv"),
66
+ processor.get_labels(),
67
+ is_test=False)
68
+
69
+ # Write fake dev file.
70
+ _create_fake_file(
71
+ os.path.join(input_data_dir, "dev-en.tsv"),
72
+ processor.get_labels(),
73
+ is_test=False)
74
+
75
+ # Write fake test files.
76
+ for lang in processor.supported_languages:
77
+ _create_fake_file(
78
+ os.path.join(input_data_dir, "test-%s.tsv" % lang),
79
+ processor.get_labels(),
80
+ is_test=True)
81
+
82
+ output_path = os.path.join(self.get_temp_dir(), task_type, "output")
83
+ tokenizer = tokenization.FullTokenizer(
84
+ vocab_file=self.vocab_file, do_lower_case=True)
85
+ metadata = tagging_data_lib.generate_tf_record_from_data_file(
86
+ processor,
87
+ input_data_dir,
88
+ tokenizer,
89
+ max_seq_length=8,
90
+ train_data_output_path=os.path.join(output_path, "train.tfrecord"),
91
+ eval_data_output_path=os.path.join(output_path, "eval.tfrecord"),
92
+ test_data_output_path=os.path.join(output_path, "test_{}.tfrecord"),
93
+ text_preprocessing=tokenization.convert_to_unicode)
94
+
95
+ self.assertEqual(metadata["train_data_size"], 5)
96
+ files = tf.io.gfile.glob(output_path + "/*")
97
+ expected_files = []
98
+ expected_files.append(os.path.join(output_path, "train.tfrecord"))
99
+ expected_files.append(os.path.join(output_path, "eval.tfrecord"))
100
+ for lang in processor.supported_languages:
101
+ expected_files.append(
102
+ os.path.join(output_path, "test_%s.tfrecord" % lang))
103
+
104
+ self.assertCountEqual(files, expected_files)
105
+
106
+
107
+ if __name__ == "__main__":
108
+ tf.test.main()
tagging_dataloader.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The TensorFlow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Loads dataset for the tagging (e.g., NER/POS) task."""
16
+ import dataclasses
17
+ from typing import Mapping, Optional
18
+
19
+ import tensorflow as tf, tf_keras
20
+ from official.common import dataset_fn
21
+ from official.core import config_definitions as cfg
22
+ from official.core import input_reader
23
+ from official.nlp.data import data_loader
24
+ from official.nlp.data import data_loader_factory
25
+
26
+
27
+ @dataclasses.dataclass
28
+ class TaggingDataConfig(cfg.DataConfig):
29
+ """Data config for tagging (tasks/tagging)."""
30
+ is_training: bool = True
31
+ seq_length: int = 128
32
+ include_sentence_id: bool = False
33
+ file_type: str = 'tfrecord'
34
+
35
+
36
+ @data_loader_factory.register_data_loader_cls(TaggingDataConfig)
37
+ class TaggingDataLoader(data_loader.DataLoader):
38
+ """A class to load dataset for tagging (e.g., NER and POS) task."""
39
+
40
+ def __init__(self, params: TaggingDataConfig):
41
+ self._params = params
42
+ self._seq_length = params.seq_length
43
+ self._include_sentence_id = params.include_sentence_id
44
+
45
+ def _decode(self, record: tf.Tensor):
46
+ """Decodes a serialized tf.Example."""
47
+ name_to_features = {
48
+ 'input_ids': tf.io.FixedLenFeature([self._seq_length], tf.int64),
49
+ 'input_mask': tf.io.FixedLenFeature([self._seq_length], tf.int64),
50
+ 'segment_ids': tf.io.FixedLenFeature([self._seq_length], tf.int64),
51
+ 'label_ids': tf.io.FixedLenFeature([self._seq_length], tf.int64),
52
+ }
53
+ if self._include_sentence_id:
54
+ name_to_features['sentence_id'] = tf.io.FixedLenFeature([], tf.int64)
55
+ name_to_features['sub_sentence_id'] = tf.io.FixedLenFeature([], tf.int64)
56
+
57
+ example = tf.io.parse_single_example(record, name_to_features)
58
+
59
+ # tf.Example only supports tf.int64, but the TPU only supports tf.int32.
60
+ # So cast all int64 to int32.
61
+ for name in example:
62
+ t = example[name]
63
+ if t.dtype == tf.int64:
64
+ t = tf.cast(t, tf.int32)
65
+ example[name] = t
66
+
67
+ return example
68
+
69
+ def _parse(self, record: Mapping[str, tf.Tensor]):
70
+ """Parses raw tensors into a dict of tensors to be consumed by the model."""
71
+ x = {
72
+ 'input_word_ids': record['input_ids'],
73
+ 'input_mask': record['input_mask'],
74
+ 'input_type_ids': record['segment_ids']
75
+ }
76
+ if self._include_sentence_id:
77
+ x['sentence_id'] = record['sentence_id']
78
+ x['sub_sentence_id'] = record['sub_sentence_id']
79
+
80
+ y = record['label_ids']
81
+ return (x, y)
82
+
83
+ def load(self, input_context: Optional[tf.distribute.InputContext] = None):
84
+ """Returns a tf.dataset.Dataset."""
85
+ reader = input_reader.InputReader(
86
+ params=self._params,
87
+ dataset_fn=dataset_fn.pick_dataset_fn(self._params.file_type),
88
+ decoder_fn=self._decode,
89
+ parser_fn=self._parse)
90
+ return reader.read(input_context)
tagging_dataloader_test.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The TensorFlow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Tests for official.nlp.data.tagging_data_loader."""
16
+ import os
17
+
18
+ from absl.testing import parameterized
19
+ import numpy as np
20
+ import tensorflow as tf, tf_keras
21
+
22
+ from official.nlp.data import tagging_dataloader
23
+
24
+
25
+ def _create_fake_dataset(output_path, seq_length, include_sentence_id):
26
+ """Creates a fake dataset."""
27
+ writer = tf.io.TFRecordWriter(output_path)
28
+
29
+ def create_int_feature(values):
30
+ f = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values)))
31
+ return f
32
+
33
+ for i in range(100):
34
+ features = {}
35
+ input_ids = np.random.randint(100, size=(seq_length))
36
+ features['input_ids'] = create_int_feature(input_ids)
37
+ features['input_mask'] = create_int_feature(np.ones_like(input_ids))
38
+ features['segment_ids'] = create_int_feature(np.ones_like(input_ids))
39
+ features['label_ids'] = create_int_feature(
40
+ np.random.randint(10, size=(seq_length)))
41
+ if include_sentence_id:
42
+ features['sentence_id'] = create_int_feature([i])
43
+ features['sub_sentence_id'] = create_int_feature([0])
44
+
45
+ tf_example = tf.train.Example(features=tf.train.Features(feature=features))
46
+ writer.write(tf_example.SerializeToString())
47
+ writer.close()
48
+
49
+
50
+ class TaggingDataLoaderTest(tf.test.TestCase, parameterized.TestCase):
51
+
52
+ @parameterized.parameters(True, False)
53
+ def test_load_dataset(self, include_sentence_id):
54
+ seq_length = 16
55
+ batch_size = 10
56
+ train_data_path = os.path.join(self.get_temp_dir(), 'train.tf_record')
57
+ _create_fake_dataset(train_data_path, seq_length, include_sentence_id)
58
+ data_config = tagging_dataloader.TaggingDataConfig(
59
+ input_path=train_data_path,
60
+ seq_length=seq_length,
61
+ global_batch_size=batch_size,
62
+ include_sentence_id=include_sentence_id)
63
+
64
+ dataset = tagging_dataloader.TaggingDataLoader(data_config).load()
65
+ features, labels = next(iter(dataset))
66
+
67
+ expected_keys = ['input_word_ids', 'input_mask', 'input_type_ids']
68
+ if include_sentence_id:
69
+ expected_keys.extend(['sentence_id', 'sub_sentence_id'])
70
+ self.assertCountEqual(expected_keys, features.keys())
71
+
72
+ self.assertEqual(features['input_word_ids'].shape, (batch_size, seq_length))
73
+ self.assertEqual(features['input_mask'].shape, (batch_size, seq_length))
74
+ self.assertEqual(features['input_type_ids'].shape, (batch_size, seq_length))
75
+ self.assertEqual(labels.shape, (batch_size, seq_length))
76
+ if include_sentence_id:
77
+ self.assertEqual(features['sentence_id'].shape, (batch_size,))
78
+ self.assertEqual(features['sub_sentence_id'].shape, (batch_size,))
79
+
80
+
81
+ if __name__ == '__main__':
82
+ tf.test.main()
train_sentencepiece.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The TensorFlow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """A script to train sentencepiece model from tensorflow datasets.
16
+
17
+ Reserved tokens:
18
+ pad: 0,
19
+ eos: 1,
20
+ unk: 2
21
+ (bos is not reserved)
22
+ """
23
+
24
+ import os
25
+ import tempfile
26
+ from typing import List, Tuple
27
+
28
+ from absl import app
29
+ from absl import flags
30
+ from absl import logging
31
+ import tensorflow as tf, tf_keras
32
+ import tensorflow_datasets as tfds
33
+
34
+ from sentencepiece import SentencePieceTrainer
35
+
36
+
37
+ FLAGS = flags.FLAGS
38
+ flags.DEFINE_string("output_model_path", None,
39
+ "Path to save the sentencepiece model.")
40
+ flags.mark_flag_as_required("output_model_path")
41
+
42
+ flags.DEFINE_string("tfds_dir", None, "Directory of the tfds.")
43
+ flags.DEFINE_string("tfds_name", "wmt14_translate/de-en",
44
+ "Name of the dataset we generate vacabulay from.")
45
+ flags.DEFINE_string("tfds_split", "train", "Split of the dataset.")
46
+ flags.DEFINE_integer("vocab_size", 32000, "Size of vocabulary.")
47
+ flags.DEFINE_integer(
48
+ "max_char", -1,
49
+ "Maximum number of characters to use. "
50
+ "If a non-positive number is provided, all sentences are used.")
51
+ flags.DEFINE_string("model_type", "bpe",
52
+ "Model algorithm: unigram, bpe, word or char.")
53
+ flags.DEFINE_float("character_coverage", 0.9995,
54
+ "Character coverage to determine the minimum symbols")
55
+ flags.DEFINE_list(
56
+ "data_keys", ["en", "de"],
57
+ "Comma-separated list of keys to use for training the vocabulary.")
58
+
59
+
60
+ def dump_chars_to_textfile(dataset: tf.data.Dataset,
61
+ data_keys: Tuple[str],
62
+ max_char: int = -1):
63
+ """Write part of a TFDS sentence dataset to lines in a text file.
64
+
65
+ Args:
66
+ dataset: tf.dataset containing string-data.
67
+ data_keys: what keys in dataset to dump from.
68
+ max_char: max character to dump to text file.
69
+
70
+ Returns:
71
+ name of temp file with dataset bytes, exact number of characters dumped.
72
+ """
73
+ ds_iter = dataset.as_numpy_iterator()
74
+ with tempfile.NamedTemporaryFile(delete=False) as outfp:
75
+ char_count = 0
76
+ while True:
77
+ example = next(ds_iter, None)
78
+ if example is None or (
79
+ max_char > 0 and char_count > max_char):
80
+ break
81
+ for k in data_keys:
82
+ line = example[k] + b"\n"
83
+ char_count += len(line)
84
+ outfp.write(line)
85
+ return outfp.name
86
+
87
+
88
+ def train_sentencepiece(
89
+ file_path: str,
90
+ model_path: str,
91
+ vocab_size: int,
92
+ character_coverage: float,
93
+ model_type: str):
94
+ """Train SentencePiece tokenizer from subset of tf dataset.
95
+
96
+ Args:
97
+ file_path: path of data to train sentencepiece.
98
+ model_path: path of model file to save vocab model to.
99
+ vocab_size: size of vocab tokens to train.
100
+ character_coverage: amount of characters covered by the model, good defaults
101
+ are 0.9995 for languages with rich character set like Japanese or Chinese
102
+ and 1.0 for other languages with small character set.
103
+ model_type: type of sentencepiece vocab to train.
104
+
105
+ Returns:
106
+ path to the trained sentencepiece vocabulary model.
107
+ """
108
+ argstr = " ".join([
109
+ f"--input={file_path}", f"--vocab_size={vocab_size}",
110
+ f"--character_coverage={character_coverage}",
111
+ f"--model_prefix={model_path}", f"--model_type={model_type}",
112
+ "--bos_id=-1", "--pad_id=0", "--eos_id=1", "--unk_id=2"
113
+ ])
114
+ SentencePieceTrainer.Train(argstr)
115
+
116
+
117
+ def main(argv: List[str]):
118
+ del argv
119
+ builder = tfds.builder(FLAGS.tfds_name, data_dir=FLAGS.tfds_dir)
120
+ ds = builder.as_dataset(split=FLAGS.tfds_split)
121
+ tmp_filename = dump_chars_to_textfile(ds, FLAGS.data_keys, FLAGS.max_char)
122
+ logging.info("Sentencepiece model will be placed here: %s",
123
+ FLAGS.output_model_path)
124
+ train_sentencepiece(tmp_filename,
125
+ FLAGS.output_model_path,
126
+ FLAGS.vocab_size,
127
+ FLAGS.character_coverage,
128
+ FLAGS.model_type)
129
+ os.remove(tmp_filename)
130
+
131
+
132
+ if __name__ == "__main__":
133
+ app.run(main)
wmt_dataloader.py ADDED
@@ -0,0 +1,295 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The TensorFlow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Input pipeline for the transformer model to read, filter, and batch examples.
16
+
17
+ Batching scheme
18
+
19
+ Prior to batching, elements in the dataset are grouped by length (max between
20
+ 'inputs' and 'targets' length). Each group is then batched such that:
21
+ group_batch_size * length <= batch_size.
22
+
23
+ Another way to view batch_size is the maximum number of tokens in each batch.
24
+
25
+ Once batched, each element in the dataset will have the shape:
26
+ {'inputs': [group_batch_size, padded_input_length],
27
+ 'targets': [group_batch_size, padded_target_length]}
28
+ Lengths are padded to the longest 'inputs' or 'targets' sequence in the batch
29
+ (padded_input_length and padded_target_length can be different).
30
+
31
+ This batching scheme decreases the fraction of padding tokens per training
32
+ batch, thus improving the training speed significantly.
33
+ """
34
+ from typing import Dict, Optional
35
+
36
+ import dataclasses
37
+ import tensorflow as tf, tf_keras
38
+ import tensorflow_text as tftxt
39
+ from official.core import config_definitions as cfg
40
+ from official.core import input_reader
41
+ from official.nlp.data import data_loader
42
+ from official.nlp.data import data_loader_factory
43
+
44
+ # Example grouping constants. Defines length boundaries for each group.
45
+ # These values are the defaults used in Tensor2Tensor.
46
+ _MIN_BOUNDARY = 8
47
+ _BOUNDARY_SCALE = 1.1
48
+
49
+
50
+ def _get_example_length(example):
51
+ """Returns the maximum length between the example inputs and targets."""
52
+ length = tf.maximum(tf.shape(example[0])[0], tf.shape(example[1])[0])
53
+ return length
54
+
55
+
56
+ def _create_min_max_boundaries(max_length,
57
+ min_boundary=_MIN_BOUNDARY,
58
+ boundary_scale=_BOUNDARY_SCALE):
59
+ """Create min and max boundary lists up to max_length.
60
+
61
+ For example, when max_length=24, min_boundary=4 and boundary_scale=2, the
62
+ returned values will be:
63
+ buckets_min = [0, 4, 8, 16]
64
+ buckets_max = [4, 8, 16, 25]
65
+
66
+ Args:
67
+ max_length: The maximum length of example in dataset.
68
+ min_boundary: Minimum length in boundary.
69
+ boundary_scale: Amount to scale consecutive boundaries in the list.
70
+
71
+ Returns:
72
+ min and max boundary lists
73
+
74
+ """
75
+ # Create bucket boundaries list by scaling the previous boundary or adding 1
76
+ # (to ensure increasing boundary sizes).
77
+ bucket_boundaries = []
78
+ x = min_boundary
79
+ while x < max_length:
80
+ bucket_boundaries.append(x)
81
+ x = max(x + 1, int(x * boundary_scale))
82
+
83
+ # Create min and max boundary lists from the initial list.
84
+ buckets_min = [0] + bucket_boundaries
85
+ buckets_max = bucket_boundaries + [max_length + 1]
86
+ return buckets_min, buckets_max
87
+
88
+
89
+ def _batch_examples(dataset, batch_size, max_length):
90
+ """Group examples by similar lengths, and return batched dataset.
91
+
92
+ Each batch of similar-length examples are padded to the same length, and may
93
+ have different number of elements in each batch, such that:
94
+ group_batch_size * padded_length <= batch_size.
95
+
96
+ This decreases the number of padding tokens per batch, which improves the
97
+ training speed.
98
+
99
+ Args:
100
+ dataset: Dataset of unbatched examples.
101
+ batch_size: Max number of tokens per batch of examples.
102
+ max_length: Max number of tokens in an example input or target sequence.
103
+
104
+ Returns:
105
+ Dataset of batched examples with similar lengths.
106
+ """
107
+ # Get min and max boundary lists for each example. These are used to calculate
108
+ # the `bucket_id`, which is the index at which:
109
+ # buckets_min[bucket_id] <= len(example) < buckets_max[bucket_id]
110
+ # Note that using both min and max lists improves the performance.
111
+ buckets_min, buckets_max = _create_min_max_boundaries(max_length)
112
+
113
+ # Create list of batch sizes for each bucket_id, so that
114
+ # bucket_batch_size[bucket_id] * buckets_max[bucket_id] <= batch_size
115
+ bucket_batch_sizes = [int(batch_size) // x for x in buckets_max]
116
+
117
+ # Validates bucket batch sizes.
118
+ if any([batch_size <= 0 for batch_size in bucket_batch_sizes]):
119
+ raise ValueError(
120
+ 'The token budget, global batch size, is too small to yield 0 bucket '
121
+ 'window: %s' % str(bucket_batch_sizes))
122
+
123
+ # bucket_id will be a tensor, so convert this list to a tensor as well.
124
+ bucket_batch_sizes = tf.constant(bucket_batch_sizes, dtype=tf.int64)
125
+
126
+ def example_to_bucket_id(example):
127
+ """Return int64 bucket id for this example, calculated based on length."""
128
+ example_input = example['inputs']
129
+ example_target = example['targets']
130
+ seq_length = _get_example_length((example_input, example_target))
131
+
132
+ conditions_c = tf.logical_and(
133
+ tf.less_equal(buckets_min, seq_length), tf.less(seq_length,
134
+ buckets_max))
135
+ bucket_id = tf.reduce_min(tf.where(conditions_c))
136
+ return bucket_id
137
+
138
+ def window_size_fn(bucket_id):
139
+ """Return number of examples to be grouped when given a bucket id."""
140
+ return bucket_batch_sizes[bucket_id]
141
+
142
+ def batching_fn(bucket_id, grouped_dataset):
143
+ """Batch and add padding to a dataset of elements with similar lengths."""
144
+ bucket_batch_size = window_size_fn(bucket_id)
145
+
146
+ # Batch the dataset and add padding so that all input sequences in the
147
+ # examples have the same length, and all target sequences have the same
148
+ # lengths as well. Resulting lengths of inputs and targets can differ.
149
+ padded_shapes = dict([
150
+ (name, [None] * len(spec.shape))
151
+ for name, spec in grouped_dataset.element_spec.items()
152
+ ])
153
+ return grouped_dataset.padded_batch(bucket_batch_size, padded_shapes)
154
+
155
+ return dataset.apply(
156
+ tf.data.experimental.group_by_window(
157
+ key_func=example_to_bucket_id,
158
+ reduce_func=batching_fn,
159
+ window_size=None,
160
+ window_size_func=window_size_fn))
161
+
162
+
163
+ @dataclasses.dataclass
164
+ class WMTDataConfig(cfg.DataConfig):
165
+ """Data config for WMT translation."""
166
+ max_seq_length: int = 64
167
+ static_batch: bool = False
168
+ sentencepiece_model_path: str = ''
169
+ src_lang: str = ''
170
+ tgt_lang: str = ''
171
+ transform_and_batch: bool = True
172
+ has_unique_id: bool = False
173
+
174
+
175
+ @data_loader_factory.register_data_loader_cls(WMTDataConfig)
176
+ class WMTDataLoader(data_loader.DataLoader):
177
+ """A class to load dataset for WMT translation task."""
178
+
179
+ def __init__(self, params: WMTDataConfig):
180
+ self._params = params
181
+ self._max_seq_length = params.max_seq_length
182
+ self._static_batch = params.static_batch
183
+ self._global_batch_size = params.global_batch_size
184
+ if self._params.transform_and_batch:
185
+ self._tokenizer = tftxt.SentencepieceTokenizer(
186
+ model=tf.io.gfile.GFile(params.sentencepiece_model_path, 'rb').read(),
187
+ add_eos=True)
188
+
189
+ def _decode(self, record: tf.Tensor):
190
+ """Decodes a serialized tf.Example."""
191
+ name_to_features = {
192
+ self._params.src_lang: tf.io.FixedLenFeature([], tf.string),
193
+ self._params.tgt_lang: tf.io.FixedLenFeature([], tf.string),
194
+ }
195
+ if self._params.has_unique_id:
196
+ name_to_features['unique_id'] = tf.io.FixedLenFeature([], tf.int64)
197
+ example = tf.io.parse_single_example(record, name_to_features)
198
+ # tf.Example only supports tf.int64, but the TPU only supports tf.int32.
199
+ # So cast all int64 to int32.
200
+ for name in example:
201
+ t = example[name]
202
+ if t.dtype == tf.int64:
203
+ t = tf.cast(t, tf.int32)
204
+ example[name] = t
205
+ return example
206
+
207
+ def _tokenize(self, inputs) -> Dict[str, tf.Tensor]:
208
+ tokenized_inputs = {}
209
+ for k, v in inputs.items():
210
+ if k == self._params.src_lang:
211
+ tokenized_inputs['inputs'] = self._tokenizer.tokenize(v)
212
+ elif k == self._params.tgt_lang:
213
+ tokenized_inputs['targets'] = self._tokenizer.tokenize(v)
214
+ else:
215
+ tokenized_inputs[k] = v
216
+ print(tokenized_inputs)
217
+ return tokenized_inputs
218
+
219
+ def _filter_max_length(self, inputs):
220
+ # return tf.constant(True)
221
+ return tf.logical_and(
222
+ tf.shape(inputs['inputs'])[0] <= self._max_seq_length,
223
+ tf.shape(inputs['targets'])[0] <= self._max_seq_length)
224
+
225
+ def _maybe_truncate(self, inputs):
226
+ truncated_inputs = {}
227
+ for k, v in inputs.items():
228
+ if k == 'inputs' or k == 'targets':
229
+ truncated_inputs[k] = tf.pad(
230
+ v[:self._max_seq_length - 1], [[0, 1]],
231
+ constant_values=1) if tf.shape(v)[0] > self._max_seq_length else v
232
+ else:
233
+ truncated_inputs[k] = v
234
+ return truncated_inputs
235
+
236
+ def _tokenize_bucketize_and_batch(
237
+ self,
238
+ dataset,
239
+ input_context: Optional[tf.distribute.InputContext] = None):
240
+ dataset = dataset.map(
241
+ self._tokenize, num_parallel_calls=tf.data.experimental.AUTOTUNE)
242
+
243
+ if self._params.is_training:
244
+ dataset = dataset.filter(self._filter_max_length)
245
+ else:
246
+ dataset = dataset.map(
247
+ self._maybe_truncate,
248
+ num_parallel_calls=tf.data.experimental.AUTOTUNE)
249
+
250
+ per_replica_batch_size = input_context.get_per_replica_batch_size(
251
+ self._global_batch_size) if input_context else self._global_batch_size
252
+ if self._static_batch:
253
+ padded_shapes = {}
254
+ for name, _ in dataset.element_spec.items():
255
+ if name == 'unique_id':
256
+ padded_shapes[name] = []
257
+ else:
258
+ padded_shapes[name] = [self._max_seq_length
259
+ ] if self._static_batch else [None]
260
+ batch_size = per_replica_batch_size
261
+ if self._params.is_training:
262
+ batch_size = int(batch_size // self._max_seq_length)
263
+ dataset = dataset.padded_batch(
264
+ batch_size,
265
+ padded_shapes,
266
+ drop_remainder=True)
267
+ else:
268
+ # Group and batch such that each batch has examples of similar length.
269
+ dataset = _batch_examples(dataset, per_replica_batch_size,
270
+ self._max_seq_length)
271
+ # Prefetch the next element to improve speed of input pipeline.
272
+ dataset = dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
273
+ return dataset
274
+
275
+ def load(self, input_context: Optional[tf.distribute.InputContext] = None):
276
+ """Returns a tf.dataset.Dataset."""
277
+ decoder_fn = None
278
+ # Only decode for TFRecords.
279
+ if self._params.input_path:
280
+ decoder_fn = self._decode
281
+
282
+ def _identity(
283
+ dataset, input_context: Optional[tf.distribute.InputContext] = None):
284
+ del input_context
285
+ return dataset
286
+
287
+ transform_and_batch_fn = _identity
288
+ if self._params.transform_and_batch:
289
+ transform_and_batch_fn = self._tokenize_bucketize_and_batch
290
+
291
+ reader = input_reader.InputReader(
292
+ params=self._params,
293
+ decoder_fn=decoder_fn,
294
+ transform_and_batch_fn=transform_and_batch_fn)
295
+ return reader.read(input_context)
wmt_dataloader_test.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The TensorFlow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Tests for official.nlp.data.wmt_dataloader."""
16
+ import os
17
+ from absl.testing import parameterized
18
+
19
+ import tensorflow as tf, tf_keras
20
+
21
+ from sentencepiece import SentencePieceTrainer
22
+ from official.nlp.data import wmt_dataloader
23
+
24
+
25
+ def _generate_line_file(filepath, lines):
26
+ with tf.io.gfile.GFile(filepath, 'w') as f:
27
+ for l in lines:
28
+ f.write('{}\n'.format(l))
29
+
30
+
31
+ def _generate_record_file(filepath, src_lines, tgt_lines, unique_id=False):
32
+ writer = tf.io.TFRecordWriter(filepath)
33
+ for i, (src, tgt) in enumerate(zip(src_lines, tgt_lines)):
34
+ features = {
35
+ 'en': tf.train.Feature(
36
+ bytes_list=tf.train.BytesList(
37
+ value=[src.encode()])),
38
+ 'reverse_en': tf.train.Feature(
39
+ bytes_list=tf.train.BytesList(
40
+ value=[tgt.encode()])),
41
+ }
42
+ if unique_id:
43
+ features['unique_id'] = tf.train.Feature(
44
+ int64_list=tf.train.Int64List(value=[i]))
45
+ example = tf.train.Example(
46
+ features=tf.train.Features(
47
+ feature=features))
48
+ writer.write(example.SerializeToString())
49
+ writer.close()
50
+
51
+
52
+ def _train_sentencepiece(input_path, vocab_size, model_path, eos_id=1):
53
+ argstr = ' '.join([
54
+ f'--input={input_path}', f'--vocab_size={vocab_size}',
55
+ '--character_coverage=0.995',
56
+ f'--model_prefix={model_path}', '--model_type=bpe',
57
+ '--bos_id=-1', '--pad_id=0', f'--eos_id={eos_id}', '--unk_id=2'
58
+ ])
59
+ SentencePieceTrainer.Train(argstr)
60
+
61
+
62
+ class WMTDataLoaderTest(tf.test.TestCase, parameterized.TestCase):
63
+
64
+ def setUp(self):
65
+ super(WMTDataLoaderTest, self).setUp()
66
+ self._temp_dir = self.get_temp_dir()
67
+ src_lines = [
68
+ 'abc ede fg',
69
+ 'bbcd ef a g',
70
+ 'de f a a g'
71
+ ]
72
+ tgt_lines = [
73
+ 'dd cc a ef g',
74
+ 'bcd ef a g',
75
+ 'gef cd ba'
76
+ ]
77
+ self._record_train_input_path = os.path.join(self._temp_dir, 'train.record')
78
+ _generate_record_file(self._record_train_input_path, src_lines, tgt_lines)
79
+ self._record_test_input_path = os.path.join(self._temp_dir, 'test.record')
80
+ _generate_record_file(self._record_test_input_path, src_lines, tgt_lines,
81
+ unique_id=True)
82
+ self._sentencepeice_input_path = os.path.join(self._temp_dir, 'inputs.txt')
83
+ _generate_line_file(self._sentencepeice_input_path, src_lines + tgt_lines)
84
+ sentencepeice_model_prefix = os.path.join(self._temp_dir, 'sp')
85
+ _train_sentencepiece(self._sentencepeice_input_path, 20,
86
+ sentencepeice_model_prefix)
87
+ self._sentencepeice_model_path = '{}.model'.format(
88
+ sentencepeice_model_prefix)
89
+
90
+ @parameterized.named_parameters(
91
+ ('train_static', True, True, 100, (2, 35)),
92
+ ('train_non_static', True, False, 100, (12, 7)),
93
+ ('non_train_static', False, True, 3, (3, 35)),
94
+ ('non_train_non_static', False, False, 50, (2, 7)),)
95
+ def test_load_dataset(
96
+ self, is_training, static_batch, batch_size, expected_shape):
97
+ data_config = wmt_dataloader.WMTDataConfig(
98
+ input_path=self._record_train_input_path
99
+ if is_training else self._record_test_input_path,
100
+ max_seq_length=35,
101
+ global_batch_size=batch_size,
102
+ is_training=is_training,
103
+ static_batch=static_batch,
104
+ src_lang='en',
105
+ tgt_lang='reverse_en',
106
+ sentencepiece_model_path=self._sentencepeice_model_path)
107
+ dataset = wmt_dataloader.WMTDataLoader(data_config).load()
108
+ examples = next(iter(dataset))
109
+ inputs, targets = examples['inputs'], examples['targets']
110
+ self.assertEqual(inputs.shape, expected_shape)
111
+ self.assertEqual(targets.shape, expected_shape)
112
+
113
+ def test_load_dataset_raise_invalid_window(self):
114
+ batch_tokens_size = 10 # this is too small to form buckets.
115
+ data_config = wmt_dataloader.WMTDataConfig(
116
+ input_path=self._record_train_input_path,
117
+ max_seq_length=100,
118
+ global_batch_size=batch_tokens_size,
119
+ is_training=True,
120
+ static_batch=False,
121
+ src_lang='en',
122
+ tgt_lang='reverse_en',
123
+ sentencepiece_model_path=self._sentencepeice_model_path)
124
+ with self.assertRaisesRegex(
125
+ ValueError, 'The token budget, global batch size, is too small.*'):
126
+ _ = wmt_dataloader.WMTDataLoader(data_config).load()
127
+
128
+
129
+ if __name__ == '__main__':
130
+ tf.test.main()