Spaces:
Sleeping
Sleeping
Pradeep Kumar
commited on
Commit
β’
f18e71f
1
Parent(s):
c130734
Upload 33 files
Browse files- README.md +4 -13
- __init__.py +14 -0
- classifier_data_lib.py +0 -0
- classifier_data_lib_test.py +95 -0
- create_finetuning_data.py +441 -0
- create_pretraining_data.py +718 -0
- create_pretraining_data_test.py +128 -0
- create_xlnet_pretraining_data.py +721 -0
- create_xlnet_pretraining_data_test.py +355 -0
- data_loader.py +48 -0
- data_loader_factory.py +58 -0
- data_loader_factory_test.py +45 -0
- dual_encoder_dataloader.py +147 -0
- dual_encoder_dataloader_test.py +131 -0
- pretrain_dataloader.py +589 -0
- pretrain_dataloader_test.py +242 -0
- pretrain_dynamic_dataloader.py +223 -0
- pretrain_dynamic_dataloader_test.py +245 -0
- pretrain_text_dataloader.py +226 -0
- question_answering_dataloader.py +115 -0
- question_answering_dataloader_test.py +74 -0
- sentence_prediction_dataloader.py +267 -0
- sentence_prediction_dataloader_test.py +290 -0
- sentence_retrieval_lib.py +166 -0
- squad_lib.py +975 -0
- squad_lib_sp.py +976 -0
- tagging_data_lib.py +426 -0
- tagging_data_lib_test.py +108 -0
- tagging_dataloader.py +90 -0
- tagging_dataloader_test.py +82 -0
- train_sentencepiece.py +133 -0
- wmt_dataloader.py +295 -0
- wmt_dataloader_test.py +130 -0
README.md
CHANGED
@@ -1,13 +1,4 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
colorTo: indigo
|
6 |
-
sdk: gradio
|
7 |
-
sdk_version: 4.41.0
|
8 |
-
app_file: app.py
|
9 |
-
pinned: false
|
10 |
-
license: apache-2.0
|
11 |
-
---
|
12 |
-
|
13 |
-
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
1 |
+
This directory contains binaries and utils required for input preprocessing,
|
2 |
+
tokenization, etc that can be used with model building blocks available in
|
3 |
+
NLP modeling library [nlp/modelling](https://github.com/tensorflow/models/tree/master/official/nlp/modeling)
|
4 |
+
to train custom models and validate new research ideas.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
__init__.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
classifier_data_lib.py
CHANGED
The diff for this file is too large to render.
See raw diff
|
|
classifier_data_lib_test.py
ADDED
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
"""Tests for third_party.tensorflow_models.official.nlp.data.classifier_data_lib."""
|
16 |
+
|
17 |
+
import os
|
18 |
+
import tempfile
|
19 |
+
|
20 |
+
from absl.testing import parameterized
|
21 |
+
import tensorflow as tf, tf_keras
|
22 |
+
import tensorflow_datasets as tfds
|
23 |
+
|
24 |
+
from official.nlp.data import classifier_data_lib
|
25 |
+
from official.nlp.tools import tokenization
|
26 |
+
|
27 |
+
|
28 |
+
def decode_record(record, name_to_features):
|
29 |
+
"""Decodes a record to a TensorFlow example."""
|
30 |
+
return tf.io.parse_single_example(record, name_to_features)
|
31 |
+
|
32 |
+
|
33 |
+
class BertClassifierLibTest(tf.test.TestCase, parameterized.TestCase):
|
34 |
+
|
35 |
+
def setUp(self):
|
36 |
+
super(BertClassifierLibTest, self).setUp()
|
37 |
+
self.model_dir = self.get_temp_dir()
|
38 |
+
self.processors = {
|
39 |
+
"CB": classifier_data_lib.CBProcessor,
|
40 |
+
"SUPERGLUE-RTE": classifier_data_lib.SuperGLUERTEProcessor,
|
41 |
+
"BOOLQ": classifier_data_lib.BoolQProcessor,
|
42 |
+
"WIC": classifier_data_lib.WiCProcessor,
|
43 |
+
}
|
44 |
+
|
45 |
+
vocab_tokens = [
|
46 |
+
"[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn",
|
47 |
+
"##ing", ","
|
48 |
+
]
|
49 |
+
with tempfile.NamedTemporaryFile(delete=False) as vocab_writer:
|
50 |
+
vocab_writer.write("".join([x + "\n" for x in vocab_tokens
|
51 |
+
]).encode("utf-8"))
|
52 |
+
vocab_file = vocab_writer.name
|
53 |
+
self.tokenizer = tokenization.FullTokenizer(vocab_file)
|
54 |
+
|
55 |
+
@parameterized.parameters(
|
56 |
+
{"task_type": "CB"},
|
57 |
+
{"task_type": "BOOLQ"},
|
58 |
+
{"task_type": "SUPERGLUE-RTE"},
|
59 |
+
{"task_type": "WIC"},
|
60 |
+
)
|
61 |
+
def test_generate_dataset_from_tfds_processor(self, task_type):
|
62 |
+
with tfds.testing.mock_data(num_examples=5):
|
63 |
+
output_path = os.path.join(self.model_dir, task_type)
|
64 |
+
|
65 |
+
processor = self.processors[task_type]()
|
66 |
+
|
67 |
+
classifier_data_lib.generate_tf_record_from_data_file(
|
68 |
+
processor,
|
69 |
+
None,
|
70 |
+
self.tokenizer,
|
71 |
+
train_data_output_path=output_path,
|
72 |
+
eval_data_output_path=output_path,
|
73 |
+
test_data_output_path=output_path)
|
74 |
+
files = tf.io.gfile.glob(output_path)
|
75 |
+
self.assertNotEmpty(files)
|
76 |
+
|
77 |
+
train_dataset = tf.data.TFRecordDataset(output_path)
|
78 |
+
seq_length = 128
|
79 |
+
label_type = tf.int64
|
80 |
+
name_to_features = {
|
81 |
+
"input_ids": tf.io.FixedLenFeature([seq_length], tf.int64),
|
82 |
+
"input_mask": tf.io.FixedLenFeature([seq_length], tf.int64),
|
83 |
+
"segment_ids": tf.io.FixedLenFeature([seq_length], tf.int64),
|
84 |
+
"label_ids": tf.io.FixedLenFeature([], label_type),
|
85 |
+
}
|
86 |
+
train_dataset = train_dataset.map(
|
87 |
+
lambda record: decode_record(record, name_to_features))
|
88 |
+
|
89 |
+
# If data is retrieved without error, then all requirements
|
90 |
+
# including data type/shapes are met.
|
91 |
+
_ = next(iter(train_dataset))
|
92 |
+
|
93 |
+
|
94 |
+
if __name__ == "__main__":
|
95 |
+
tf.test.main()
|
create_finetuning_data.py
ADDED
@@ -0,0 +1,441 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
"""BERT finetuning task dataset generator."""
|
16 |
+
|
17 |
+
import functools
|
18 |
+
import json
|
19 |
+
import os
|
20 |
+
|
21 |
+
# Import libraries
|
22 |
+
from absl import app
|
23 |
+
from absl import flags
|
24 |
+
import tensorflow as tf, tf_keras
|
25 |
+
from official.nlp.data import classifier_data_lib
|
26 |
+
from official.nlp.data import sentence_retrieval_lib
|
27 |
+
# word-piece tokenizer based squad_lib
|
28 |
+
from official.nlp.data import squad_lib as squad_lib_wp
|
29 |
+
# sentence-piece tokenizer based squad_lib
|
30 |
+
from official.nlp.data import squad_lib_sp
|
31 |
+
from official.nlp.data import tagging_data_lib
|
32 |
+
from official.nlp.tools import tokenization
|
33 |
+
|
34 |
+
FLAGS = flags.FLAGS
|
35 |
+
|
36 |
+
flags.DEFINE_enum(
|
37 |
+
"fine_tuning_task_type", "classification",
|
38 |
+
["classification", "regression", "squad", "retrieval", "tagging"],
|
39 |
+
"The name of the BERT fine tuning task for which data "
|
40 |
+
"will be generated.")
|
41 |
+
|
42 |
+
# BERT classification specific flags.
|
43 |
+
flags.DEFINE_string(
|
44 |
+
"input_data_dir", None,
|
45 |
+
"The input data dir. Should contain the .tsv files (or other data files) "
|
46 |
+
"for the task.")
|
47 |
+
|
48 |
+
flags.DEFINE_enum(
|
49 |
+
"classification_task_name", "MNLI", [
|
50 |
+
"AX", "COLA", "IMDB", "MNLI", "MRPC", "PAWS-X", "QNLI", "QQP", "RTE",
|
51 |
+
"SST-2", "STS-B", "WNLI", "XNLI", "XTREME-XNLI", "XTREME-PAWS-X",
|
52 |
+
"AX-g", "SUPERGLUE-RTE", "CB", "BoolQ", "WIC"
|
53 |
+
], "The name of the task to train BERT classifier. The "
|
54 |
+
"difference between XTREME-XNLI and XNLI is: 1. the format "
|
55 |
+
"of input tsv files; 2. the dev set for XTREME is english "
|
56 |
+
"only and for XNLI is all languages combined. Same for "
|
57 |
+
"PAWS-X.")
|
58 |
+
|
59 |
+
# MNLI task-specific flag.
|
60 |
+
flags.DEFINE_enum("mnli_type", "matched", ["matched", "mismatched"],
|
61 |
+
"The type of MNLI dataset.")
|
62 |
+
|
63 |
+
# XNLI task-specific flag.
|
64 |
+
flags.DEFINE_string(
|
65 |
+
"xnli_language", "en",
|
66 |
+
"Language of training data for XNLI task. If the value is 'all', the data "
|
67 |
+
"of all languages will be used for training.")
|
68 |
+
|
69 |
+
# PAWS-X task-specific flag.
|
70 |
+
flags.DEFINE_string(
|
71 |
+
"pawsx_language", "en",
|
72 |
+
"Language of training data for PAWS-X task. If the value is 'all', the data "
|
73 |
+
"of all languages will be used for training.")
|
74 |
+
|
75 |
+
# XTREME classification specific flags. Only used in XtremePawsx and XtremeXnli.
|
76 |
+
flags.DEFINE_string(
|
77 |
+
"translated_input_data_dir", None,
|
78 |
+
"The translated input data dir. Should contain the .tsv files (or other "
|
79 |
+
"data files) for the task.")
|
80 |
+
|
81 |
+
# Retrieval task-specific flags.
|
82 |
+
flags.DEFINE_enum("retrieval_task_name", "bucc", ["bucc", "tatoeba"],
|
83 |
+
"The name of sentence retrieval task for scoring")
|
84 |
+
|
85 |
+
# Tagging task-specific flags.
|
86 |
+
flags.DEFINE_enum("tagging_task_name", "panx", ["panx", "udpos"],
|
87 |
+
"The name of BERT tagging (token classification) task.")
|
88 |
+
|
89 |
+
flags.DEFINE_bool("tagging_only_use_en_train", True,
|
90 |
+
"Whether only use english training data in tagging.")
|
91 |
+
|
92 |
+
# BERT Squad task-specific flags.
|
93 |
+
flags.DEFINE_string(
|
94 |
+
"squad_data_file", None,
|
95 |
+
"The input data file in for generating training data for BERT squad task.")
|
96 |
+
|
97 |
+
flags.DEFINE_string(
|
98 |
+
"translated_squad_data_folder", None,
|
99 |
+
"The translated data folder for generating training data for BERT squad "
|
100 |
+
"task.")
|
101 |
+
|
102 |
+
flags.DEFINE_integer(
|
103 |
+
"doc_stride", 128,
|
104 |
+
"When splitting up a long document into chunks, how much stride to "
|
105 |
+
"take between chunks.")
|
106 |
+
|
107 |
+
flags.DEFINE_integer(
|
108 |
+
"max_query_length", 64,
|
109 |
+
"The maximum number of tokens for the question. Questions longer than "
|
110 |
+
"this will be truncated to this length.")
|
111 |
+
|
112 |
+
flags.DEFINE_bool(
|
113 |
+
"version_2_with_negative", False,
|
114 |
+
"If true, the SQuAD examples contain some that do not have an answer.")
|
115 |
+
|
116 |
+
flags.DEFINE_bool(
|
117 |
+
"xlnet_format", False,
|
118 |
+
"If true, then data will be preprocessed in a paragraph, query, class order"
|
119 |
+
" instead of the BERT-style class, paragraph, query order.")
|
120 |
+
|
121 |
+
# XTREME specific flags.
|
122 |
+
flags.DEFINE_bool("only_use_en_dev", True, "Whether only use english dev data.")
|
123 |
+
|
124 |
+
# Shared flags across BERT fine-tuning tasks.
|
125 |
+
flags.DEFINE_string("vocab_file", None,
|
126 |
+
"The vocabulary file that the BERT model was trained on.")
|
127 |
+
|
128 |
+
flags.DEFINE_string(
|
129 |
+
"train_data_output_path", None,
|
130 |
+
"The path in which generated training input data will be written as tf"
|
131 |
+
" records.")
|
132 |
+
|
133 |
+
flags.DEFINE_string(
|
134 |
+
"eval_data_output_path", None,
|
135 |
+
"The path in which generated evaluation input data will be written as tf"
|
136 |
+
" records.")
|
137 |
+
|
138 |
+
flags.DEFINE_string(
|
139 |
+
"test_data_output_path", None,
|
140 |
+
"The path in which generated test input data will be written as tf"
|
141 |
+
" records. If None, do not generate test data. Must be a pattern template"
|
142 |
+
" as test_{}.tfrecords if processor has language specific test data.")
|
143 |
+
|
144 |
+
flags.DEFINE_string("meta_data_file_path", None,
|
145 |
+
"The path in which input meta data will be written.")
|
146 |
+
|
147 |
+
flags.DEFINE_bool(
|
148 |
+
"do_lower_case", True,
|
149 |
+
"Whether to lower case the input text. Should be True for uncased "
|
150 |
+
"models and False for cased models.")
|
151 |
+
|
152 |
+
flags.DEFINE_integer(
|
153 |
+
"max_seq_length", 128,
|
154 |
+
"The maximum total input sequence length after WordPiece tokenization. "
|
155 |
+
"Sequences longer than this will be truncated, and sequences shorter "
|
156 |
+
"than this will be padded.")
|
157 |
+
|
158 |
+
flags.DEFINE_string("sp_model_file", "",
|
159 |
+
"The path to the model used by sentence piece tokenizer.")
|
160 |
+
|
161 |
+
flags.DEFINE_enum(
|
162 |
+
"tokenization", "WordPiece", ["WordPiece", "SentencePiece"],
|
163 |
+
"Specifies the tokenizer implementation, i.e., whether to use WordPiece "
|
164 |
+
"or SentencePiece tokenizer. Canonical BERT uses WordPiece tokenizer, "
|
165 |
+
"while ALBERT uses SentencePiece tokenizer.")
|
166 |
+
|
167 |
+
flags.DEFINE_string(
|
168 |
+
"tfds_params", "", "Comma-separated list of TFDS parameter assignments for "
|
169 |
+
"generic classfication data import (for more details "
|
170 |
+
"see the TfdsProcessor class documentation).")
|
171 |
+
|
172 |
+
|
173 |
+
def generate_classifier_dataset():
|
174 |
+
"""Generates classifier dataset and returns input meta data."""
|
175 |
+
if FLAGS.classification_task_name in [
|
176 |
+
"COLA",
|
177 |
+
"WNLI",
|
178 |
+
"SST-2",
|
179 |
+
"MRPC",
|
180 |
+
"QQP",
|
181 |
+
"STS-B",
|
182 |
+
"MNLI",
|
183 |
+
"QNLI",
|
184 |
+
"RTE",
|
185 |
+
"AX",
|
186 |
+
"SUPERGLUE-RTE",
|
187 |
+
"CB",
|
188 |
+
"BoolQ",
|
189 |
+
"WIC",
|
190 |
+
]:
|
191 |
+
assert not FLAGS.input_data_dir or FLAGS.tfds_params
|
192 |
+
else:
|
193 |
+
assert (FLAGS.input_data_dir and FLAGS.classification_task_name or
|
194 |
+
FLAGS.tfds_params)
|
195 |
+
|
196 |
+
if FLAGS.tokenization == "WordPiece":
|
197 |
+
tokenizer = tokenization.FullTokenizer(
|
198 |
+
vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)
|
199 |
+
processor_text_fn = tokenization.convert_to_unicode
|
200 |
+
else:
|
201 |
+
assert FLAGS.tokenization == "SentencePiece"
|
202 |
+
tokenizer = tokenization.FullSentencePieceTokenizer(FLAGS.sp_model_file)
|
203 |
+
processor_text_fn = functools.partial(
|
204 |
+
tokenization.preprocess_text, lower=FLAGS.do_lower_case)
|
205 |
+
|
206 |
+
if FLAGS.tfds_params:
|
207 |
+
processor = classifier_data_lib.TfdsProcessor(
|
208 |
+
tfds_params=FLAGS.tfds_params, process_text_fn=processor_text_fn)
|
209 |
+
return classifier_data_lib.generate_tf_record_from_data_file(
|
210 |
+
processor,
|
211 |
+
None,
|
212 |
+
tokenizer,
|
213 |
+
train_data_output_path=FLAGS.train_data_output_path,
|
214 |
+
eval_data_output_path=FLAGS.eval_data_output_path,
|
215 |
+
test_data_output_path=FLAGS.test_data_output_path,
|
216 |
+
max_seq_length=FLAGS.max_seq_length)
|
217 |
+
else:
|
218 |
+
processors = {
|
219 |
+
"ax":
|
220 |
+
classifier_data_lib.AxProcessor,
|
221 |
+
"cola":
|
222 |
+
classifier_data_lib.ColaProcessor,
|
223 |
+
"imdb":
|
224 |
+
classifier_data_lib.ImdbProcessor,
|
225 |
+
"mnli":
|
226 |
+
functools.partial(
|
227 |
+
classifier_data_lib.MnliProcessor, mnli_type=FLAGS.mnli_type),
|
228 |
+
"mrpc":
|
229 |
+
classifier_data_lib.MrpcProcessor,
|
230 |
+
"qnli":
|
231 |
+
classifier_data_lib.QnliProcessor,
|
232 |
+
"qqp":
|
233 |
+
classifier_data_lib.QqpProcessor,
|
234 |
+
"rte":
|
235 |
+
classifier_data_lib.RteProcessor,
|
236 |
+
"sst-2":
|
237 |
+
classifier_data_lib.SstProcessor,
|
238 |
+
"sts-b":
|
239 |
+
classifier_data_lib.StsBProcessor,
|
240 |
+
"xnli":
|
241 |
+
functools.partial(
|
242 |
+
classifier_data_lib.XnliProcessor,
|
243 |
+
language=FLAGS.xnli_language),
|
244 |
+
"paws-x":
|
245 |
+
functools.partial(
|
246 |
+
classifier_data_lib.PawsxProcessor,
|
247 |
+
language=FLAGS.pawsx_language),
|
248 |
+
"wnli":
|
249 |
+
classifier_data_lib.WnliProcessor,
|
250 |
+
"xtreme-xnli":
|
251 |
+
functools.partial(
|
252 |
+
classifier_data_lib.XtremeXnliProcessor,
|
253 |
+
translated_data_dir=FLAGS.translated_input_data_dir,
|
254 |
+
only_use_en_dev=FLAGS.only_use_en_dev),
|
255 |
+
"xtreme-paws-x":
|
256 |
+
functools.partial(
|
257 |
+
classifier_data_lib.XtremePawsxProcessor,
|
258 |
+
translated_data_dir=FLAGS.translated_input_data_dir,
|
259 |
+
only_use_en_dev=FLAGS.only_use_en_dev),
|
260 |
+
"ax-g":
|
261 |
+
classifier_data_lib.AXgProcessor,
|
262 |
+
"superglue-rte":
|
263 |
+
classifier_data_lib.SuperGLUERTEProcessor,
|
264 |
+
"cb":
|
265 |
+
classifier_data_lib.CBProcessor,
|
266 |
+
"boolq":
|
267 |
+
classifier_data_lib.BoolQProcessor,
|
268 |
+
"wic":
|
269 |
+
classifier_data_lib.WnliProcessor,
|
270 |
+
}
|
271 |
+
task_name = FLAGS.classification_task_name.lower()
|
272 |
+
if task_name not in processors:
|
273 |
+
raise ValueError("Task not found: %s" % (task_name,))
|
274 |
+
|
275 |
+
processor = processors[task_name](process_text_fn=processor_text_fn)
|
276 |
+
return classifier_data_lib.generate_tf_record_from_data_file(
|
277 |
+
processor,
|
278 |
+
FLAGS.input_data_dir,
|
279 |
+
tokenizer,
|
280 |
+
train_data_output_path=FLAGS.train_data_output_path,
|
281 |
+
eval_data_output_path=FLAGS.eval_data_output_path,
|
282 |
+
test_data_output_path=FLAGS.test_data_output_path,
|
283 |
+
max_seq_length=FLAGS.max_seq_length)
|
284 |
+
|
285 |
+
|
286 |
+
def generate_regression_dataset():
|
287 |
+
"""Generates regression dataset and returns input meta data."""
|
288 |
+
if FLAGS.tokenization == "WordPiece":
|
289 |
+
tokenizer = tokenization.FullTokenizer(
|
290 |
+
vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)
|
291 |
+
processor_text_fn = tokenization.convert_to_unicode
|
292 |
+
else:
|
293 |
+
assert FLAGS.tokenization == "SentencePiece"
|
294 |
+
tokenizer = tokenization.FullSentencePieceTokenizer(FLAGS.sp_model_file)
|
295 |
+
processor_text_fn = functools.partial(
|
296 |
+
tokenization.preprocess_text, lower=FLAGS.do_lower_case)
|
297 |
+
|
298 |
+
if FLAGS.tfds_params:
|
299 |
+
processor = classifier_data_lib.TfdsProcessor(
|
300 |
+
tfds_params=FLAGS.tfds_params, process_text_fn=processor_text_fn)
|
301 |
+
return classifier_data_lib.generate_tf_record_from_data_file(
|
302 |
+
processor,
|
303 |
+
None,
|
304 |
+
tokenizer,
|
305 |
+
train_data_output_path=FLAGS.train_data_output_path,
|
306 |
+
eval_data_output_path=FLAGS.eval_data_output_path,
|
307 |
+
test_data_output_path=FLAGS.test_data_output_path,
|
308 |
+
max_seq_length=FLAGS.max_seq_length)
|
309 |
+
else:
|
310 |
+
raise ValueError("No data processor found for the given regression task.")
|
311 |
+
|
312 |
+
|
313 |
+
def generate_squad_dataset():
|
314 |
+
"""Generates squad training dataset and returns input meta data."""
|
315 |
+
assert FLAGS.squad_data_file
|
316 |
+
if FLAGS.tokenization == "WordPiece":
|
317 |
+
return squad_lib_wp.generate_tf_record_from_json_file(
|
318 |
+
input_file_path=FLAGS.squad_data_file,
|
319 |
+
vocab_file_path=FLAGS.vocab_file,
|
320 |
+
output_path=FLAGS.train_data_output_path,
|
321 |
+
translated_input_folder=FLAGS.translated_squad_data_folder,
|
322 |
+
max_seq_length=FLAGS.max_seq_length,
|
323 |
+
do_lower_case=FLAGS.do_lower_case,
|
324 |
+
max_query_length=FLAGS.max_query_length,
|
325 |
+
doc_stride=FLAGS.doc_stride,
|
326 |
+
version_2_with_negative=FLAGS.version_2_with_negative,
|
327 |
+
xlnet_format=FLAGS.xlnet_format)
|
328 |
+
else:
|
329 |
+
assert FLAGS.tokenization == "SentencePiece"
|
330 |
+
return squad_lib_sp.generate_tf_record_from_json_file(
|
331 |
+
input_file_path=FLAGS.squad_data_file,
|
332 |
+
sp_model_file=FLAGS.sp_model_file,
|
333 |
+
output_path=FLAGS.train_data_output_path,
|
334 |
+
translated_input_folder=FLAGS.translated_squad_data_folder,
|
335 |
+
max_seq_length=FLAGS.max_seq_length,
|
336 |
+
do_lower_case=FLAGS.do_lower_case,
|
337 |
+
max_query_length=FLAGS.max_query_length,
|
338 |
+
doc_stride=FLAGS.doc_stride,
|
339 |
+
xlnet_format=FLAGS.xlnet_format,
|
340 |
+
version_2_with_negative=FLAGS.version_2_with_negative)
|
341 |
+
|
342 |
+
|
343 |
+
def generate_retrieval_dataset():
|
344 |
+
"""Generate retrieval test and dev dataset and returns input meta data."""
|
345 |
+
assert (FLAGS.input_data_dir and FLAGS.retrieval_task_name)
|
346 |
+
if FLAGS.tokenization == "WordPiece":
|
347 |
+
tokenizer = tokenization.FullTokenizer(
|
348 |
+
vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)
|
349 |
+
processor_text_fn = tokenization.convert_to_unicode
|
350 |
+
else:
|
351 |
+
assert FLAGS.tokenization == "SentencePiece"
|
352 |
+
tokenizer = tokenization.FullSentencePieceTokenizer(FLAGS.sp_model_file)
|
353 |
+
processor_text_fn = functools.partial(
|
354 |
+
tokenization.preprocess_text, lower=FLAGS.do_lower_case)
|
355 |
+
|
356 |
+
processors = {
|
357 |
+
"bucc": sentence_retrieval_lib.BuccProcessor,
|
358 |
+
"tatoeba": sentence_retrieval_lib.TatoebaProcessor,
|
359 |
+
}
|
360 |
+
|
361 |
+
task_name = FLAGS.retrieval_task_name.lower()
|
362 |
+
if task_name not in processors:
|
363 |
+
raise ValueError("Task not found: %s" % task_name)
|
364 |
+
|
365 |
+
processor = processors[task_name](process_text_fn=processor_text_fn)
|
366 |
+
|
367 |
+
return sentence_retrieval_lib.generate_sentence_retrevial_tf_record(
|
368 |
+
processor, FLAGS.input_data_dir, tokenizer, FLAGS.eval_data_output_path,
|
369 |
+
FLAGS.test_data_output_path, FLAGS.max_seq_length)
|
370 |
+
|
371 |
+
|
372 |
+
def generate_tagging_dataset():
|
373 |
+
"""Generates tagging dataset."""
|
374 |
+
processors = {
|
375 |
+
"panx":
|
376 |
+
functools.partial(
|
377 |
+
tagging_data_lib.PanxProcessor,
|
378 |
+
only_use_en_train=FLAGS.tagging_only_use_en_train,
|
379 |
+
only_use_en_dev=FLAGS.only_use_en_dev),
|
380 |
+
"udpos":
|
381 |
+
functools.partial(
|
382 |
+
tagging_data_lib.UdposProcessor,
|
383 |
+
only_use_en_train=FLAGS.tagging_only_use_en_train,
|
384 |
+
only_use_en_dev=FLAGS.only_use_en_dev),
|
385 |
+
}
|
386 |
+
task_name = FLAGS.tagging_task_name.lower()
|
387 |
+
if task_name not in processors:
|
388 |
+
raise ValueError("Task not found: %s" % task_name)
|
389 |
+
|
390 |
+
if FLAGS.tokenization == "WordPiece":
|
391 |
+
tokenizer = tokenization.FullTokenizer(
|
392 |
+
vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)
|
393 |
+
processor_text_fn = tokenization.convert_to_unicode
|
394 |
+
elif FLAGS.tokenization == "SentencePiece":
|
395 |
+
tokenizer = tokenization.FullSentencePieceTokenizer(FLAGS.sp_model_file)
|
396 |
+
processor_text_fn = functools.partial(
|
397 |
+
tokenization.preprocess_text, lower=FLAGS.do_lower_case)
|
398 |
+
else:
|
399 |
+
raise ValueError("Unsupported tokenization: %s" % FLAGS.tokenization)
|
400 |
+
|
401 |
+
processor = processors[task_name]()
|
402 |
+
return tagging_data_lib.generate_tf_record_from_data_file(
|
403 |
+
processor, FLAGS.input_data_dir, tokenizer, FLAGS.max_seq_length,
|
404 |
+
FLAGS.train_data_output_path, FLAGS.eval_data_output_path,
|
405 |
+
FLAGS.test_data_output_path, processor_text_fn)
|
406 |
+
|
407 |
+
|
408 |
+
def main(_):
|
409 |
+
if FLAGS.tokenization == "WordPiece":
|
410 |
+
if not FLAGS.vocab_file:
|
411 |
+
raise ValueError(
|
412 |
+
"FLAG vocab_file for word-piece tokenizer is not specified.")
|
413 |
+
else:
|
414 |
+
assert FLAGS.tokenization == "SentencePiece"
|
415 |
+
if not FLAGS.sp_model_file:
|
416 |
+
raise ValueError(
|
417 |
+
"FLAG sp_model_file for sentence-piece tokenizer is not specified.")
|
418 |
+
|
419 |
+
if FLAGS.fine_tuning_task_type != "retrieval":
|
420 |
+
flags.mark_flag_as_required("train_data_output_path")
|
421 |
+
|
422 |
+
if FLAGS.fine_tuning_task_type == "classification":
|
423 |
+
input_meta_data = generate_classifier_dataset()
|
424 |
+
elif FLAGS.fine_tuning_task_type == "regression":
|
425 |
+
input_meta_data = generate_regression_dataset()
|
426 |
+
elif FLAGS.fine_tuning_task_type == "retrieval":
|
427 |
+
input_meta_data = generate_retrieval_dataset()
|
428 |
+
elif FLAGS.fine_tuning_task_type == "squad":
|
429 |
+
input_meta_data = generate_squad_dataset()
|
430 |
+
else:
|
431 |
+
assert FLAGS.fine_tuning_task_type == "tagging"
|
432 |
+
input_meta_data = generate_tagging_dataset()
|
433 |
+
|
434 |
+
tf.io.gfile.makedirs(os.path.dirname(FLAGS.meta_data_file_path))
|
435 |
+
with tf.io.gfile.GFile(FLAGS.meta_data_file_path, "w") as writer:
|
436 |
+
writer.write(json.dumps(input_meta_data, indent=4) + "\n")
|
437 |
+
|
438 |
+
|
439 |
+
if __name__ == "__main__":
|
440 |
+
flags.mark_flag_as_required("meta_data_file_path")
|
441 |
+
app.run(main)
|
create_pretraining_data.py
ADDED
@@ -0,0 +1,718 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
"""Create masked LM/next sentence masked_lm TF examples for BERT."""
|
16 |
+
|
17 |
+
import collections
|
18 |
+
import itertools
|
19 |
+
import random
|
20 |
+
|
21 |
+
# Import libraries
|
22 |
+
|
23 |
+
from absl import app
|
24 |
+
from absl import flags
|
25 |
+
from absl import logging
|
26 |
+
import tensorflow as tf, tf_keras
|
27 |
+
|
28 |
+
from official.nlp.tools import tokenization
|
29 |
+
|
30 |
+
FLAGS = flags.FLAGS
|
31 |
+
|
32 |
+
flags.DEFINE_string("input_file", None,
|
33 |
+
"Input raw text file (or comma-separated list of files).")
|
34 |
+
|
35 |
+
flags.DEFINE_string(
|
36 |
+
"output_file", None,
|
37 |
+
"Output TF example file (or comma-separated list of files).")
|
38 |
+
|
39 |
+
flags.DEFINE_enum(
|
40 |
+
"tokenization",
|
41 |
+
"WordPiece",
|
42 |
+
["WordPiece", "SentencePiece"],
|
43 |
+
"Specifies the tokenizer implementation, i.e., whether to use WordPiece "
|
44 |
+
"or SentencePiece tokenizer. Canonical BERT uses WordPiece tokenizer, "
|
45 |
+
"while ALBERT uses SentencePiece tokenizer.",
|
46 |
+
)
|
47 |
+
|
48 |
+
flags.DEFINE_string(
|
49 |
+
"vocab_file",
|
50 |
+
None,
|
51 |
+
"For WordPiece tokenization, the vocabulary file of the tokenizer.",
|
52 |
+
)
|
53 |
+
|
54 |
+
flags.DEFINE_string(
|
55 |
+
"sp_model_file",
|
56 |
+
"",
|
57 |
+
"For SentencePiece tokenization, the path to the model of the tokenizer.",
|
58 |
+
)
|
59 |
+
|
60 |
+
flags.DEFINE_bool(
|
61 |
+
"do_lower_case", True,
|
62 |
+
"Whether to lower case the input text. Should be True for uncased "
|
63 |
+
"models and False for cased models.")
|
64 |
+
|
65 |
+
flags.DEFINE_bool(
|
66 |
+
"do_whole_word_mask",
|
67 |
+
False,
|
68 |
+
"Whether to use whole word masking rather than per-token masking.",
|
69 |
+
)
|
70 |
+
|
71 |
+
flags.DEFINE_integer(
|
72 |
+
"max_ngram_size", None,
|
73 |
+
"Mask contiguous whole words (n-grams) of up to `max_ngram_size` using a "
|
74 |
+
"weighting scheme to favor shorter n-grams. "
|
75 |
+
"Note: `--do_whole_word_mask=True` must also be set when n-gram masking.")
|
76 |
+
|
77 |
+
flags.DEFINE_bool(
|
78 |
+
"gzip_compress", False,
|
79 |
+
"Whether to use `GZIP` compress option to get compressed TFRecord files.")
|
80 |
+
|
81 |
+
flags.DEFINE_bool(
|
82 |
+
"use_v2_feature_names", False,
|
83 |
+
"Whether to use the feature names consistent with the models.")
|
84 |
+
|
85 |
+
flags.DEFINE_integer("max_seq_length", 128, "Maximum sequence length.")
|
86 |
+
|
87 |
+
flags.DEFINE_integer("max_predictions_per_seq", 20,
|
88 |
+
"Maximum number of masked LM predictions per sequence.")
|
89 |
+
|
90 |
+
flags.DEFINE_integer("random_seed", 12345, "Random seed for data generation.")
|
91 |
+
|
92 |
+
flags.DEFINE_integer(
|
93 |
+
"dupe_factor", 10,
|
94 |
+
"Number of times to duplicate the input data (with different masks).")
|
95 |
+
|
96 |
+
flags.DEFINE_float("masked_lm_prob", 0.15, "Masked LM probability.")
|
97 |
+
|
98 |
+
flags.DEFINE_float(
|
99 |
+
"short_seq_prob", 0.1,
|
100 |
+
"Probability of creating sequences which are shorter than the "
|
101 |
+
"maximum length.")
|
102 |
+
|
103 |
+
|
104 |
+
class TrainingInstance(object):
|
105 |
+
"""A single training instance (sentence pair)."""
|
106 |
+
|
107 |
+
def __init__(self, tokens, segment_ids, masked_lm_positions, masked_lm_labels,
|
108 |
+
is_random_next):
|
109 |
+
self.tokens = tokens
|
110 |
+
self.segment_ids = segment_ids
|
111 |
+
self.is_random_next = is_random_next
|
112 |
+
self.masked_lm_positions = masked_lm_positions
|
113 |
+
self.masked_lm_labels = masked_lm_labels
|
114 |
+
|
115 |
+
def __str__(self):
|
116 |
+
s = ""
|
117 |
+
s += "tokens: %s\n" % (" ".join(
|
118 |
+
[tokenization.printable_text(x) for x in self.tokens]))
|
119 |
+
s += "segment_ids: %s\n" % (" ".join([str(x) for x in self.segment_ids]))
|
120 |
+
s += "is_random_next: %s\n" % self.is_random_next
|
121 |
+
s += "masked_lm_positions: %s\n" % (" ".join(
|
122 |
+
[str(x) for x in self.masked_lm_positions]))
|
123 |
+
s += "masked_lm_labels: %s\n" % (" ".join(
|
124 |
+
[tokenization.printable_text(x) for x in self.masked_lm_labels]))
|
125 |
+
s += "\n"
|
126 |
+
return s
|
127 |
+
|
128 |
+
def __repr__(self):
|
129 |
+
return self.__str__()
|
130 |
+
|
131 |
+
|
132 |
+
def write_instance_to_example_files(instances, tokenizer, max_seq_length,
|
133 |
+
max_predictions_per_seq, output_files,
|
134 |
+
gzip_compress, use_v2_feature_names):
|
135 |
+
"""Creates TF example files from `TrainingInstance`s."""
|
136 |
+
writers = []
|
137 |
+
for output_file in output_files:
|
138 |
+
writers.append(
|
139 |
+
tf.io.TFRecordWriter(
|
140 |
+
output_file, options="GZIP" if gzip_compress else ""))
|
141 |
+
|
142 |
+
writer_index = 0
|
143 |
+
|
144 |
+
total_written = 0
|
145 |
+
for (inst_index, instance) in enumerate(instances):
|
146 |
+
input_ids = tokenizer.convert_tokens_to_ids(instance.tokens)
|
147 |
+
input_mask = [1] * len(input_ids)
|
148 |
+
segment_ids = list(instance.segment_ids)
|
149 |
+
assert len(input_ids) <= max_seq_length
|
150 |
+
|
151 |
+
while len(input_ids) < max_seq_length:
|
152 |
+
input_ids.append(0)
|
153 |
+
input_mask.append(0)
|
154 |
+
segment_ids.append(0)
|
155 |
+
|
156 |
+
assert len(input_ids) == max_seq_length
|
157 |
+
assert len(input_mask) == max_seq_length
|
158 |
+
assert len(segment_ids) == max_seq_length
|
159 |
+
|
160 |
+
masked_lm_positions = list(instance.masked_lm_positions)
|
161 |
+
masked_lm_ids = tokenizer.convert_tokens_to_ids(instance.masked_lm_labels)
|
162 |
+
masked_lm_weights = [1.0] * len(masked_lm_ids)
|
163 |
+
|
164 |
+
while len(masked_lm_positions) < max_predictions_per_seq:
|
165 |
+
masked_lm_positions.append(0)
|
166 |
+
masked_lm_ids.append(0)
|
167 |
+
masked_lm_weights.append(0.0)
|
168 |
+
|
169 |
+
next_sentence_label = 1 if instance.is_random_next else 0
|
170 |
+
|
171 |
+
features = collections.OrderedDict()
|
172 |
+
if use_v2_feature_names:
|
173 |
+
features["input_word_ids"] = create_int_feature(input_ids)
|
174 |
+
features["input_type_ids"] = create_int_feature(segment_ids)
|
175 |
+
else:
|
176 |
+
features["input_ids"] = create_int_feature(input_ids)
|
177 |
+
features["segment_ids"] = create_int_feature(segment_ids)
|
178 |
+
|
179 |
+
features["input_mask"] = create_int_feature(input_mask)
|
180 |
+
features["masked_lm_positions"] = create_int_feature(masked_lm_positions)
|
181 |
+
features["masked_lm_ids"] = create_int_feature(masked_lm_ids)
|
182 |
+
features["masked_lm_weights"] = create_float_feature(masked_lm_weights)
|
183 |
+
features["next_sentence_labels"] = create_int_feature([next_sentence_label])
|
184 |
+
|
185 |
+
tf_example = tf.train.Example(features=tf.train.Features(feature=features))
|
186 |
+
|
187 |
+
writers[writer_index].write(tf_example.SerializeToString())
|
188 |
+
writer_index = (writer_index + 1) % len(writers)
|
189 |
+
|
190 |
+
total_written += 1
|
191 |
+
|
192 |
+
if inst_index < 20:
|
193 |
+
logging.info("*** Example ***")
|
194 |
+
logging.info("tokens: %s", " ".join(
|
195 |
+
[tokenization.printable_text(x) for x in instance.tokens]))
|
196 |
+
|
197 |
+
for feature_name in features.keys():
|
198 |
+
feature = features[feature_name]
|
199 |
+
values = []
|
200 |
+
if feature.int64_list.value:
|
201 |
+
values = feature.int64_list.value
|
202 |
+
elif feature.float_list.value:
|
203 |
+
values = feature.float_list.value
|
204 |
+
logging.info("%s: %s", feature_name, " ".join([str(x) for x in values]))
|
205 |
+
|
206 |
+
for writer in writers:
|
207 |
+
writer.close()
|
208 |
+
|
209 |
+
logging.info("Wrote %d total instances", total_written)
|
210 |
+
|
211 |
+
|
212 |
+
def create_int_feature(values):
|
213 |
+
feature = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values)))
|
214 |
+
return feature
|
215 |
+
|
216 |
+
|
217 |
+
def create_float_feature(values):
|
218 |
+
feature = tf.train.Feature(float_list=tf.train.FloatList(value=list(values)))
|
219 |
+
return feature
|
220 |
+
|
221 |
+
|
222 |
+
def create_training_instances(
|
223 |
+
input_files,
|
224 |
+
tokenizer,
|
225 |
+
processor_text_fn,
|
226 |
+
max_seq_length,
|
227 |
+
dupe_factor,
|
228 |
+
short_seq_prob,
|
229 |
+
masked_lm_prob,
|
230 |
+
max_predictions_per_seq,
|
231 |
+
rng,
|
232 |
+
do_whole_word_mask=False,
|
233 |
+
max_ngram_size=None,
|
234 |
+
):
|
235 |
+
"""Create `TrainingInstance`s from raw text."""
|
236 |
+
all_documents = [[]]
|
237 |
+
|
238 |
+
# Input file format:
|
239 |
+
# (1) One sentence per line. These should ideally be actual sentences, not
|
240 |
+
# entire paragraphs or arbitrary spans of text. (Because we use the
|
241 |
+
# sentence boundaries for the "next sentence prediction" task).
|
242 |
+
# (2) Blank lines between documents. Document boundaries are needed so
|
243 |
+
# that the "next sentence prediction" task doesn't span between documents.
|
244 |
+
for input_file in input_files:
|
245 |
+
with tf.io.gfile.GFile(input_file, "rb") as reader:
|
246 |
+
for line in reader:
|
247 |
+
line = processor_text_fn(line)
|
248 |
+
|
249 |
+
# Empty lines are used as document delimiters
|
250 |
+
if not line:
|
251 |
+
all_documents.append([])
|
252 |
+
tokens = tokenizer.tokenize(line)
|
253 |
+
if tokens:
|
254 |
+
all_documents[-1].append(tokens)
|
255 |
+
|
256 |
+
# Remove empty documents
|
257 |
+
all_documents = [x for x in all_documents if x]
|
258 |
+
rng.shuffle(all_documents)
|
259 |
+
|
260 |
+
vocab_words = list(tokenizer.vocab.keys())
|
261 |
+
instances = []
|
262 |
+
for _ in range(dupe_factor):
|
263 |
+
for document_index in range(len(all_documents)):
|
264 |
+
instances.extend(
|
265 |
+
create_instances_from_document(
|
266 |
+
all_documents, document_index, max_seq_length, short_seq_prob,
|
267 |
+
masked_lm_prob, max_predictions_per_seq, vocab_words, rng,
|
268 |
+
do_whole_word_mask, max_ngram_size))
|
269 |
+
|
270 |
+
rng.shuffle(instances)
|
271 |
+
return instances
|
272 |
+
|
273 |
+
|
274 |
+
def create_instances_from_document(
|
275 |
+
all_documents, document_index, max_seq_length, short_seq_prob,
|
276 |
+
masked_lm_prob, max_predictions_per_seq, vocab_words, rng,
|
277 |
+
do_whole_word_mask=False,
|
278 |
+
max_ngram_size=None):
|
279 |
+
"""Creates `TrainingInstance`s for a single document."""
|
280 |
+
document = all_documents[document_index]
|
281 |
+
|
282 |
+
# Account for [CLS], [SEP], [SEP]
|
283 |
+
max_num_tokens = max_seq_length - 3
|
284 |
+
|
285 |
+
# We *usually* want to fill up the entire sequence since we are padding
|
286 |
+
# to `max_seq_length` anyways, so short sequences are generally wasted
|
287 |
+
# computation. However, we *sometimes*
|
288 |
+
# (i.e., short_seq_prob == 0.1 == 10% of the time) want to use shorter
|
289 |
+
# sequences to minimize the mismatch between pre-training and fine-tuning.
|
290 |
+
# The `target_seq_length` is just a rough target however, whereas
|
291 |
+
# `max_seq_length` is a hard limit.
|
292 |
+
target_seq_length = max_num_tokens
|
293 |
+
if rng.random() < short_seq_prob:
|
294 |
+
target_seq_length = rng.randint(2, max_num_tokens)
|
295 |
+
|
296 |
+
# We DON'T just concatenate all of the tokens from a document into a long
|
297 |
+
# sequence and choose an arbitrary split point because this would make the
|
298 |
+
# next sentence prediction task too easy. Instead, we split the input into
|
299 |
+
# segments "A" and "B" based on the actual "sentences" provided by the user
|
300 |
+
# input.
|
301 |
+
instances = []
|
302 |
+
current_chunk = []
|
303 |
+
current_length = 0
|
304 |
+
i = 0
|
305 |
+
while i < len(document):
|
306 |
+
segment = document[i]
|
307 |
+
current_chunk.append(segment)
|
308 |
+
current_length += len(segment)
|
309 |
+
if i == len(document) - 1 or current_length >= target_seq_length:
|
310 |
+
if current_chunk:
|
311 |
+
# `a_end` is how many segments from `current_chunk` go into the `A`
|
312 |
+
# (first) sentence.
|
313 |
+
a_end = 1
|
314 |
+
if len(current_chunk) >= 2:
|
315 |
+
a_end = rng.randint(1, len(current_chunk) - 1)
|
316 |
+
|
317 |
+
tokens_a = []
|
318 |
+
for j in range(a_end):
|
319 |
+
tokens_a.extend(current_chunk[j])
|
320 |
+
|
321 |
+
tokens_b = []
|
322 |
+
# Random next
|
323 |
+
is_random_next = False
|
324 |
+
if len(current_chunk) == 1 or rng.random() < 0.5:
|
325 |
+
is_random_next = True
|
326 |
+
target_b_length = target_seq_length - len(tokens_a)
|
327 |
+
|
328 |
+
# This should rarely go for more than one iteration for large
|
329 |
+
# corpora. However, just to be careful, we try to make sure that
|
330 |
+
# the random document is not the same as the document
|
331 |
+
# we're processing.
|
332 |
+
for _ in range(10):
|
333 |
+
random_document_index = rng.randint(0, len(all_documents) - 1)
|
334 |
+
if random_document_index != document_index:
|
335 |
+
break
|
336 |
+
|
337 |
+
random_document = all_documents[random_document_index]
|
338 |
+
random_start = rng.randint(0, len(random_document) - 1)
|
339 |
+
for j in range(random_start, len(random_document)):
|
340 |
+
tokens_b.extend(random_document[j])
|
341 |
+
if len(tokens_b) >= target_b_length:
|
342 |
+
break
|
343 |
+
# We didn't actually use these segments so we "put them back" so
|
344 |
+
# they don't go to waste.
|
345 |
+
num_unused_segments = len(current_chunk) - a_end
|
346 |
+
i -= num_unused_segments
|
347 |
+
# Actual next
|
348 |
+
else:
|
349 |
+
is_random_next = False
|
350 |
+
for j in range(a_end, len(current_chunk)):
|
351 |
+
tokens_b.extend(current_chunk[j])
|
352 |
+
truncate_seq_pair(tokens_a, tokens_b, max_num_tokens, rng)
|
353 |
+
|
354 |
+
assert len(tokens_a) >= 1
|
355 |
+
assert len(tokens_b) >= 1
|
356 |
+
|
357 |
+
tokens = []
|
358 |
+
segment_ids = []
|
359 |
+
tokens.append("[CLS]")
|
360 |
+
segment_ids.append(0)
|
361 |
+
for token in tokens_a:
|
362 |
+
tokens.append(token)
|
363 |
+
segment_ids.append(0)
|
364 |
+
|
365 |
+
tokens.append("[SEP]")
|
366 |
+
segment_ids.append(0)
|
367 |
+
|
368 |
+
for token in tokens_b:
|
369 |
+
tokens.append(token)
|
370 |
+
segment_ids.append(1)
|
371 |
+
tokens.append("[SEP]")
|
372 |
+
segment_ids.append(1)
|
373 |
+
|
374 |
+
(tokens, masked_lm_positions,
|
375 |
+
masked_lm_labels) = create_masked_lm_predictions(
|
376 |
+
tokens, masked_lm_prob, max_predictions_per_seq, vocab_words, rng,
|
377 |
+
do_whole_word_mask, max_ngram_size)
|
378 |
+
instance = TrainingInstance(
|
379 |
+
tokens=tokens,
|
380 |
+
segment_ids=segment_ids,
|
381 |
+
is_random_next=is_random_next,
|
382 |
+
masked_lm_positions=masked_lm_positions,
|
383 |
+
masked_lm_labels=masked_lm_labels)
|
384 |
+
instances.append(instance)
|
385 |
+
current_chunk = []
|
386 |
+
current_length = 0
|
387 |
+
i += 1
|
388 |
+
|
389 |
+
return instances
|
390 |
+
|
391 |
+
|
392 |
+
MaskedLmInstance = collections.namedtuple("MaskedLmInstance",
|
393 |
+
["index", "label"])
|
394 |
+
|
395 |
+
# A _Gram is a [half-open) interval of token indices which form a word.
|
396 |
+
# E.g.,
|
397 |
+
# words: ["The", "doghouse"]
|
398 |
+
# tokens: ["The", "dog", "##house"]
|
399 |
+
# grams: [(0,1), (1,3)]
|
400 |
+
_Gram = collections.namedtuple("_Gram", ["begin", "end"])
|
401 |
+
|
402 |
+
|
403 |
+
def _window(iterable, size):
|
404 |
+
"""Helper to create a sliding window iterator with a given size.
|
405 |
+
|
406 |
+
E.g.,
|
407 |
+
input = [1, 2, 3, 4]
|
408 |
+
_window(input, 1) => [1], [2], [3], [4]
|
409 |
+
_window(input, 2) => [1, 2], [2, 3], [3, 4]
|
410 |
+
_window(input, 3) => [1, 2, 3], [2, 3, 4]
|
411 |
+
_window(input, 4) => [1, 2, 3, 4]
|
412 |
+
_window(input, 5) => None
|
413 |
+
|
414 |
+
Args:
|
415 |
+
iterable: elements to iterate over.
|
416 |
+
size: size of the window.
|
417 |
+
|
418 |
+
Yields:
|
419 |
+
Elements of `iterable` batched into a sliding window of length `size`.
|
420 |
+
"""
|
421 |
+
i = iter(iterable)
|
422 |
+
window = []
|
423 |
+
try:
|
424 |
+
for e in range(0, size):
|
425 |
+
window.append(next(i))
|
426 |
+
yield window
|
427 |
+
except StopIteration:
|
428 |
+
# handle the case where iterable's length is less than the window size.
|
429 |
+
return
|
430 |
+
for e in i:
|
431 |
+
window = window[1:] + [e]
|
432 |
+
yield window
|
433 |
+
|
434 |
+
|
435 |
+
def _contiguous(sorted_grams):
|
436 |
+
"""Test whether a sequence of grams is contiguous.
|
437 |
+
|
438 |
+
Args:
|
439 |
+
sorted_grams: _Grams which are sorted in increasing order.
|
440 |
+
Returns:
|
441 |
+
True if `sorted_grams` are touching each other.
|
442 |
+
|
443 |
+
E.g.,
|
444 |
+
_contiguous([(1, 4), (4, 5), (5, 10)]) == True
|
445 |
+
_contiguous([(1, 2), (4, 5)]) == False
|
446 |
+
"""
|
447 |
+
for a, b in _window(sorted_grams, 2):
|
448 |
+
if a.end != b.begin:
|
449 |
+
return False
|
450 |
+
return True
|
451 |
+
|
452 |
+
|
453 |
+
def _masking_ngrams(grams, max_ngram_size, max_masked_tokens, rng):
|
454 |
+
"""Create a list of masking {1, ..., n}-grams from a list of one-grams.
|
455 |
+
|
456 |
+
This is an extension of 'whole word masking' to mask multiple, contiguous
|
457 |
+
words such as (e.g., "the red boat").
|
458 |
+
|
459 |
+
Each input gram represents the token indices of a single word,
|
460 |
+
words: ["the", "red", "boat"]
|
461 |
+
tokens: ["the", "red", "boa", "##t"]
|
462 |
+
grams: [(0,1), (1,2), (2,4)]
|
463 |
+
|
464 |
+
For a `max_ngram_size` of three, possible outputs masks include:
|
465 |
+
1-grams: (0,1), (1,2), (2,4)
|
466 |
+
2-grams: (0,2), (1,4)
|
467 |
+
3-grams; (0,4)
|
468 |
+
|
469 |
+
Output masks will not overlap and contain less than `max_masked_tokens` total
|
470 |
+
tokens. E.g., for the example above with `max_masked_tokens` as three,
|
471 |
+
valid outputs are,
|
472 |
+
[(0,1), (1,2)] # "the", "red" covering two tokens
|
473 |
+
[(1,2), (2,4)] # "red", "boa", "##t" covering three tokens
|
474 |
+
|
475 |
+
The length of the selected n-gram follows a zipf weighting to
|
476 |
+
favor shorter n-gram sizes (weight(1)=1, weight(2)=1/2, weight(3)=1/3, ...).
|
477 |
+
|
478 |
+
Args:
|
479 |
+
grams: List of one-grams.
|
480 |
+
max_ngram_size: Maximum number of contiguous one-grams combined to create
|
481 |
+
an n-gram.
|
482 |
+
max_masked_tokens: Maximum total number of tokens to be masked.
|
483 |
+
rng: `random.Random` generator.
|
484 |
+
|
485 |
+
Returns:
|
486 |
+
A list of n-grams to be used as masks.
|
487 |
+
"""
|
488 |
+
if not grams:
|
489 |
+
return None
|
490 |
+
|
491 |
+
grams = sorted(grams)
|
492 |
+
num_tokens = grams[-1].end
|
493 |
+
|
494 |
+
# Ensure our grams are valid (i.e., they don't overlap).
|
495 |
+
for a, b in _window(grams, 2):
|
496 |
+
if a.end > b.begin:
|
497 |
+
raise ValueError("overlapping grams: {}".format(grams))
|
498 |
+
|
499 |
+
# Build map from n-gram length to list of n-grams.
|
500 |
+
ngrams = {i: [] for i in range(1, max_ngram_size+1)}
|
501 |
+
for gram_size in range(1, max_ngram_size+1):
|
502 |
+
for g in _window(grams, gram_size):
|
503 |
+
if _contiguous(g):
|
504 |
+
# Add an n-gram which spans these one-grams.
|
505 |
+
ngrams[gram_size].append(_Gram(g[0].begin, g[-1].end))
|
506 |
+
|
507 |
+
# Shuffle each list of n-grams.
|
508 |
+
for v in ngrams.values():
|
509 |
+
rng.shuffle(v)
|
510 |
+
|
511 |
+
# Create the weighting for n-gram length selection.
|
512 |
+
# Stored cumulatively for `random.choices` below.
|
513 |
+
cummulative_weights = list(
|
514 |
+
itertools.accumulate([1./n for n in range(1, max_ngram_size+1)]))
|
515 |
+
|
516 |
+
output_ngrams = []
|
517 |
+
# Keep a bitmask of which tokens have been masked.
|
518 |
+
masked_tokens = [False] * num_tokens
|
519 |
+
# Loop until we have enough masked tokens or there are no more candidate
|
520 |
+
# n-grams of any length.
|
521 |
+
# Each code path should ensure one or more elements from `ngrams` are removed
|
522 |
+
# to guarantee this loop terminates.
|
523 |
+
while (sum(masked_tokens) < max_masked_tokens and
|
524 |
+
sum(len(s) for s in ngrams.values())):
|
525 |
+
# Pick an n-gram size based on our weights.
|
526 |
+
sz = random.choices(range(1, max_ngram_size+1),
|
527 |
+
cum_weights=cummulative_weights)[0]
|
528 |
+
|
529 |
+
# Ensure this size doesn't result in too many masked tokens.
|
530 |
+
# E.g., a two-gram contains _at least_ two tokens.
|
531 |
+
if sum(masked_tokens) + sz > max_masked_tokens:
|
532 |
+
# All n-grams of this length are too long and can be removed from
|
533 |
+
# consideration.
|
534 |
+
ngrams[sz].clear()
|
535 |
+
continue
|
536 |
+
|
537 |
+
# All of the n-grams of this size have been used.
|
538 |
+
if not ngrams[sz]:
|
539 |
+
continue
|
540 |
+
|
541 |
+
# Choose a random n-gram of the given size.
|
542 |
+
gram = ngrams[sz].pop()
|
543 |
+
num_gram_tokens = gram.end-gram.begin
|
544 |
+
|
545 |
+
# Check if this would add too many tokens.
|
546 |
+
if num_gram_tokens + sum(masked_tokens) > max_masked_tokens:
|
547 |
+
continue
|
548 |
+
|
549 |
+
# Check if any of the tokens in this gram have already been masked.
|
550 |
+
if sum(masked_tokens[gram.begin:gram.end]):
|
551 |
+
continue
|
552 |
+
|
553 |
+
# Found a usable n-gram! Mark its tokens as masked and add it to return.
|
554 |
+
masked_tokens[gram.begin:gram.end] = [True] * (gram.end-gram.begin)
|
555 |
+
output_ngrams.append(gram)
|
556 |
+
return output_ngrams
|
557 |
+
|
558 |
+
|
559 |
+
def _tokens_to_grams(tokens):
|
560 |
+
"""Reconstitue grams (words) from `tokens`.
|
561 |
+
|
562 |
+
E.g.,
|
563 |
+
tokens: ['[CLS]', 'That', 'lit', '##tle', 'blue', 'tru', '##ck', '[SEP]']
|
564 |
+
grams: [ [1,2), [2, 4), [4,5) , [5, 6)]
|
565 |
+
|
566 |
+
Args:
|
567 |
+
tokens: list of tokens (word pieces or sentence pieces).
|
568 |
+
|
569 |
+
Returns:
|
570 |
+
List of _Grams representing spans of whole words
|
571 |
+
(without "[CLS]" and "[SEP]").
|
572 |
+
"""
|
573 |
+
grams = []
|
574 |
+
gram_start_pos = None
|
575 |
+
for i, token in enumerate(tokens):
|
576 |
+
if gram_start_pos is not None and token.startswith("##"):
|
577 |
+
continue
|
578 |
+
if gram_start_pos is not None:
|
579 |
+
grams.append(_Gram(gram_start_pos, i))
|
580 |
+
if token not in ["[CLS]", "[SEP]"]:
|
581 |
+
gram_start_pos = i
|
582 |
+
else:
|
583 |
+
gram_start_pos = None
|
584 |
+
if gram_start_pos is not None:
|
585 |
+
grams.append(_Gram(gram_start_pos, len(tokens)))
|
586 |
+
return grams
|
587 |
+
|
588 |
+
|
589 |
+
def create_masked_lm_predictions(tokens, masked_lm_prob,
|
590 |
+
max_predictions_per_seq, vocab_words, rng,
|
591 |
+
do_whole_word_mask,
|
592 |
+
max_ngram_size=None):
|
593 |
+
"""Creates the predictions for the masked LM objective."""
|
594 |
+
if do_whole_word_mask:
|
595 |
+
grams = _tokens_to_grams(tokens)
|
596 |
+
else:
|
597 |
+
# Here we consider each token to be a word to allow for sub-word masking.
|
598 |
+
if max_ngram_size:
|
599 |
+
raise ValueError("cannot use ngram masking without whole word masking")
|
600 |
+
grams = [_Gram(i, i+1) for i in range(0, len(tokens))
|
601 |
+
if tokens[i] not in ["[CLS]", "[SEP]"]]
|
602 |
+
|
603 |
+
num_to_predict = min(max_predictions_per_seq,
|
604 |
+
max(1, int(round(len(tokens) * masked_lm_prob))))
|
605 |
+
# Generate masks. If `max_ngram_size` in [0, None] it means we're doing
|
606 |
+
# whole word masking or token level masking. Both of these can be treated
|
607 |
+
# as the `max_ngram_size=1` case.
|
608 |
+
masked_grams = _masking_ngrams(grams, max_ngram_size or 1,
|
609 |
+
num_to_predict, rng)
|
610 |
+
masked_lms = []
|
611 |
+
output_tokens = list(tokens)
|
612 |
+
for gram in masked_grams:
|
613 |
+
# 80% of the time, replace all n-gram tokens with [MASK]
|
614 |
+
if rng.random() < 0.8:
|
615 |
+
replacement_action = lambda idx: "[MASK]"
|
616 |
+
else:
|
617 |
+
# 10% of the time, keep all the original n-gram tokens.
|
618 |
+
if rng.random() < 0.5:
|
619 |
+
replacement_action = lambda idx: tokens[idx]
|
620 |
+
# 10% of the time, replace each n-gram token with a random word.
|
621 |
+
else:
|
622 |
+
replacement_action = lambda idx: rng.choice(vocab_words)
|
623 |
+
|
624 |
+
for idx in range(gram.begin, gram.end):
|
625 |
+
output_tokens[idx] = replacement_action(idx)
|
626 |
+
masked_lms.append(MaskedLmInstance(index=idx, label=tokens[idx]))
|
627 |
+
|
628 |
+
assert len(masked_lms) <= num_to_predict
|
629 |
+
masked_lms = sorted(masked_lms, key=lambda x: x.index)
|
630 |
+
|
631 |
+
masked_lm_positions = []
|
632 |
+
masked_lm_labels = []
|
633 |
+
for p in masked_lms:
|
634 |
+
masked_lm_positions.append(p.index)
|
635 |
+
masked_lm_labels.append(p.label)
|
636 |
+
|
637 |
+
return (output_tokens, masked_lm_positions, masked_lm_labels)
|
638 |
+
|
639 |
+
|
640 |
+
def truncate_seq_pair(tokens_a, tokens_b, max_num_tokens, rng):
|
641 |
+
"""Truncates a pair of sequences to a maximum sequence length."""
|
642 |
+
while True:
|
643 |
+
total_length = len(tokens_a) + len(tokens_b)
|
644 |
+
if total_length <= max_num_tokens:
|
645 |
+
break
|
646 |
+
|
647 |
+
trunc_tokens = tokens_a if len(tokens_a) > len(tokens_b) else tokens_b
|
648 |
+
assert len(trunc_tokens) >= 1
|
649 |
+
|
650 |
+
# We want to sometimes truncate from the front and sometimes from the
|
651 |
+
# back to add more randomness and avoid biases.
|
652 |
+
if rng.random() < 0.5:
|
653 |
+
del trunc_tokens[0]
|
654 |
+
else:
|
655 |
+
trunc_tokens.pop()
|
656 |
+
|
657 |
+
|
658 |
+
def get_processor_text_fn(is_sentence_piece, do_lower_case):
|
659 |
+
def processor_text_fn(text):
|
660 |
+
text = tokenization.convert_to_unicode(text)
|
661 |
+
if is_sentence_piece:
|
662 |
+
# Additional preprocessing specific to the SentencePiece tokenizer.
|
663 |
+
text = tokenization.preprocess_text(text, lower=do_lower_case)
|
664 |
+
|
665 |
+
return text.strip()
|
666 |
+
|
667 |
+
return processor_text_fn
|
668 |
+
|
669 |
+
|
670 |
+
def main(_):
|
671 |
+
if FLAGS.tokenization == "WordPiece":
|
672 |
+
tokenizer = tokenization.FullTokenizer(
|
673 |
+
vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case
|
674 |
+
)
|
675 |
+
processor_text_fn = get_processor_text_fn(False, FLAGS.do_lower_case)
|
676 |
+
else:
|
677 |
+
assert FLAGS.tokenization == "SentencePiece"
|
678 |
+
tokenizer = tokenization.FullSentencePieceTokenizer(FLAGS.sp_model_file)
|
679 |
+
processor_text_fn = get_processor_text_fn(True, FLAGS.do_lower_case)
|
680 |
+
|
681 |
+
input_files = []
|
682 |
+
for input_pattern in FLAGS.input_file.split(","):
|
683 |
+
input_files.extend(tf.io.gfile.glob(input_pattern))
|
684 |
+
|
685 |
+
logging.info("*** Reading from input files ***")
|
686 |
+
for input_file in input_files:
|
687 |
+
logging.info(" %s", input_file)
|
688 |
+
|
689 |
+
rng = random.Random(FLAGS.random_seed)
|
690 |
+
instances = create_training_instances(
|
691 |
+
input_files,
|
692 |
+
tokenizer,
|
693 |
+
processor_text_fn,
|
694 |
+
FLAGS.max_seq_length,
|
695 |
+
FLAGS.dupe_factor,
|
696 |
+
FLAGS.short_seq_prob,
|
697 |
+
FLAGS.masked_lm_prob,
|
698 |
+
FLAGS.max_predictions_per_seq,
|
699 |
+
rng,
|
700 |
+
FLAGS.do_whole_word_mask,
|
701 |
+
FLAGS.max_ngram_size,
|
702 |
+
)
|
703 |
+
|
704 |
+
output_files = FLAGS.output_file.split(",")
|
705 |
+
logging.info("*** Writing to output files ***")
|
706 |
+
for output_file in output_files:
|
707 |
+
logging.info(" %s", output_file)
|
708 |
+
|
709 |
+
write_instance_to_example_files(instances, tokenizer, FLAGS.max_seq_length,
|
710 |
+
FLAGS.max_predictions_per_seq, output_files,
|
711 |
+
FLAGS.gzip_compress,
|
712 |
+
FLAGS.use_v2_feature_names)
|
713 |
+
|
714 |
+
|
715 |
+
if __name__ == "__main__":
|
716 |
+
flags.mark_flag_as_required("input_file")
|
717 |
+
flags.mark_flag_as_required("output_file")
|
718 |
+
app.run(main)
|
create_pretraining_data_test.py
ADDED
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
"""Tests for official.nlp.data.create_pretraining_data."""
|
16 |
+
import random
|
17 |
+
|
18 |
+
import tensorflow as tf, tf_keras
|
19 |
+
|
20 |
+
from official.nlp.data import create_pretraining_data as cpd
|
21 |
+
|
22 |
+
_VOCAB_WORDS = ["vocab_1", "vocab_2"]
|
23 |
+
|
24 |
+
|
25 |
+
class CreatePretrainingDataTest(tf.test.TestCase):
|
26 |
+
|
27 |
+
def assertTokens(self, input_tokens, output_tokens, masked_positions,
|
28 |
+
masked_labels):
|
29 |
+
# Ensure the masked positions are unique.
|
30 |
+
self.assertCountEqual(masked_positions, set(masked_positions))
|
31 |
+
|
32 |
+
# Ensure we can reconstruct the input from the output.
|
33 |
+
reconstructed_tokens = output_tokens
|
34 |
+
for pos, label in zip(masked_positions, masked_labels):
|
35 |
+
reconstructed_tokens[pos] = label
|
36 |
+
self.assertEqual(input_tokens, reconstructed_tokens)
|
37 |
+
|
38 |
+
# Ensure each label is valid.
|
39 |
+
for pos, label in zip(masked_positions, masked_labels):
|
40 |
+
output_token = output_tokens[pos]
|
41 |
+
if (output_token == "[MASK]" or output_token in _VOCAB_WORDS or
|
42 |
+
output_token == input_tokens[pos]):
|
43 |
+
continue
|
44 |
+
self.fail("invalid mask value: {}".format(output_token))
|
45 |
+
|
46 |
+
def test_tokens_to_grams(self):
|
47 |
+
tests = [
|
48 |
+
(["That", "cone"], [(0, 1), (1, 2)]),
|
49 |
+
(["That", "cone", "##s"], [(0, 1), (1, 3)]),
|
50 |
+
(["Swit", "##zer", "##land"], [(0, 3)]),
|
51 |
+
(["[CLS]", "Up", "##dog"], [(1, 3)]),
|
52 |
+
(["[CLS]", "Up", "##dog", "[SEP]", "Down"], [(1, 3), (4, 5)]),
|
53 |
+
]
|
54 |
+
for inp, expected in tests:
|
55 |
+
output = cpd._tokens_to_grams(inp)
|
56 |
+
self.assertEqual(expected, output)
|
57 |
+
|
58 |
+
def test_window(self):
|
59 |
+
input_list = [1, 2, 3, 4]
|
60 |
+
window_outputs = [
|
61 |
+
(1, [[1], [2], [3], [4]]),
|
62 |
+
(2, [[1, 2], [2, 3], [3, 4]]),
|
63 |
+
(3, [[1, 2, 3], [2, 3, 4]]),
|
64 |
+
(4, [[1, 2, 3, 4]]),
|
65 |
+
(5, []),
|
66 |
+
]
|
67 |
+
for window, expected in window_outputs:
|
68 |
+
output = cpd._window(input_list, window)
|
69 |
+
self.assertEqual(expected, list(output))
|
70 |
+
|
71 |
+
def test_create_masked_lm_predictions(self):
|
72 |
+
tokens = ["[CLS]", "a", "##a", "b", "##b", "c", "##c", "[SEP]"]
|
73 |
+
rng = random.Random(123)
|
74 |
+
for _ in range(0, 5):
|
75 |
+
output_tokens, masked_positions, masked_labels = (
|
76 |
+
cpd.create_masked_lm_predictions(
|
77 |
+
tokens=tokens,
|
78 |
+
masked_lm_prob=1.0,
|
79 |
+
max_predictions_per_seq=3,
|
80 |
+
vocab_words=_VOCAB_WORDS,
|
81 |
+
rng=rng,
|
82 |
+
do_whole_word_mask=False,
|
83 |
+
max_ngram_size=None))
|
84 |
+
self.assertLen(masked_positions, 3)
|
85 |
+
self.assertLen(masked_labels, 3)
|
86 |
+
self.assertTokens(tokens, output_tokens, masked_positions, masked_labels)
|
87 |
+
|
88 |
+
def test_create_masked_lm_predictions_whole_word(self):
|
89 |
+
tokens = ["[CLS]", "a", "##a", "b", "##b", "c", "##c", "[SEP]"]
|
90 |
+
rng = random.Random(345)
|
91 |
+
for _ in range(0, 5):
|
92 |
+
output_tokens, masked_positions, masked_labels = (
|
93 |
+
cpd.create_masked_lm_predictions(
|
94 |
+
tokens=tokens,
|
95 |
+
masked_lm_prob=1.0,
|
96 |
+
max_predictions_per_seq=3,
|
97 |
+
vocab_words=_VOCAB_WORDS,
|
98 |
+
rng=rng,
|
99 |
+
do_whole_word_mask=True,
|
100 |
+
max_ngram_size=None))
|
101 |
+
# since we can't get exactly three tokens without breaking a word we
|
102 |
+
# only take two.
|
103 |
+
self.assertLen(masked_positions, 2)
|
104 |
+
self.assertLen(masked_labels, 2)
|
105 |
+
self.assertTokens(tokens, output_tokens, masked_positions, masked_labels)
|
106 |
+
# ensure that we took an entire word.
|
107 |
+
self.assertIn(masked_labels, [["a", "##a"], ["b", "##b"], ["c", "##c"]])
|
108 |
+
|
109 |
+
def test_create_masked_lm_predictions_ngram(self):
|
110 |
+
tokens = ["[CLS]"] + ["tok{}".format(i) for i in range(0, 512)] + ["[SEP]"]
|
111 |
+
rng = random.Random(345)
|
112 |
+
for _ in range(0, 5):
|
113 |
+
output_tokens, masked_positions, masked_labels = (
|
114 |
+
cpd.create_masked_lm_predictions(
|
115 |
+
tokens=tokens,
|
116 |
+
masked_lm_prob=1.0,
|
117 |
+
max_predictions_per_seq=76,
|
118 |
+
vocab_words=_VOCAB_WORDS,
|
119 |
+
rng=rng,
|
120 |
+
do_whole_word_mask=True,
|
121 |
+
max_ngram_size=3))
|
122 |
+
self.assertLen(masked_positions, 76)
|
123 |
+
self.assertLen(masked_labels, 76)
|
124 |
+
self.assertTokens(tokens, output_tokens, masked_positions, masked_labels)
|
125 |
+
|
126 |
+
|
127 |
+
if __name__ == "__main__":
|
128 |
+
tf.test.main()
|
create_xlnet_pretraining_data.py
ADDED
@@ -0,0 +1,721 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
"""Create LM TF examples for XLNet."""
|
16 |
+
|
17 |
+
import dataclasses
|
18 |
+
import json
|
19 |
+
import math
|
20 |
+
import os
|
21 |
+
|
22 |
+
import random
|
23 |
+
from typing import Iterable, Mapping, List, Optional, Tuple
|
24 |
+
import unicodedata
|
25 |
+
|
26 |
+
# Import libraries
|
27 |
+
|
28 |
+
from absl import app
|
29 |
+
from absl import flags
|
30 |
+
from absl import logging
|
31 |
+
|
32 |
+
import numpy as np
|
33 |
+
import tensorflow as tf, tf_keras
|
34 |
+
|
35 |
+
from official.nlp.tools import tokenization
|
36 |
+
|
37 |
+
special_symbols = {
|
38 |
+
"<unk>": 0,
|
39 |
+
"<s>": 1,
|
40 |
+
"</s>": 2,
|
41 |
+
"<cls>": 3,
|
42 |
+
"<sep>": 4,
|
43 |
+
"<pad>": 5,
|
44 |
+
"<mask>": 6,
|
45 |
+
"<eod>": 7,
|
46 |
+
"<eop>": 8,
|
47 |
+
}
|
48 |
+
|
49 |
+
FLAGS = flags.FLAGS
|
50 |
+
|
51 |
+
flags.DEFINE_integer("seq_length", 512,
|
52 |
+
help="Sequence length.")
|
53 |
+
flags.DEFINE_integer("reuse_length", 256,
|
54 |
+
help="Number of token that can be reused as memory. "
|
55 |
+
"Could be half of `seq_len`.")
|
56 |
+
flags.DEFINE_string("input_file", None,
|
57 |
+
"Input raw text file (or comma-separated list of files).")
|
58 |
+
flags.DEFINE_string(
|
59 |
+
"save_dir", None,
|
60 |
+
"Directory for saving processed data.")
|
61 |
+
flags.DEFINE_string("sp_model_file", "",
|
62 |
+
"The path to the model used by sentence piece tokenizer.")
|
63 |
+
flags.DEFINE_bool("use_eod_token", True,
|
64 |
+
"Whether or not to include EOD tokens.")
|
65 |
+
flags.DEFINE_bool("bi_data", True, "Whether or not to use bi-directional data.")
|
66 |
+
flags.DEFINE_bool(
|
67 |
+
"do_lower_case", True,
|
68 |
+
"Whether to lower case the input text. Should be True for uncased "
|
69 |
+
"models and False for cased models.")
|
70 |
+
flags.DEFINE_integer("per_host_batch_size", 32, "Batch size per host.")
|
71 |
+
flags.DEFINE_integer("num_cores_per_host", 16,
|
72 |
+
"The number of (TPU) cores per host.")
|
73 |
+
flags.DEFINE_string("prefix", "", "Filename prefix.")
|
74 |
+
flags.DEFINE_string("suffix", "", "Filename suffix.")
|
75 |
+
|
76 |
+
flags.DEFINE_integer("task_id", None,
|
77 |
+
"The id of the current task.")
|
78 |
+
flags.DEFINE_integer("num_tasks", None,
|
79 |
+
"The total number of tasks.")
|
80 |
+
flags.DEFINE_integer("num_passes", 1, "The number of times to run the script.")
|
81 |
+
|
82 |
+
|
83 |
+
@dataclasses.dataclass
|
84 |
+
class TrainingInstance:
|
85 |
+
"""Representation of a single XLNet Pretraining instance."""
|
86 |
+
data: Iterable[int]
|
87 |
+
segment_ids: Iterable[int]
|
88 |
+
boundary_indices: Iterable[int]
|
89 |
+
label: int
|
90 |
+
|
91 |
+
def to_feature(self) -> Mapping[str, tf.train.Feature]:
|
92 |
+
feat = lambda x: tf.train.Feature(int64_list=tf.train.Int64List(value=x))
|
93 |
+
return dict(
|
94 |
+
input_word_ids=feat(self.data),
|
95 |
+
input_type_ids=feat(self.segment_ids),
|
96 |
+
boundary_indices=feat(self.boundary_indices),
|
97 |
+
label=feat([self.label]))
|
98 |
+
|
99 |
+
def to_example(self) -> tf.train.Example:
|
100 |
+
return tf.train.Example(
|
101 |
+
features=tf.train.Features(feature=self.to_feature()))
|
102 |
+
|
103 |
+
def __str__(self):
|
104 |
+
def seq_to_str(seq):
|
105 |
+
return " ".join([str(x) for x in seq])
|
106 |
+
|
107 |
+
s = ""
|
108 |
+
s += "tokens: %s\n" % seq_to_str(self.data)
|
109 |
+
s += "segment_ids: %s\n" % seq_to_str(self.segment_ids)
|
110 |
+
s += "boundary_indices: %s\n" % seq_to_str(self.boundary_indices)
|
111 |
+
s += "label: %s\n" % self.label
|
112 |
+
s += "\n"
|
113 |
+
return s
|
114 |
+
|
115 |
+
def __repr__(self):
|
116 |
+
return self.__str__()
|
117 |
+
|
118 |
+
|
119 |
+
def _preprocess_line(line: str, do_lower_case: bool = False) -> str:
|
120 |
+
"""Preprocesses an individual raw text line.
|
121 |
+
|
122 |
+
This function will:
|
123 |
+
- Remove extraneous spaces.
|
124 |
+
- Replace `` with ", and '' with ".
|
125 |
+
- Replaces accents.
|
126 |
+
- Applies lower casing.
|
127 |
+
|
128 |
+
Args:
|
129 |
+
line: The input line to preprocess.
|
130 |
+
do_lower_case: Whether or not to lower case the text.
|
131 |
+
|
132 |
+
Returns:
|
133 |
+
The preprocessed line.
|
134 |
+
|
135 |
+
"""
|
136 |
+
line = " ".join(line.split())
|
137 |
+
line = line.replace("``", "\"").replace("''", "\"")
|
138 |
+
|
139 |
+
# Replace accents.
|
140 |
+
line = unicodedata.normalize("NFKD", line)
|
141 |
+
line = "".join([c for c in line if not unicodedata.combining(c)])
|
142 |
+
|
143 |
+
if do_lower_case:
|
144 |
+
line = line.lower()
|
145 |
+
return line
|
146 |
+
|
147 |
+
|
148 |
+
def preprocess_and_tokenize_input_files(
|
149 |
+
input_files: Iterable[str],
|
150 |
+
tokenizer: tokenization.FullSentencePieceTokenizer,
|
151 |
+
use_eod: bool = True,
|
152 |
+
do_lower_case: bool = False,
|
153 |
+
log_example_freq: int = 100000) -> List[Tuple[np.array, np.array]]:
|
154 |
+
"""Preprocesses and encodes raw text from input files.
|
155 |
+
|
156 |
+
This function preprocesses raw text and encodes them into tokens using a
|
157 |
+
`SentencePieceModel` tokenization method. This also provides the sentence
|
158 |
+
indicator for each token.
|
159 |
+
|
160 |
+
Args:
|
161 |
+
input_files: The list of input file names.
|
162 |
+
tokenizer: The SentencePiece tokenizer that has the attribute `sp_model`.
|
163 |
+
use_eod: Whether or not to use an EOD indicator. If `False`, then EOD is
|
164 |
+
not included.
|
165 |
+
do_lower_case: Whether or not to apply lower casing during raw text
|
166 |
+
preprocessing.
|
167 |
+
log_example_freq: The optional field for how many lines to process before
|
168 |
+
emitting an info log.
|
169 |
+
|
170 |
+
Returns:
|
171 |
+
The preprocessed list. Each entry in the list is a tuple consisting of
|
172 |
+
the token IDs and the sentence IDs.
|
173 |
+
|
174 |
+
"""
|
175 |
+
all_data = []
|
176 |
+
eod_symbol = special_symbols["<eod>"]
|
177 |
+
|
178 |
+
total_number_of_lines = 0
|
179 |
+
|
180 |
+
# Input file format:
|
181 |
+
# (1) One sentence per line. These should ideally be actual sentences, not
|
182 |
+
# entire paragraphs or arbitrary spans of text. (Because we use the
|
183 |
+
# sentence boundaries for the "next sentence prediction" task).
|
184 |
+
# (2) Blank lines between documents. Document boundaries are needed so
|
185 |
+
# that the "next sentence prediction" task doesn't span between documents.
|
186 |
+
for input_file in input_files:
|
187 |
+
line_count = 0
|
188 |
+
logging.info("Preprocessing %s", input_file)
|
189 |
+
|
190 |
+
all_tokens = []
|
191 |
+
all_sentence_ids = []
|
192 |
+
|
193 |
+
sentence_id = True
|
194 |
+
|
195 |
+
with tf.io.gfile.GFile(input_file, "rb") as reader:
|
196 |
+
while True:
|
197 |
+
line = tokenization.convert_to_unicode(reader.readline())
|
198 |
+
if not line:
|
199 |
+
break
|
200 |
+
|
201 |
+
line_count += 1
|
202 |
+
if line_count % log_example_freq == 0:
|
203 |
+
logging.info("Loading line %d", line_count)
|
204 |
+
|
205 |
+
line = line.strip()
|
206 |
+
|
207 |
+
if not line:
|
208 |
+
if use_eod:
|
209 |
+
token_ids = [eod_symbol]
|
210 |
+
sentence_id = not sentence_id
|
211 |
+
else:
|
212 |
+
continue
|
213 |
+
else:
|
214 |
+
preprocessed_line = _preprocess_line(
|
215 |
+
line=line, do_lower_case=do_lower_case)
|
216 |
+
token_ids = tokenization.encode_ids(
|
217 |
+
sp_model=tokenizer.sp_model, text=preprocessed_line)
|
218 |
+
|
219 |
+
all_tokens.extend(token_ids)
|
220 |
+
all_sentence_ids.extend([sentence_id] * len(token_ids))
|
221 |
+
sentence_id = not sentence_id
|
222 |
+
logging.info("Finished processing %s. Number of lines: %d",
|
223 |
+
input_file, line_count)
|
224 |
+
if line_count == 0:
|
225 |
+
continue
|
226 |
+
total_number_of_lines += line_count
|
227 |
+
all_tokens = np.array(all_tokens, dtype=np.int64)
|
228 |
+
all_sentence_ids = np.array(all_sentence_ids, dtype=bool)
|
229 |
+
all_data.append((all_tokens, all_sentence_ids))
|
230 |
+
|
231 |
+
logging.info("Completed text preprocessing. Total number of lines: %d",
|
232 |
+
total_number_of_lines)
|
233 |
+
return all_data
|
234 |
+
|
235 |
+
|
236 |
+
def _reshape_to_batch_dimensions(
|
237 |
+
tokens: np.array,
|
238 |
+
sentence_ids: np.array,
|
239 |
+
per_host_batch_size: int) -> Tuple[np.array, np.array]:
|
240 |
+
"""Truncates and reshapes input data with a batch major dimension.
|
241 |
+
|
242 |
+
Args:
|
243 |
+
tokens: The input token ids. This should have the same shape as
|
244 |
+
`sentence_ids`.
|
245 |
+
sentence_ids: The input sentence ids. This should have the same shape as
|
246 |
+
`token_ids`.
|
247 |
+
per_host_batch_size: The target per-host batch size.
|
248 |
+
|
249 |
+
Returns:
|
250 |
+
The tuple of reshaped tokens and sentence_ids.
|
251 |
+
"""
|
252 |
+
num_steps = len(tokens) // per_host_batch_size
|
253 |
+
truncated_data_length = num_steps * per_host_batch_size
|
254 |
+
|
255 |
+
logging.info("per_host_batch_size: %d", per_host_batch_size)
|
256 |
+
logging.info("num_steps: %d", num_steps)
|
257 |
+
def truncate_and_reshape(a):
|
258 |
+
return a[:truncated_data_length].reshape((per_host_batch_size, num_steps))
|
259 |
+
|
260 |
+
return (truncate_and_reshape(tokens), truncate_and_reshape(sentence_ids))
|
261 |
+
|
262 |
+
|
263 |
+
def _create_a_and_b_segments(
|
264 |
+
tokens: np.array,
|
265 |
+
sentence_ids: np.array,
|
266 |
+
begin_index: int,
|
267 |
+
total_length: int,
|
268 |
+
no_cut_probability: float = 0.5):
|
269 |
+
"""Splits segments A and B from a single instance of tokens and sentence ids.
|
270 |
+
|
271 |
+
Args:
|
272 |
+
tokens: The 1D input token ids. This represents an individual entry within a
|
273 |
+
batch.
|
274 |
+
sentence_ids: The 1D input sentence ids. This represents an individual entry
|
275 |
+
within a batch. This should be the same length as `tokens`.
|
276 |
+
begin_index: The reference beginning index to split data.
|
277 |
+
total_length: The target combined length of segments A and B.
|
278 |
+
no_cut_probability: The probability of not cutting a segment despite
|
279 |
+
a cut possibly existing.
|
280 |
+
|
281 |
+
Returns:
|
282 |
+
A tuple consisting of A data, B data, and label.
|
283 |
+
|
284 |
+
"""
|
285 |
+
data_length = tokens.shape[0]
|
286 |
+
if begin_index + total_length >= data_length:
|
287 |
+
logging.info("[_create_segments]: begin_index %d + total_length %d >= "
|
288 |
+
"data_length %d", begin_index, total_length, data_length)
|
289 |
+
return None
|
290 |
+
|
291 |
+
end_index = begin_index + 1
|
292 |
+
cut_indices = []
|
293 |
+
|
294 |
+
# Identify all indices where sentence IDs change from one to the next.
|
295 |
+
while end_index < data_length:
|
296 |
+
if sentence_ids[end_index] != sentence_ids[end_index - 1]:
|
297 |
+
if end_index - begin_index >= total_length:
|
298 |
+
break
|
299 |
+
cut_indices.append(end_index)
|
300 |
+
end_index += 1
|
301 |
+
|
302 |
+
a_begin = begin_index
|
303 |
+
|
304 |
+
if not cut_indices or random.random() < no_cut_probability:
|
305 |
+
# Segments A and B are contained within the same sentence.
|
306 |
+
label = 0
|
307 |
+
if not cut_indices:
|
308 |
+
a_end = end_index
|
309 |
+
else:
|
310 |
+
a_end = random.choice(cut_indices)
|
311 |
+
b_length = max(1, total_length - (a_end - a_begin))
|
312 |
+
b_begin = random.randint(0, data_length - 1 - b_length)
|
313 |
+
b_end = b_begin + b_length
|
314 |
+
|
315 |
+
while b_begin > 0 and sentence_ids[b_begin - 1] == sentence_ids[b_begin]:
|
316 |
+
b_begin -= 1
|
317 |
+
while (b_end < data_length - 1 and
|
318 |
+
sentence_ids[b_end - 1] == sentence_ids[b_end]):
|
319 |
+
b_end += 1
|
320 |
+
else:
|
321 |
+
# Segments A and B are different sentences.
|
322 |
+
label = 1
|
323 |
+
a_end = random.choice(cut_indices)
|
324 |
+
b_begin = a_end
|
325 |
+
b_end = end_index
|
326 |
+
|
327 |
+
while a_end - a_begin + b_end - b_begin > total_length:
|
328 |
+
if a_end - a_begin > b_end - b_begin:
|
329 |
+
# Delete only the right side for the LM objective.
|
330 |
+
a_end -= 1
|
331 |
+
else:
|
332 |
+
b_end -= 1
|
333 |
+
if a_end >= data_length or b_end >= data_length:
|
334 |
+
logging.info("[_create_segments]: a_end %d or b_end %d >= data_length %d",
|
335 |
+
a_end, b_end, data_length)
|
336 |
+
return None
|
337 |
+
|
338 |
+
a_data = tokens[a_begin: a_end]
|
339 |
+
b_data = tokens[b_begin: b_end]
|
340 |
+
return a_data, b_data, label
|
341 |
+
|
342 |
+
|
343 |
+
def _is_functional_piece(piece: str) -> bool:
|
344 |
+
return piece != "<unk>" and piece.startswith("<") and piece.endswith(">")
|
345 |
+
|
346 |
+
|
347 |
+
def _is_start_piece(piece: str) -> bool:
|
348 |
+
special_pieces = set(list('!"#$%&\"()*+,-./:;?@[\\]^_`{|}~'))
|
349 |
+
if (piece.startswith("β") or piece in special_pieces):
|
350 |
+
return True
|
351 |
+
else:
|
352 |
+
return False
|
353 |
+
|
354 |
+
|
355 |
+
def _get_boundary_indices(
|
356 |
+
data: np.array,
|
357 |
+
tokenizer: tokenization.FullSentencePieceTokenizer) -> np.array:
|
358 |
+
"""Gets the boundary indices of whole words."""
|
359 |
+
seq_length = len(data)
|
360 |
+
boundary_indices = []
|
361 |
+
for index, piece in enumerate(tokenizer.convert_ids_to_tokens(data.tolist())):
|
362 |
+
if _is_start_piece(piece) and not _is_functional_piece(piece):
|
363 |
+
boundary_indices.append(index)
|
364 |
+
boundary_indices.append(seq_length)
|
365 |
+
return boundary_indices
|
366 |
+
|
367 |
+
|
368 |
+
def _convert_tokens_to_instances(
|
369 |
+
tokens: np.array,
|
370 |
+
sentence_ids: np.array,
|
371 |
+
per_host_batch_size: int,
|
372 |
+
seq_length: int,
|
373 |
+
reuse_length: int,
|
374 |
+
bi_data: bool,
|
375 |
+
tokenizer: tokenization.FullSentencePieceTokenizer,
|
376 |
+
num_cores_per_host: int = 0,
|
377 |
+
logging_frequency: int = 500) -> List[TrainingInstance]:
|
378 |
+
"""Converts tokens and sentence IDs into individual training instances.
|
379 |
+
|
380 |
+
The format of data in the XLNet pretraining task is very similar to the
|
381 |
+
BERT pretraining task. Two segments A and B are randomly sampled, and the
|
382 |
+
contatenation of A and B into a single sequence is used to perform
|
383 |
+
language modeling.
|
384 |
+
|
385 |
+
To create an XLNet Pretraining instance from a single long sequence, S:
|
386 |
+
- Create a segment of length `reuse_length`. This first segment represents
|
387 |
+
past tokens. During modeling, this segment is used to cache obtained
|
388 |
+
content representations for the segment recurrence mechanism.
|
389 |
+
- Similar to BERT, create a segment of length `seq_length` - `reuse_length`
|
390 |
+
composed of A and B segments.
|
391 |
+
For XLNet, the order is "A", "SEP", "B", "SEP", "CLS".
|
392 |
+
|
393 |
+
Args:
|
394 |
+
tokens: All tokens concatenated into a single list.
|
395 |
+
sentence_ids: All sentence IDs concatenated into a single list.
|
396 |
+
per_host_batch_size: The target batch size per host.
|
397 |
+
seq_length: The max sequence length.
|
398 |
+
reuse_length: The number of tokens to use from the previous segment.
|
399 |
+
bi_data: Whether or not to use bidirectional data.
|
400 |
+
tokenizer: The SentencePiece tokenizer that has the attribute `sp_model`.
|
401 |
+
num_cores_per_host: The number of cores per host. This is required if
|
402 |
+
`bi_data` = `True`.
|
403 |
+
logging_frequency: The frequency at which to log status updates.
|
404 |
+
|
405 |
+
Returns:
|
406 |
+
A list of `TrainingInstance` objects.
|
407 |
+
"""
|
408 |
+
instances = []
|
409 |
+
|
410 |
+
per_core_batch_size = (per_host_batch_size // num_cores_per_host
|
411 |
+
if bi_data else None)
|
412 |
+
|
413 |
+
if bi_data:
|
414 |
+
logging.info("Bi-directional data enabled.")
|
415 |
+
assert per_host_batch_size % (2 * num_cores_per_host) == 0
|
416 |
+
forward_tokens, forward_sentence_ids = _reshape_to_batch_dimensions(
|
417 |
+
tokens=tokens,
|
418 |
+
sentence_ids=sentence_ids,
|
419 |
+
per_host_batch_size=per_host_batch_size // 2)
|
420 |
+
forward_data_shape = (num_cores_per_host, 1, per_core_batch_size // 2, -1)
|
421 |
+
|
422 |
+
forward_tokens = forward_tokens.reshape(forward_data_shape)
|
423 |
+
forward_sentence_ids = forward_sentence_ids.reshape(forward_data_shape)
|
424 |
+
|
425 |
+
backwards_tokens = forward_tokens[:, :, :, ::-1]
|
426 |
+
backwards_sentence_ids = forward_sentence_ids[:, :, :, ::-1]
|
427 |
+
|
428 |
+
tokens = np.concatenate([forward_tokens, backwards_tokens], 1).reshape(
|
429 |
+
per_host_batch_size, -1)
|
430 |
+
sentence_ids = np.concatenate(
|
431 |
+
[forward_sentence_ids, backwards_sentence_ids]).reshape(
|
432 |
+
per_host_batch_size, -1)
|
433 |
+
else:
|
434 |
+
logging.info("Bi-directional data disabled.")
|
435 |
+
tokens, sentence_ids = _reshape_to_batch_dimensions(
|
436 |
+
tokens=tokens,
|
437 |
+
sentence_ids=sentence_ids,
|
438 |
+
per_host_batch_size=per_host_batch_size)
|
439 |
+
|
440 |
+
logging.info("Tokens shape: %s", tokens.shape)
|
441 |
+
|
442 |
+
data_length = tokens.shape[1]
|
443 |
+
sep = np.array([special_symbols["<sep>"]], dtype=np.int64)
|
444 |
+
cls = np.array([special_symbols["<cls>"]], dtype=np.int64)
|
445 |
+
# 2 sep, 1 cls
|
446 |
+
num_special_tokens = 3
|
447 |
+
|
448 |
+
data_index = 0
|
449 |
+
batch_number = 0
|
450 |
+
step_size = reuse_length if reuse_length else seq_length
|
451 |
+
num_batches = math.ceil(data_length / step_size)
|
452 |
+
|
453 |
+
while data_index + seq_length <= data_length:
|
454 |
+
if batch_number % logging_frequency == 0:
|
455 |
+
logging.info("Processing batch %d of %d", batch_number, num_batches)
|
456 |
+
|
457 |
+
for batch_index in range(per_host_batch_size):
|
458 |
+
previous_segment_tokens = tokens[
|
459 |
+
batch_index, data_index: data_index + reuse_length]
|
460 |
+
|
461 |
+
results = _create_a_and_b_segments(
|
462 |
+
tokens=tokens[batch_index],
|
463 |
+
sentence_ids=sentence_ids[batch_index],
|
464 |
+
begin_index=data_index + reuse_length,
|
465 |
+
total_length=seq_length - reuse_length - num_special_tokens)
|
466 |
+
|
467 |
+
if results is None:
|
468 |
+
logging.info("Stopping at data index: %d", data_index)
|
469 |
+
break
|
470 |
+
a_data, b_data, label = results
|
471 |
+
|
472 |
+
data = np.concatenate(
|
473 |
+
[previous_segment_tokens, a_data, sep, b_data, sep, cls])
|
474 |
+
a_length = a_data.shape[0]
|
475 |
+
b_length = b_data.shape[0]
|
476 |
+
segment_ids = ([0] * (reuse_length + a_length) + [0]
|
477 |
+
+ [1] * b_length + [1] + [2])
|
478 |
+
boundary_indices = _get_boundary_indices(tokenizer=tokenizer,
|
479 |
+
data=data)
|
480 |
+
assert len(data) == seq_length
|
481 |
+
assert len(segment_ids) == seq_length
|
482 |
+
assert len(boundary_indices) > 0 # pylint: disable=g-explicit-length-test
|
483 |
+
|
484 |
+
instances.append(TrainingInstance(
|
485 |
+
data=data,
|
486 |
+
segment_ids=segment_ids,
|
487 |
+
boundary_indices=boundary_indices,
|
488 |
+
label=label))
|
489 |
+
batch_number += 1
|
490 |
+
data_index += step_size
|
491 |
+
return instances
|
492 |
+
|
493 |
+
|
494 |
+
def write_instances_to_tfrecord(
|
495 |
+
instances: Iterable[TrainingInstance],
|
496 |
+
save_path: str):
|
497 |
+
"""Writes instances to TFRecord."""
|
498 |
+
record_writer = tf.io.TFRecordWriter(save_path)
|
499 |
+
logging.info("Start writing to %s.", save_path)
|
500 |
+
|
501 |
+
for i, instance in enumerate(instances):
|
502 |
+
if i < 5:
|
503 |
+
logging.info("Instance %d: %s", i, str(instance))
|
504 |
+
record_writer.write(instance.to_example().SerializeToString())
|
505 |
+
|
506 |
+
record_writer.close()
|
507 |
+
logging.info("Done writing %s.", save_path)
|
508 |
+
|
509 |
+
|
510 |
+
def shuffle_and_combine_preprocessed_data(
|
511 |
+
all_data: List[Tuple[np.array, np.array]]) -> Tuple[np.array, np.array]:
|
512 |
+
"""Shuffles and combines preprocessed token/sentence IDs from documents."""
|
513 |
+
document_permutation = np.random.permutation(len(all_data))
|
514 |
+
|
515 |
+
previous_sentence_id = None
|
516 |
+
|
517 |
+
all_tokens, all_sentence_ids = [], []
|
518 |
+
for document_index in document_permutation:
|
519 |
+
tokens, sentence_ids = all_data[document_index]
|
520 |
+
# pylint: disable=g-explicit-length-test
|
521 |
+
if len(tokens) == 0:
|
522 |
+
continue
|
523 |
+
if (previous_sentence_id is not None and
|
524 |
+
sentence_ids[0] == previous_sentence_id):
|
525 |
+
sentence_ids = np.logical_not(sentence_ids)
|
526 |
+
|
527 |
+
all_tokens.append(tokens)
|
528 |
+
all_sentence_ids.append(sentence_ids)
|
529 |
+
|
530 |
+
previous_sentence_id = sentence_ids[-1]
|
531 |
+
|
532 |
+
return np.concatenate(all_tokens), np.concatenate(all_sentence_ids)
|
533 |
+
|
534 |
+
|
535 |
+
def get_tfrecord_name(
|
536 |
+
per_host_batch_size: int,
|
537 |
+
num_cores_per_host: int,
|
538 |
+
seq_length: int,
|
539 |
+
bi_data: bool,
|
540 |
+
reuse_length: int,
|
541 |
+
do_lower_case: bool,
|
542 |
+
use_eod_token: bool,
|
543 |
+
prefix: str = "",
|
544 |
+
suffix: str = "",
|
545 |
+
pass_id: int = 0,
|
546 |
+
num_passes: int = 1,
|
547 |
+
task_id: int = None,
|
548 |
+
num_tasks: int = None) -> str:
|
549 |
+
"""Formats the resulting TFRecord name based on provided inputs."""
|
550 |
+
components = []
|
551 |
+
if prefix:
|
552 |
+
components.append(prefix)
|
553 |
+
components.append("seqlen-{}".format(seq_length))
|
554 |
+
if reuse_length == 0:
|
555 |
+
components.append("memless")
|
556 |
+
else:
|
557 |
+
components.append("reuse-{}".format(reuse_length))
|
558 |
+
components.append("bs-{}".format(per_host_batch_size))
|
559 |
+
components.append("cores-{}".format(num_cores_per_host))
|
560 |
+
|
561 |
+
if do_lower_case:
|
562 |
+
components.append("uncased")
|
563 |
+
else:
|
564 |
+
components.append("cased")
|
565 |
+
if use_eod_token:
|
566 |
+
components.append("eod")
|
567 |
+
if bi_data:
|
568 |
+
components.append("bi")
|
569 |
+
else:
|
570 |
+
components.append("uni")
|
571 |
+
|
572 |
+
if suffix:
|
573 |
+
components.append(suffix)
|
574 |
+
|
575 |
+
s = "_".join(components) + ".tfrecord"
|
576 |
+
if num_passes == 1 and task_id is None:
|
577 |
+
return s
|
578 |
+
|
579 |
+
if task_id is None:
|
580 |
+
num_tasks = 1
|
581 |
+
task_id = 0
|
582 |
+
|
583 |
+
current_shard = task_id * num_passes + pass_id
|
584 |
+
total_shards = num_tasks * num_passes
|
585 |
+
return s + "-{}-of-{}".format(current_shard, total_shards)
|
586 |
+
|
587 |
+
|
588 |
+
def create_tfrecords(
|
589 |
+
tokenizer: tokenization.FullSentencePieceTokenizer,
|
590 |
+
input_file_or_files: str,
|
591 |
+
use_eod_token: bool,
|
592 |
+
do_lower_case: bool,
|
593 |
+
per_host_batch_size: int,
|
594 |
+
seq_length: int,
|
595 |
+
reuse_length: int,
|
596 |
+
bi_data: bool,
|
597 |
+
num_cores_per_host: int,
|
598 |
+
save_dir: str,
|
599 |
+
prefix: str = "",
|
600 |
+
suffix: str = "",
|
601 |
+
num_tasks: Optional[int] = None,
|
602 |
+
task_id: Optional[int] = None,
|
603 |
+
num_passes: int = 1):
|
604 |
+
"""Runs the end-to-end preprocessing pipeline."""
|
605 |
+
|
606 |
+
logging.info("Input configuration:")
|
607 |
+
logging.info("input file(s): %s", input_file_or_files)
|
608 |
+
logging.info("use_eod_token: %s", use_eod_token)
|
609 |
+
logging.info("do_lower_case: %s", do_lower_case)
|
610 |
+
logging.info("per_host_batch_size: %d", per_host_batch_size)
|
611 |
+
logging.info("seq_length: %d", seq_length)
|
612 |
+
logging.info("reuse_length: %d", reuse_length)
|
613 |
+
logging.info("bi_data: %s", bi_data)
|
614 |
+
logging.info("num_cores_per_host: %d", num_cores_per_host)
|
615 |
+
logging.info("save_dir: %s", save_dir)
|
616 |
+
if task_id is not None and num_tasks is not None:
|
617 |
+
logging.info("task_id: %d", task_id)
|
618 |
+
logging.info("num_tasks: %d", num_tasks)
|
619 |
+
|
620 |
+
input_files = []
|
621 |
+
for input_pattern in input_file_or_files.split(","):
|
622 |
+
input_files.extend(tf.io.gfile.glob(input_pattern))
|
623 |
+
|
624 |
+
logging.info("*** Reading from input files ***")
|
625 |
+
for input_file in input_files:
|
626 |
+
logging.info(" %s", input_file)
|
627 |
+
|
628 |
+
logging.info("Shuffling the files with a fixed random seed.")
|
629 |
+
np.random.shuffle(input_files)
|
630 |
+
if num_tasks is not None:
|
631 |
+
assert task_id is not None
|
632 |
+
logging.info("Total number of input files: %d", len(input_files))
|
633 |
+
logging.info("Splitting into %d shards of %d files each.",
|
634 |
+
num_tasks, len(input_files) // num_tasks)
|
635 |
+
input_files = input_files[task_id::num_tasks]
|
636 |
+
|
637 |
+
all_data = preprocess_and_tokenize_input_files(
|
638 |
+
input_files=input_files,
|
639 |
+
tokenizer=tokenizer,
|
640 |
+
use_eod=use_eod_token,
|
641 |
+
do_lower_case=do_lower_case)
|
642 |
+
for pass_id in range(num_passes):
|
643 |
+
logging.info("Beginning pass %d of %d", pass_id, num_passes)
|
644 |
+
tokens, sentence_ids = shuffle_and_combine_preprocessed_data(all_data)
|
645 |
+
|
646 |
+
assert len(tokens) == len(sentence_ids)
|
647 |
+
|
648 |
+
filename = get_tfrecord_name(
|
649 |
+
per_host_batch_size=per_host_batch_size,
|
650 |
+
num_cores_per_host=num_cores_per_host,
|
651 |
+
seq_length=seq_length,
|
652 |
+
bi_data=bi_data,
|
653 |
+
use_eod_token=use_eod_token,
|
654 |
+
reuse_length=reuse_length,
|
655 |
+
do_lower_case=do_lower_case,
|
656 |
+
prefix=prefix,
|
657 |
+
suffix=suffix,
|
658 |
+
pass_id=pass_id,
|
659 |
+
num_passes=num_passes,
|
660 |
+
num_tasks=num_tasks,
|
661 |
+
task_id=task_id)
|
662 |
+
save_path = os.path.join(save_dir, filename)
|
663 |
+
if os.path.exists(save_path):
|
664 |
+
# If the path already exists, then we were probably preempted but
|
665 |
+
# previously wrote this file.
|
666 |
+
logging.info("%s already exists, skipping this batch.", save_path)
|
667 |
+
else:
|
668 |
+
instances = _convert_tokens_to_instances(
|
669 |
+
tokenizer=tokenizer,
|
670 |
+
tokens=tokens,
|
671 |
+
sentence_ids=sentence_ids,
|
672 |
+
per_host_batch_size=per_host_batch_size,
|
673 |
+
seq_length=seq_length,
|
674 |
+
reuse_length=reuse_length,
|
675 |
+
bi_data=bi_data,
|
676 |
+
num_cores_per_host=num_cores_per_host)
|
677 |
+
write_instances_to_tfrecord(instances=instances, save_path=save_path)
|
678 |
+
|
679 |
+
if task_id is None or task_id == 0:
|
680 |
+
corpus_info = {
|
681 |
+
"vocab_size": 32000,
|
682 |
+
"per_host_batch_size": per_host_batch_size,
|
683 |
+
"num_cores_per_host": num_cores_per_host,
|
684 |
+
"seq_length": seq_length,
|
685 |
+
"reuse_length": reuse_length,
|
686 |
+
"do_lower_case": do_lower_case,
|
687 |
+
"bi_data": bi_data,
|
688 |
+
"use_eod_token": use_eod_token,
|
689 |
+
}
|
690 |
+
corpus_fname = os.path.basename(filename) + ".json"
|
691 |
+
corpus_destination = os.path.join(save_dir, corpus_fname)
|
692 |
+
logging.info("Saving corpus info to %s", corpus_destination)
|
693 |
+
|
694 |
+
with tf.io.gfile.GFile(corpus_destination, "w") as fp:
|
695 |
+
json.dump(corpus_info, fp)
|
696 |
+
|
697 |
+
|
698 |
+
def main(_):
|
699 |
+
tokenizer = tokenization.FullSentencePieceTokenizer(FLAGS.sp_model_file)
|
700 |
+
create_tfrecords(
|
701 |
+
tokenizer=tokenizer,
|
702 |
+
input_file_or_files=FLAGS.input_file,
|
703 |
+
use_eod_token=FLAGS.use_eod_token,
|
704 |
+
do_lower_case=FLAGS.do_lower_case,
|
705 |
+
per_host_batch_size=FLAGS.per_host_batch_size,
|
706 |
+
seq_length=FLAGS.seq_length,
|
707 |
+
reuse_length=FLAGS.reuse_length,
|
708 |
+
bi_data=FLAGS.bi_data,
|
709 |
+
num_cores_per_host=FLAGS.num_cores_per_host,
|
710 |
+
save_dir=FLAGS.save_dir,
|
711 |
+
prefix=FLAGS.prefix,
|
712 |
+
suffix=FLAGS.suffix,
|
713 |
+
num_tasks=FLAGS.num_tasks,
|
714 |
+
task_id=FLAGS.task_id,
|
715 |
+
num_passes=FLAGS.num_passes)
|
716 |
+
|
717 |
+
|
718 |
+
if __name__ == "__main__":
|
719 |
+
np.random.seed(0)
|
720 |
+
logging.set_verbosity(logging.INFO)
|
721 |
+
app.run(main)
|
create_xlnet_pretraining_data_test.py
ADDED
@@ -0,0 +1,355 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
"""Tests for official.nlp.data.create_xlnet_pretraining_data."""
|
16 |
+
import os
|
17 |
+
import tempfile
|
18 |
+
from typing import List
|
19 |
+
|
20 |
+
from absl import logging
|
21 |
+
from absl.testing import parameterized
|
22 |
+
|
23 |
+
import numpy as np
|
24 |
+
import tensorflow as tf, tf_keras
|
25 |
+
|
26 |
+
from official.nlp.data import create_xlnet_pretraining_data as cpd
|
27 |
+
|
28 |
+
_VOCAB_WORDS = ["vocab_1", "vocab_2"]
|
29 |
+
|
30 |
+
|
31 |
+
# pylint: disable=invalid-name
|
32 |
+
def _create_files(
|
33 |
+
temp_dir: str, file_contents: List[List[str]]) -> List[str]:
|
34 |
+
"""Writes arbitrary documents into files."""
|
35 |
+
root_dir = tempfile.mkdtemp(dir=temp_dir)
|
36 |
+
files = []
|
37 |
+
|
38 |
+
for i, file_content in enumerate(file_contents):
|
39 |
+
destination = os.path.join(root_dir, "%d.txt" % i)
|
40 |
+
with open(destination, "wb") as f:
|
41 |
+
for line in file_content:
|
42 |
+
f.write(line.encode("utf-8"))
|
43 |
+
files.append(destination)
|
44 |
+
return files
|
45 |
+
|
46 |
+
|
47 |
+
def _get_mock_tokenizer():
|
48 |
+
"""Creates a mock tokenizer."""
|
49 |
+
|
50 |
+
class MockSpieceModel:
|
51 |
+
"""Mock Spiece model for testing."""
|
52 |
+
|
53 |
+
def __init__(self):
|
54 |
+
self._special_piece_to_id = {
|
55 |
+
"<unk>": 0,
|
56 |
+
}
|
57 |
+
for piece in set(list('!"#$%&\"()*+,-./:;?@[\\]^_`{|}~')):
|
58 |
+
self._special_piece_to_id[piece] = 1
|
59 |
+
|
60 |
+
def EncodeAsPieces(self, inputs: str) -> List[str]:
|
61 |
+
return inputs
|
62 |
+
|
63 |
+
def SampleEncodeAsPieces(self,
|
64 |
+
inputs: str,
|
65 |
+
nbest_size: int,
|
66 |
+
theta: float) -> List[str]:
|
67 |
+
del nbest_size, theta
|
68 |
+
return inputs
|
69 |
+
|
70 |
+
def PieceToId(self, piece: str) -> int:
|
71 |
+
return ord(piece[0])
|
72 |
+
|
73 |
+
def IdToPiece(self, id_: int) -> str:
|
74 |
+
return chr(id_) * 3
|
75 |
+
|
76 |
+
class Tokenizer:
|
77 |
+
"""Mock Tokenizer for testing."""
|
78 |
+
|
79 |
+
def __init__(self):
|
80 |
+
self.sp_model = MockSpieceModel()
|
81 |
+
|
82 |
+
def convert_ids_to_tokens(self, ids: List[int]) -> List[str]:
|
83 |
+
return [self.sp_model.IdToPiece(id_) for id_ in ids]
|
84 |
+
|
85 |
+
return Tokenizer()
|
86 |
+
|
87 |
+
|
88 |
+
class PreprocessDataTest(tf.test.TestCase):
|
89 |
+
|
90 |
+
def test_remove_extraneous_space(self):
|
91 |
+
line = " abc "
|
92 |
+
output = cpd._preprocess_line(line)
|
93 |
+
self.assertEqual(output, "abc")
|
94 |
+
|
95 |
+
def test_symbol_replacements(self):
|
96 |
+
self.assertEqual(cpd._preprocess_line("``abc``"), "\"abc\"")
|
97 |
+
self.assertEqual(cpd._preprocess_line("''abc''"), "\"abc\"")
|
98 |
+
|
99 |
+
def test_accent_replacements(self):
|
100 |
+
self.assertEqual(cpd._preprocess_line("Γ₯bc"), "abc")
|
101 |
+
|
102 |
+
def test_lower_case(self):
|
103 |
+
self.assertEqual(cpd._preprocess_line("ABC", do_lower_case=True), "abc")
|
104 |
+
|
105 |
+
def test_end_to_end(self):
|
106 |
+
self.assertEqual(
|
107 |
+
cpd._preprocess_line("HelLo ``wΓ³rLd``", do_lower_case=True),
|
108 |
+
"hello \"world\"")
|
109 |
+
|
110 |
+
|
111 |
+
class PreprocessAndTokenizeFilesTest(tf.test.TestCase):
|
112 |
+
|
113 |
+
def test_basic_end_to_end(self):
|
114 |
+
documents = [
|
115 |
+
[
|
116 |
+
"This is sentence 1.\n",
|
117 |
+
"This is sentence 2.\n",
|
118 |
+
"Sentence 3 is what this is.\n",
|
119 |
+
],
|
120 |
+
[
|
121 |
+
"This is the second document.\n",
|
122 |
+
"This is the second line of the second document.\n"
|
123 |
+
],
|
124 |
+
]
|
125 |
+
input_files = _create_files(temp_dir=self.get_temp_dir(),
|
126 |
+
file_contents=documents)
|
127 |
+
all_data = cpd.preprocess_and_tokenize_input_files(
|
128 |
+
input_files=input_files,
|
129 |
+
tokenizer=_get_mock_tokenizer(),
|
130 |
+
log_example_freq=1)
|
131 |
+
|
132 |
+
self.assertEqual(len(all_data), len(documents))
|
133 |
+
for token_ids, sentence_ids in all_data:
|
134 |
+
self.assertEqual(len(token_ids), len(sentence_ids))
|
135 |
+
|
136 |
+
def test_basic_correctness(self):
|
137 |
+
documents = [["a\n", "b\n", "c\n"]]
|
138 |
+
input_files = _create_files(temp_dir=self.get_temp_dir(),
|
139 |
+
file_contents=documents)
|
140 |
+
all_data = cpd.preprocess_and_tokenize_input_files(
|
141 |
+
input_files=input_files,
|
142 |
+
tokenizer=_get_mock_tokenizer(),
|
143 |
+
log_example_freq=1)
|
144 |
+
|
145 |
+
token_ids, sentence_ids = all_data[0]
|
146 |
+
|
147 |
+
self.assertAllClose(token_ids, [97, 98, 99])
|
148 |
+
self.assertAllClose(sentence_ids, [True, False, True])
|
149 |
+
|
150 |
+
def test_correctness_with_spaces_and_accents(self):
|
151 |
+
documents = [[
|
152 |
+
" Γ₯ \n",
|
153 |
+
"b \n",
|
154 |
+
" c \n",
|
155 |
+
]]
|
156 |
+
input_files = _create_files(temp_dir=self.get_temp_dir(),
|
157 |
+
file_contents=documents)
|
158 |
+
all_data = cpd.preprocess_and_tokenize_input_files(
|
159 |
+
input_files=input_files,
|
160 |
+
tokenizer=_get_mock_tokenizer(),
|
161 |
+
log_example_freq=1)
|
162 |
+
|
163 |
+
token_ids, sentence_ids = all_data[0]
|
164 |
+
|
165 |
+
self.assertAllClose(token_ids, [97, 98, 99])
|
166 |
+
self.assertAllClose(sentence_ids, [True, False, True])
|
167 |
+
|
168 |
+
|
169 |
+
class BatchReshapeTests(tf.test.TestCase):
|
170 |
+
|
171 |
+
def test_basic_functionality(self):
|
172 |
+
per_host_batch_size = 3
|
173 |
+
mock_shape = (20,)
|
174 |
+
|
175 |
+
# Should truncate and reshape.
|
176 |
+
expected_result_shape = (3, 6)
|
177 |
+
|
178 |
+
tokens = np.zeros(mock_shape)
|
179 |
+
sentence_ids = np.zeros(mock_shape)
|
180 |
+
|
181 |
+
reshaped_data = cpd._reshape_to_batch_dimensions(
|
182 |
+
tokens=tokens,
|
183 |
+
sentence_ids=sentence_ids,
|
184 |
+
per_host_batch_size=per_host_batch_size)
|
185 |
+
for values in reshaped_data:
|
186 |
+
self.assertEqual(len(values.flatten()) % per_host_batch_size, 0)
|
187 |
+
self.assertAllClose(values.shape, expected_result_shape)
|
188 |
+
|
189 |
+
|
190 |
+
class CreateSegmentsTest(tf.test.TestCase):
|
191 |
+
|
192 |
+
def test_basic_functionality(self):
|
193 |
+
data_length = 10
|
194 |
+
tokens = np.arange(data_length)
|
195 |
+
sentence_ids = np.concatenate([np.zeros(data_length // 2),
|
196 |
+
np.ones(data_length // 2)])
|
197 |
+
begin_index = 0
|
198 |
+
total_length = 8
|
199 |
+
a_data, b_data, label = cpd._create_a_and_b_segments(
|
200 |
+
tokens=tokens,
|
201 |
+
sentence_ids=sentence_ids,
|
202 |
+
begin_index=begin_index,
|
203 |
+
total_length=total_length,
|
204 |
+
no_cut_probability=0.)
|
205 |
+
self.assertAllClose(a_data, [0, 1, 2, 3])
|
206 |
+
self.assertAllClose(b_data, [5, 6, 7, 8])
|
207 |
+
self.assertEqual(label, 1)
|
208 |
+
|
209 |
+
def test_no_cut(self):
|
210 |
+
data_length = 10
|
211 |
+
tokens = np.arange(data_length)
|
212 |
+
sentence_ids = np.zeros(data_length)
|
213 |
+
|
214 |
+
begin_index = 0
|
215 |
+
total_length = 8
|
216 |
+
a_data, b_data, label = cpd._create_a_and_b_segments(
|
217 |
+
tokens=tokens,
|
218 |
+
sentence_ids=sentence_ids,
|
219 |
+
begin_index=begin_index,
|
220 |
+
total_length=total_length,
|
221 |
+
no_cut_probability=0.)
|
222 |
+
self.assertGreater(len(a_data), 0)
|
223 |
+
self.assertGreater(len(b_data), 0)
|
224 |
+
self.assertEqual(label, 0)
|
225 |
+
|
226 |
+
def test_no_cut_with_probability(self):
|
227 |
+
data_length = 10
|
228 |
+
tokens = np.arange(data_length)
|
229 |
+
sentence_ids = np.concatenate([np.zeros(data_length // 2),
|
230 |
+
np.ones(data_length // 2)])
|
231 |
+
begin_index = 0
|
232 |
+
total_length = 8
|
233 |
+
a_data, b_data, label = cpd._create_a_and_b_segments(
|
234 |
+
tokens=tokens,
|
235 |
+
sentence_ids=sentence_ids,
|
236 |
+
begin_index=begin_index,
|
237 |
+
total_length=total_length,
|
238 |
+
no_cut_probability=1.)
|
239 |
+
self.assertGreater(len(a_data), 0)
|
240 |
+
self.assertGreater(len(b_data), 0)
|
241 |
+
self.assertEqual(label, 0)
|
242 |
+
|
243 |
+
|
244 |
+
class CreateInstancesTest(tf.test.TestCase):
|
245 |
+
"""Tests conversions of Token/Sentence IDs to training instances."""
|
246 |
+
|
247 |
+
def test_basic(self):
|
248 |
+
data_length = 12
|
249 |
+
tokens = np.arange(data_length)
|
250 |
+
sentence_ids = np.zeros(data_length)
|
251 |
+
seq_length = 8
|
252 |
+
instances = cpd._convert_tokens_to_instances(
|
253 |
+
tokens=tokens,
|
254 |
+
sentence_ids=sentence_ids,
|
255 |
+
per_host_batch_size=2,
|
256 |
+
seq_length=seq_length,
|
257 |
+
reuse_length=4,
|
258 |
+
tokenizer=_get_mock_tokenizer(),
|
259 |
+
bi_data=False,
|
260 |
+
num_cores_per_host=1,
|
261 |
+
logging_frequency=1)
|
262 |
+
for instance in instances:
|
263 |
+
self.assertEqual(len(instance.data), seq_length)
|
264 |
+
self.assertEqual(len(instance.segment_ids), seq_length)
|
265 |
+
self.assertIsInstance(instance.label, int)
|
266 |
+
self.assertIsInstance(instance.boundary_indices, list)
|
267 |
+
|
268 |
+
|
269 |
+
class TFRecordPathTests(tf.test.TestCase):
|
270 |
+
|
271 |
+
def test_basic(self):
|
272 |
+
base_kwargs = dict(
|
273 |
+
per_host_batch_size=1,
|
274 |
+
num_cores_per_host=1,
|
275 |
+
seq_length=2,
|
276 |
+
reuse_length=1)
|
277 |
+
|
278 |
+
config1 = dict(
|
279 |
+
prefix="test",
|
280 |
+
suffix="",
|
281 |
+
bi_data=True,
|
282 |
+
use_eod_token=False,
|
283 |
+
do_lower_case=True)
|
284 |
+
config1.update(base_kwargs)
|
285 |
+
expectation1 = "test_seqlen-2_reuse-1_bs-1_cores-1_uncased_bi.tfrecord"
|
286 |
+
self.assertEqual(cpd.get_tfrecord_name(**config1), expectation1)
|
287 |
+
|
288 |
+
config2 = dict(
|
289 |
+
prefix="",
|
290 |
+
suffix="test",
|
291 |
+
bi_data=False,
|
292 |
+
use_eod_token=False,
|
293 |
+
do_lower_case=False)
|
294 |
+
config2.update(base_kwargs)
|
295 |
+
expectation2 = "seqlen-2_reuse-1_bs-1_cores-1_cased_uni_test.tfrecord"
|
296 |
+
self.assertEqual(cpd.get_tfrecord_name(**config2), expectation2)
|
297 |
+
|
298 |
+
config3 = dict(
|
299 |
+
prefix="",
|
300 |
+
suffix="",
|
301 |
+
use_eod_token=True,
|
302 |
+
bi_data=False,
|
303 |
+
do_lower_case=True)
|
304 |
+
config3.update(base_kwargs)
|
305 |
+
expectation3 = "seqlen-2_reuse-1_bs-1_cores-1_uncased_eod_uni.tfrecord"
|
306 |
+
self.assertEqual(cpd.get_tfrecord_name(**config3), expectation3)
|
307 |
+
|
308 |
+
|
309 |
+
class TestCreateTFRecords(parameterized.TestCase, tf.test.TestCase):
|
310 |
+
|
311 |
+
@parameterized.named_parameters(
|
312 |
+
("bi_data_only", True, False, False),
|
313 |
+
("eod_token_only", False, True, True),
|
314 |
+
("lower_case_only", False, False, True),
|
315 |
+
("all_enabled", True, True, True),
|
316 |
+
)
|
317 |
+
def test_end_to_end(self,
|
318 |
+
bi_data: bool,
|
319 |
+
use_eod_token: bool,
|
320 |
+
do_lower_case: bool):
|
321 |
+
tokenizer = _get_mock_tokenizer()
|
322 |
+
|
323 |
+
num_documents = 5
|
324 |
+
sentences_per_document = 10
|
325 |
+
document_length = 50
|
326 |
+
|
327 |
+
documents = [
|
328 |
+
["a " * document_length for _ in range(sentences_per_document)]
|
329 |
+
for _ in range(num_documents)]
|
330 |
+
|
331 |
+
save_dir = tempfile.mkdtemp(dir=self.get_temp_dir())
|
332 |
+
files = _create_files(temp_dir=self.get_temp_dir(), file_contents=documents)
|
333 |
+
|
334 |
+
cpd.create_tfrecords(
|
335 |
+
tokenizer=tokenizer,
|
336 |
+
input_file_or_files=",".join(files),
|
337 |
+
use_eod_token=use_eod_token,
|
338 |
+
do_lower_case=do_lower_case,
|
339 |
+
per_host_batch_size=8,
|
340 |
+
seq_length=8,
|
341 |
+
reuse_length=4,
|
342 |
+
bi_data=bi_data,
|
343 |
+
num_cores_per_host=2,
|
344 |
+
save_dir=save_dir)
|
345 |
+
|
346 |
+
self.assertTrue(any(filter(lambda x: x.endswith(".json"),
|
347 |
+
os.listdir(save_dir))))
|
348 |
+
self.assertTrue(any(filter(lambda x: x.endswith(".tfrecord"),
|
349 |
+
os.listdir(save_dir))))
|
350 |
+
|
351 |
+
|
352 |
+
if __name__ == "__main__":
|
353 |
+
np.random.seed(0)
|
354 |
+
logging.set_verbosity(logging.INFO)
|
355 |
+
tf.test.main()
|
data_loader.py
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
"""An abstraction that NLP models define input pipelines."""
|
16 |
+
|
17 |
+
import abc
|
18 |
+
from typing import Optional
|
19 |
+
|
20 |
+
import tensorflow as tf, tf_keras
|
21 |
+
|
22 |
+
|
23 |
+
class DataLoader(metaclass=abc.ABCMeta):
|
24 |
+
"""An abstract class defining the APIs for tf.data input pipeline."""
|
25 |
+
|
26 |
+
@abc.abstractmethod
|
27 |
+
def load(
|
28 |
+
self,
|
29 |
+
input_context: Optional[tf.distribute.InputContext] = None
|
30 |
+
) -> tf.data.Dataset:
|
31 |
+
"""Implements DataLoader load method.
|
32 |
+
|
33 |
+
Builds the entire input pipeline inside the load method. Users can define
|
34 |
+
states inside the DataLoader class and returns a tf.data dataset
|
35 |
+
object.
|
36 |
+
|
37 |
+
Args:
|
38 |
+
input_context: This is a context class that is passed to the user's input
|
39 |
+
function and contains information about the compute replicas and input
|
40 |
+
pipelines. This object is used for multi-host inputs and passed by the
|
41 |
+
distribution strategy.
|
42 |
+
|
43 |
+
Returns:
|
44 |
+
A per-host tf.data dataset. Note that, we usually create the distributed
|
45 |
+
dataset through the load method, so we should not directly return a
|
46 |
+
distributed dataset here.
|
47 |
+
"""
|
48 |
+
pass
|
data_loader_factory.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
"""A global factory to access NLP registered data loaders."""
|
16 |
+
|
17 |
+
from official.core import registry
|
18 |
+
|
19 |
+
_REGISTERED_DATA_LOADER_CLS = {}
|
20 |
+
|
21 |
+
|
22 |
+
def register_data_loader_cls(data_config_cls):
|
23 |
+
"""Decorates a factory of DataLoader for lookup by a subclass of DataConfig.
|
24 |
+
|
25 |
+
This decorator supports registration of data loaders as follows:
|
26 |
+
|
27 |
+
```
|
28 |
+
@dataclasses.dataclass
|
29 |
+
class MyDataConfig(DataConfig):
|
30 |
+
# Add fields here.
|
31 |
+
pass
|
32 |
+
|
33 |
+
@register_data_loader_cls(MyDataConfig)
|
34 |
+
class MyDataLoader:
|
35 |
+
# Inherits def __init__(self, data_config).
|
36 |
+
pass
|
37 |
+
|
38 |
+
my_data_config = MyDataConfig()
|
39 |
+
|
40 |
+
# Returns MyDataLoader(my_data_config).
|
41 |
+
my_loader = get_data_loader(my_data_config)
|
42 |
+
```
|
43 |
+
|
44 |
+
Args:
|
45 |
+
data_config_cls: a subclass of DataConfig (*not* an instance
|
46 |
+
of DataConfig).
|
47 |
+
|
48 |
+
Returns:
|
49 |
+
A callable for use as class decorator that registers the decorated class
|
50 |
+
for creation from an instance of data_config_cls.
|
51 |
+
"""
|
52 |
+
return registry.register(_REGISTERED_DATA_LOADER_CLS, data_config_cls)
|
53 |
+
|
54 |
+
|
55 |
+
def get_data_loader(data_config):
|
56 |
+
"""Creates a data_loader from data_config."""
|
57 |
+
return registry.lookup(_REGISTERED_DATA_LOADER_CLS, data_config.__class__)(
|
58 |
+
data_config)
|
data_loader_factory_test.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
"""Tests for official.nlp.data.data_loader_factory."""
|
16 |
+
|
17 |
+
import dataclasses
|
18 |
+
import tensorflow as tf, tf_keras
|
19 |
+
|
20 |
+
from official.core import config_definitions as cfg
|
21 |
+
from official.nlp.data import data_loader_factory
|
22 |
+
|
23 |
+
|
24 |
+
@dataclasses.dataclass
|
25 |
+
class MyDataConfig(cfg.DataConfig):
|
26 |
+
is_training: bool = True
|
27 |
+
|
28 |
+
|
29 |
+
@data_loader_factory.register_data_loader_cls(MyDataConfig)
|
30 |
+
class MyDataLoader:
|
31 |
+
|
32 |
+
def __init__(self, params):
|
33 |
+
self.params = params
|
34 |
+
|
35 |
+
|
36 |
+
class DataLoaderFactoryTest(tf.test.TestCase):
|
37 |
+
|
38 |
+
def test_register_and_load(self):
|
39 |
+
train_config = MyDataConfig()
|
40 |
+
train_loader = data_loader_factory.get_data_loader(train_config)
|
41 |
+
self.assertTrue(train_loader.params.is_training)
|
42 |
+
|
43 |
+
|
44 |
+
if __name__ == "__main__":
|
45 |
+
tf.test.main()
|
dual_encoder_dataloader.py
ADDED
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
"""Loads dataset for the dual encoder (retrieval) task."""
|
16 |
+
import dataclasses
|
17 |
+
import functools
|
18 |
+
import itertools
|
19 |
+
from typing import Iterable, Mapping, Optional, Tuple
|
20 |
+
|
21 |
+
import tensorflow as tf, tf_keras
|
22 |
+
import tensorflow_hub as hub
|
23 |
+
|
24 |
+
from official.common import dataset_fn
|
25 |
+
from official.core import config_definitions as cfg
|
26 |
+
from official.core import input_reader
|
27 |
+
from official.nlp.data import data_loader
|
28 |
+
from official.nlp.data import data_loader_factory
|
29 |
+
from official.nlp.modeling import layers
|
30 |
+
|
31 |
+
|
32 |
+
@dataclasses.dataclass
|
33 |
+
class DualEncoderDataConfig(cfg.DataConfig):
|
34 |
+
"""Data config for dual encoder task (tasks/dual_encoder)."""
|
35 |
+
# Either set `input_path`...
|
36 |
+
input_path: str = ''
|
37 |
+
# ...or `tfds_name` and `tfds_split` to specify input.
|
38 |
+
tfds_name: str = ''
|
39 |
+
tfds_split: str = ''
|
40 |
+
global_batch_size: int = 32
|
41 |
+
# Either build preprocessing with Python code by specifying these values...
|
42 |
+
vocab_file: str = ''
|
43 |
+
lower_case: bool = True
|
44 |
+
# ...or load preprocessing from a SavedModel at this location.
|
45 |
+
preprocessing_hub_module_url: str = ''
|
46 |
+
|
47 |
+
left_text_fields: Tuple[str] = ('left_input',)
|
48 |
+
right_text_fields: Tuple[str] = ('right_input',)
|
49 |
+
is_training: bool = True
|
50 |
+
seq_length: int = 128
|
51 |
+
file_type: str = 'tfrecord'
|
52 |
+
|
53 |
+
|
54 |
+
@data_loader_factory.register_data_loader_cls(DualEncoderDataConfig)
|
55 |
+
class DualEncoderDataLoader(data_loader.DataLoader):
|
56 |
+
"""A class to load dataset for dual encoder task (tasks/dual_encoder)."""
|
57 |
+
|
58 |
+
def __init__(self, params):
|
59 |
+
if bool(params.tfds_name) == bool(params.input_path):
|
60 |
+
raise ValueError('Must specify either `tfds_name` and `tfds_split` '
|
61 |
+
'or `input_path`.')
|
62 |
+
if bool(params.vocab_file) == bool(params.preprocessing_hub_module_url):
|
63 |
+
raise ValueError('Must specify exactly one of vocab_file (with matching '
|
64 |
+
'lower_case flag) or preprocessing_hub_module_url.')
|
65 |
+
self._params = params
|
66 |
+
self._seq_length = params.seq_length
|
67 |
+
self._left_text_fields = params.left_text_fields
|
68 |
+
self._right_text_fields = params.right_text_fields
|
69 |
+
|
70 |
+
if params.preprocessing_hub_module_url:
|
71 |
+
preprocessing_hub_module = hub.load(params.preprocessing_hub_module_url)
|
72 |
+
self._tokenizer = preprocessing_hub_module.tokenize
|
73 |
+
self._pack_inputs = functools.partial(
|
74 |
+
preprocessing_hub_module.bert_pack_inputs,
|
75 |
+
seq_length=params.seq_length)
|
76 |
+
else:
|
77 |
+
self._tokenizer = layers.BertTokenizer(
|
78 |
+
vocab_file=params.vocab_file, lower_case=params.lower_case)
|
79 |
+
self._pack_inputs = layers.BertPackInputs(
|
80 |
+
seq_length=params.seq_length,
|
81 |
+
special_tokens_dict=self._tokenizer.get_special_tokens_dict())
|
82 |
+
|
83 |
+
def _decode(self, record: tf.Tensor):
|
84 |
+
"""Decodes a serialized tf.Example."""
|
85 |
+
name_to_features = {
|
86 |
+
x: tf.io.FixedLenFeature([], tf.string)
|
87 |
+
for x in itertools.chain(
|
88 |
+
*[self._left_text_fields, self._right_text_fields])
|
89 |
+
}
|
90 |
+
example = tf.io.parse_single_example(record, name_to_features)
|
91 |
+
|
92 |
+
# tf.Example only supports tf.int64, but the TPU only supports tf.int32.
|
93 |
+
# So cast all int64 to int32.
|
94 |
+
for name in example:
|
95 |
+
t = example[name]
|
96 |
+
if t.dtype == tf.int64:
|
97 |
+
t = tf.cast(t, tf.int32)
|
98 |
+
example[name] = t
|
99 |
+
|
100 |
+
return example
|
101 |
+
|
102 |
+
def _bert_tokenize(
|
103 |
+
self, record: Mapping[str, tf.Tensor],
|
104 |
+
text_fields: Iterable[str]) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor]:
|
105 |
+
"""Tokenize the input in text_fields using BERT tokenizer.
|
106 |
+
|
107 |
+
Args:
|
108 |
+
record: A tfexample record contains the features.
|
109 |
+
text_fields: A list of fields to be tokenzied.
|
110 |
+
|
111 |
+
Returns:
|
112 |
+
The tokenized features in a tuple of (input_word_ids, input_mask,
|
113 |
+
input_type_ids).
|
114 |
+
"""
|
115 |
+
segments_text = [record[x] for x in text_fields]
|
116 |
+
segments_tokens = [self._tokenizer(s) for s in segments_text]
|
117 |
+
segments = [tf.cast(x.merge_dims(1, 2), tf.int32) for x in segments_tokens]
|
118 |
+
return self._pack_inputs(segments)
|
119 |
+
|
120 |
+
def _bert_preprocess(
|
121 |
+
self, record: Mapping[str, tf.Tensor]) -> Mapping[str, tf.Tensor]:
|
122 |
+
"""Perform the bert word piece tokenization for left and right inputs."""
|
123 |
+
|
124 |
+
def _switch_prefix(string, old, new):
|
125 |
+
if string.startswith(old): return new + string[len(old):]
|
126 |
+
raise ValueError('Expected {} to start with {}'.format(string, old))
|
127 |
+
|
128 |
+
def _switch_key_prefix(d, old, new):
|
129 |
+
return {_switch_prefix(key, old, new): value for key, value in d.items()} # pytype: disable=attribute-error # trace-all-classes
|
130 |
+
|
131 |
+
model_inputs = _switch_key_prefix(
|
132 |
+
self._bert_tokenize(record, self._left_text_fields),
|
133 |
+
'input_', 'left_')
|
134 |
+
model_inputs.update(_switch_key_prefix(
|
135 |
+
self._bert_tokenize(record, self._right_text_fields),
|
136 |
+
'input_', 'right_'))
|
137 |
+
return model_inputs
|
138 |
+
|
139 |
+
def load(self, input_context: Optional[tf.distribute.InputContext] = None):
|
140 |
+
"""Returns a tf.dataset.Dataset."""
|
141 |
+
reader = input_reader.InputReader(
|
142 |
+
params=self._params,
|
143 |
+
# Skip `decoder_fn` for tfds input.
|
144 |
+
decoder_fn=self._decode if self._params.input_path else None,
|
145 |
+
dataset_fn=dataset_fn.pick_dataset_fn(self._params.file_type),
|
146 |
+
postprocess_fn=self._bert_preprocess)
|
147 |
+
return reader.read(input_context)
|
dual_encoder_dataloader_test.py
ADDED
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
"""Tests for official.nlp.data.dual_encoder_dataloader."""
|
16 |
+
import os
|
17 |
+
|
18 |
+
from absl.testing import parameterized
|
19 |
+
import tensorflow as tf, tf_keras
|
20 |
+
|
21 |
+
from official.nlp.data import dual_encoder_dataloader
|
22 |
+
|
23 |
+
|
24 |
+
_LEFT_FEATURE_NAME = 'left_input'
|
25 |
+
_RIGHT_FEATURE_NAME = 'right_input'
|
26 |
+
|
27 |
+
|
28 |
+
def _create_fake_dataset(output_path):
|
29 |
+
"""Creates a fake dataset contains examples for training a dual encoder model.
|
30 |
+
|
31 |
+
The created dataset contains examples with two byteslist features keyed by
|
32 |
+
_LEFT_FEATURE_NAME and _RIGHT_FEATURE_NAME.
|
33 |
+
|
34 |
+
Args:
|
35 |
+
output_path: The output path of the fake dataset.
|
36 |
+
"""
|
37 |
+
def create_str_feature(values):
|
38 |
+
return tf.train.Feature(bytes_list=tf.train.BytesList(value=values))
|
39 |
+
|
40 |
+
with tf.io.TFRecordWriter(output_path) as writer:
|
41 |
+
for _ in range(100):
|
42 |
+
features = {}
|
43 |
+
features[_LEFT_FEATURE_NAME] = create_str_feature([b'hello world.'])
|
44 |
+
features[_RIGHT_FEATURE_NAME] = create_str_feature([b'world hello.'])
|
45 |
+
|
46 |
+
tf_example = tf.train.Example(
|
47 |
+
features=tf.train.Features(feature=features))
|
48 |
+
writer.write(tf_example.SerializeToString())
|
49 |
+
|
50 |
+
|
51 |
+
def _make_vocab_file(vocab, output_path):
|
52 |
+
with tf.io.gfile.GFile(output_path, 'w') as f:
|
53 |
+
f.write('\n'.join(vocab + ['']))
|
54 |
+
|
55 |
+
|
56 |
+
class DualEncoderDataTest(tf.test.TestCase, parameterized.TestCase):
|
57 |
+
|
58 |
+
def test_load_dataset(self):
|
59 |
+
seq_length = 16
|
60 |
+
batch_size = 10
|
61 |
+
train_data_path = os.path.join(self.get_temp_dir(), 'train.tf_record')
|
62 |
+
vocab_path = os.path.join(self.get_temp_dir(), 'vocab.txt')
|
63 |
+
|
64 |
+
_create_fake_dataset(train_data_path)
|
65 |
+
_make_vocab_file(
|
66 |
+
['[PAD]', '[UNK]', '[CLS]', '[SEP]', 'he', '#llo', 'world'], vocab_path)
|
67 |
+
|
68 |
+
data_config = dual_encoder_dataloader.DualEncoderDataConfig(
|
69 |
+
input_path=train_data_path,
|
70 |
+
seq_length=seq_length,
|
71 |
+
vocab_file=vocab_path,
|
72 |
+
lower_case=True,
|
73 |
+
left_text_fields=(_LEFT_FEATURE_NAME,),
|
74 |
+
right_text_fields=(_RIGHT_FEATURE_NAME,),
|
75 |
+
global_batch_size=batch_size)
|
76 |
+
dataset = dual_encoder_dataloader.DualEncoderDataLoader(
|
77 |
+
data_config).load()
|
78 |
+
features = next(iter(dataset))
|
79 |
+
self.assertCountEqual(
|
80 |
+
['left_word_ids', 'left_mask', 'left_type_ids', 'right_word_ids',
|
81 |
+
'right_mask', 'right_type_ids'],
|
82 |
+
features.keys())
|
83 |
+
self.assertEqual(features['left_word_ids'].shape, (batch_size, seq_length))
|
84 |
+
self.assertEqual(features['left_mask'].shape, (batch_size, seq_length))
|
85 |
+
self.assertEqual(features['left_type_ids'].shape, (batch_size, seq_length))
|
86 |
+
self.assertEqual(features['right_word_ids'].shape, (batch_size, seq_length))
|
87 |
+
self.assertEqual(features['right_mask'].shape, (batch_size, seq_length))
|
88 |
+
self.assertEqual(features['right_type_ids'].shape, (batch_size, seq_length))
|
89 |
+
|
90 |
+
@parameterized.parameters(False, True)
|
91 |
+
def test_load_tfds(self, use_preprocessing_hub):
|
92 |
+
seq_length = 16
|
93 |
+
batch_size = 10
|
94 |
+
if use_preprocessing_hub:
|
95 |
+
vocab_path = ''
|
96 |
+
preprocessing_hub = (
|
97 |
+
'https://tfhub.dev/tensorflow/bert_multi_cased_preprocess/3')
|
98 |
+
else:
|
99 |
+
vocab_path = os.path.join(self.get_temp_dir(), 'vocab.txt')
|
100 |
+
_make_vocab_file(
|
101 |
+
['[PAD]', '[UNK]', '[CLS]', '[SEP]', 'he', '#llo', 'world'],
|
102 |
+
vocab_path)
|
103 |
+
preprocessing_hub = ''
|
104 |
+
|
105 |
+
data_config = dual_encoder_dataloader.DualEncoderDataConfig(
|
106 |
+
tfds_name='para_crawl/enmt',
|
107 |
+
tfds_split='train',
|
108 |
+
seq_length=seq_length,
|
109 |
+
vocab_file=vocab_path,
|
110 |
+
lower_case=True,
|
111 |
+
left_text_fields=('en',),
|
112 |
+
right_text_fields=('mt',),
|
113 |
+
preprocessing_hub_module_url=preprocessing_hub,
|
114 |
+
global_batch_size=batch_size)
|
115 |
+
dataset = dual_encoder_dataloader.DualEncoderDataLoader(
|
116 |
+
data_config).load()
|
117 |
+
features = next(iter(dataset))
|
118 |
+
self.assertCountEqual(
|
119 |
+
['left_word_ids', 'left_mask', 'left_type_ids', 'right_word_ids',
|
120 |
+
'right_mask', 'right_type_ids'],
|
121 |
+
features.keys())
|
122 |
+
self.assertEqual(features['left_word_ids'].shape, (batch_size, seq_length))
|
123 |
+
self.assertEqual(features['left_mask'].shape, (batch_size, seq_length))
|
124 |
+
self.assertEqual(features['left_type_ids'].shape, (batch_size, seq_length))
|
125 |
+
self.assertEqual(features['right_word_ids'].shape, (batch_size, seq_length))
|
126 |
+
self.assertEqual(features['right_mask'].shape, (batch_size, seq_length))
|
127 |
+
self.assertEqual(features['right_type_ids'].shape, (batch_size, seq_length))
|
128 |
+
|
129 |
+
|
130 |
+
if __name__ == '__main__':
|
131 |
+
tf.test.main()
|
pretrain_dataloader.py
ADDED
@@ -0,0 +1,589 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
"""Loads dataset for the BERT pretraining task."""
|
16 |
+
import dataclasses
|
17 |
+
from typing import Mapping, Optional
|
18 |
+
|
19 |
+
from absl import logging
|
20 |
+
|
21 |
+
import numpy as np
|
22 |
+
import tensorflow as tf, tf_keras
|
23 |
+
from official.common import dataset_fn
|
24 |
+
from official.core import config_definitions as cfg
|
25 |
+
from official.core import input_reader
|
26 |
+
from official.nlp.data import data_loader
|
27 |
+
from official.nlp.data import data_loader_factory
|
28 |
+
|
29 |
+
|
30 |
+
@dataclasses.dataclass
|
31 |
+
class BertPretrainDataConfig(cfg.DataConfig):
|
32 |
+
"""Data config for BERT pretraining task (tasks/masked_lm)."""
|
33 |
+
input_path: str = ''
|
34 |
+
global_batch_size: int = 512
|
35 |
+
is_training: bool = True
|
36 |
+
seq_length: int = 512
|
37 |
+
max_predictions_per_seq: int = 76
|
38 |
+
use_next_sentence_label: bool = True
|
39 |
+
use_position_id: bool = False
|
40 |
+
# Historically, BERT implementations take `input_ids` and `segment_ids` as
|
41 |
+
# feature names. Inside the TF Model Garden implementation, the Keras model
|
42 |
+
# inputs are set as `input_word_ids` and `input_type_ids`. When
|
43 |
+
# v2_feature_names is True, the data loader assumes the tf.Examples use
|
44 |
+
# `input_word_ids` and `input_type_ids` as keys.
|
45 |
+
use_v2_feature_names: bool = False
|
46 |
+
file_type: str = 'tfrecord'
|
47 |
+
|
48 |
+
|
49 |
+
@data_loader_factory.register_data_loader_cls(BertPretrainDataConfig)
|
50 |
+
class BertPretrainDataLoader(data_loader.DataLoader):
|
51 |
+
"""A class to load dataset for bert pretraining task."""
|
52 |
+
|
53 |
+
def __init__(self, params):
|
54 |
+
"""Inits `BertPretrainDataLoader` class.
|
55 |
+
|
56 |
+
Args:
|
57 |
+
params: A `BertPretrainDataConfig` object.
|
58 |
+
"""
|
59 |
+
self._params = params
|
60 |
+
self._seq_length = params.seq_length
|
61 |
+
self._max_predictions_per_seq = params.max_predictions_per_seq
|
62 |
+
self._use_next_sentence_label = params.use_next_sentence_label
|
63 |
+
self._use_position_id = params.use_position_id
|
64 |
+
|
65 |
+
def _name_to_features(self):
|
66 |
+
name_to_features = {
|
67 |
+
'input_mask':
|
68 |
+
tf.io.FixedLenFeature([self._seq_length], tf.int64),
|
69 |
+
'masked_lm_positions':
|
70 |
+
tf.io.FixedLenFeature([self._max_predictions_per_seq], tf.int64),
|
71 |
+
'masked_lm_ids':
|
72 |
+
tf.io.FixedLenFeature([self._max_predictions_per_seq], tf.int64),
|
73 |
+
'masked_lm_weights':
|
74 |
+
tf.io.FixedLenFeature([self._max_predictions_per_seq], tf.float32),
|
75 |
+
}
|
76 |
+
if self._params.use_v2_feature_names:
|
77 |
+
name_to_features.update({
|
78 |
+
'input_word_ids': tf.io.FixedLenFeature([self._seq_length], tf.int64),
|
79 |
+
'input_type_ids': tf.io.FixedLenFeature([self._seq_length], tf.int64),
|
80 |
+
})
|
81 |
+
else:
|
82 |
+
name_to_features.update({
|
83 |
+
'input_ids': tf.io.FixedLenFeature([self._seq_length], tf.int64),
|
84 |
+
'segment_ids': tf.io.FixedLenFeature([self._seq_length], tf.int64),
|
85 |
+
})
|
86 |
+
if self._use_next_sentence_label:
|
87 |
+
name_to_features['next_sentence_labels'] = tf.io.FixedLenFeature([1],
|
88 |
+
tf.int64)
|
89 |
+
if self._use_position_id:
|
90 |
+
name_to_features['position_ids'] = tf.io.FixedLenFeature(
|
91 |
+
[self._seq_length], tf.int64)
|
92 |
+
return name_to_features
|
93 |
+
|
94 |
+
def _decode(self, record: tf.Tensor):
|
95 |
+
"""Decodes a serialized tf.Example."""
|
96 |
+
name_to_features = self._name_to_features()
|
97 |
+
example = tf.io.parse_single_example(record, name_to_features)
|
98 |
+
|
99 |
+
# tf.Example only supports tf.int64, but the TPU only supports tf.int32.
|
100 |
+
# So cast all int64 to int32.
|
101 |
+
for name in list(example.keys()):
|
102 |
+
t = example[name]
|
103 |
+
if t.dtype == tf.int64:
|
104 |
+
t = tf.cast(t, tf.int32)
|
105 |
+
example[name] = t
|
106 |
+
|
107 |
+
return example
|
108 |
+
|
109 |
+
def _parse(self, record: Mapping[str, tf.Tensor]):
|
110 |
+
"""Parses raw tensors into a dict of tensors to be consumed by the model."""
|
111 |
+
x = {
|
112 |
+
'input_mask': record['input_mask'],
|
113 |
+
'masked_lm_positions': record['masked_lm_positions'],
|
114 |
+
'masked_lm_ids': record['masked_lm_ids'],
|
115 |
+
'masked_lm_weights': record['masked_lm_weights'],
|
116 |
+
}
|
117 |
+
if self._params.use_v2_feature_names:
|
118 |
+
x['input_word_ids'] = record['input_word_ids']
|
119 |
+
x['input_type_ids'] = record['input_type_ids']
|
120 |
+
else:
|
121 |
+
x['input_word_ids'] = record['input_ids']
|
122 |
+
x['input_type_ids'] = record['segment_ids']
|
123 |
+
if self._use_next_sentence_label:
|
124 |
+
x['next_sentence_labels'] = record['next_sentence_labels']
|
125 |
+
if self._use_position_id:
|
126 |
+
x['position_ids'] = record['position_ids']
|
127 |
+
|
128 |
+
return x
|
129 |
+
|
130 |
+
def load(self, input_context: Optional[tf.distribute.InputContext] = None):
|
131 |
+
"""Returns a tf.dataset.Dataset."""
|
132 |
+
reader = input_reader.InputReader(
|
133 |
+
params=self._params,
|
134 |
+
dataset_fn=dataset_fn.pick_dataset_fn(self._params.file_type),
|
135 |
+
decoder_fn=self._decode,
|
136 |
+
parser_fn=self._parse)
|
137 |
+
return reader.read(input_context)
|
138 |
+
|
139 |
+
|
140 |
+
@dataclasses.dataclass
|
141 |
+
class XLNetPretrainDataConfig(cfg.DataConfig):
|
142 |
+
"""Data config for XLNet pretraining task.
|
143 |
+
|
144 |
+
Attributes:
|
145 |
+
input_path: See base class.
|
146 |
+
global_batch_size: See base class.
|
147 |
+
is_training: See base class.
|
148 |
+
seq_length: The length of each sequence.
|
149 |
+
max_predictions_per_seq: The number of predictions per sequence.
|
150 |
+
reuse_length: The number of tokens in a previous segment to reuse. This
|
151 |
+
should be the same value used during pretrain data creation.
|
152 |
+
sample_strategy: The strategy used to sample factorization permutations.
|
153 |
+
Possible values: 'single_token', 'whole_word', 'token_span', 'word_span'.
|
154 |
+
min_num_tokens: The minimum number of tokens to sample in a span. This is
|
155 |
+
used when `sample_strategy` is 'token_span'.
|
156 |
+
max_num_tokens: The maximum number of tokens to sample in a span. This is
|
157 |
+
used when `sample_strategy` is 'token_span'.
|
158 |
+
min_num_words: The minimum number of words to sample in a span. This is used
|
159 |
+
when `sample_strategy` is 'word_span'.
|
160 |
+
max_num_words: The maximum number of words to sample in a span. This is used
|
161 |
+
when `sample_strategy` is 'word_span'.
|
162 |
+
permutation_size: The length of the longest permutation. This can be set to
|
163 |
+
`reuse_length`. This should NOT be greater than `reuse_length`, otherwise
|
164 |
+
this may introduce data leaks.
|
165 |
+
leak_ratio: The percentage of masked tokens that are leaked.
|
166 |
+
segment_sep_id: The ID of the SEP token used when preprocessing the dataset.
|
167 |
+
segment_cls_id: The ID of the CLS token used when preprocessing the dataset.
|
168 |
+
"""
|
169 |
+
input_path: str = ''
|
170 |
+
global_batch_size: int = 512
|
171 |
+
is_training: bool = True
|
172 |
+
seq_length: int = 512
|
173 |
+
max_predictions_per_seq: int = 76
|
174 |
+
reuse_length: int = 256
|
175 |
+
sample_strategy: str = 'word_span'
|
176 |
+
min_num_tokens: int = 1
|
177 |
+
max_num_tokens: int = 5
|
178 |
+
min_num_words: int = 1
|
179 |
+
max_num_words: int = 5
|
180 |
+
permutation_size: int = 256
|
181 |
+
leak_ratio: float = 0.1
|
182 |
+
segment_sep_id: int = 4
|
183 |
+
segment_cls_id: int = 3
|
184 |
+
|
185 |
+
|
186 |
+
@data_loader_factory.register_data_loader_cls(XLNetPretrainDataConfig)
|
187 |
+
class XLNetPretrainDataLoader(data_loader.DataLoader):
|
188 |
+
"""A class to load dataset for xlnet pretraining task."""
|
189 |
+
|
190 |
+
def __init__(self, params: XLNetPretrainDataConfig):
|
191 |
+
"""Inits `XLNetPretrainDataLoader` class.
|
192 |
+
|
193 |
+
Args:
|
194 |
+
params: A `XLNetPretrainDataConfig` object.
|
195 |
+
"""
|
196 |
+
self._params = params
|
197 |
+
self._seq_length = params.seq_length
|
198 |
+
self._max_predictions_per_seq = params.max_predictions_per_seq
|
199 |
+
self._reuse_length = params.reuse_length
|
200 |
+
self._num_replicas_in_sync = None
|
201 |
+
self._permutation_size = params.permutation_size
|
202 |
+
self._sep_id = params.segment_sep_id
|
203 |
+
self._cls_id = params.segment_cls_id
|
204 |
+
self._sample_strategy = params.sample_strategy
|
205 |
+
self._leak_ratio = params.leak_ratio
|
206 |
+
|
207 |
+
def _decode(self, record: tf.Tensor):
|
208 |
+
"""Decodes a serialized tf.Example."""
|
209 |
+
name_to_features = {
|
210 |
+
'input_word_ids': tf.io.FixedLenFeature([self._seq_length], tf.int64),
|
211 |
+
'input_type_ids': tf.io.FixedLenFeature([self._seq_length], tf.int64),
|
212 |
+
'boundary_indices': tf.io.VarLenFeature(tf.int64),
|
213 |
+
}
|
214 |
+
example = tf.io.parse_single_example(record, name_to_features)
|
215 |
+
|
216 |
+
# tf.Example only supports tf.int64, but the TPU only supports tf.int32.
|
217 |
+
# So cast all int64 to int32.
|
218 |
+
for name in list(example.keys()):
|
219 |
+
t = example[name]
|
220 |
+
if t.dtype == tf.int64:
|
221 |
+
t = tf.cast(t, tf.int32)
|
222 |
+
example[name] = t
|
223 |
+
|
224 |
+
return example
|
225 |
+
|
226 |
+
def _parse(self, record: Mapping[str, tf.Tensor]):
|
227 |
+
"""Parses raw tensors into a dict of tensors to be consumed by the model."""
|
228 |
+
x = {}
|
229 |
+
|
230 |
+
inputs = record['input_word_ids']
|
231 |
+
x['input_type_ids'] = record['input_type_ids']
|
232 |
+
|
233 |
+
if self._sample_strategy in ['whole_word', 'word_span']:
|
234 |
+
boundary = tf.sparse.to_dense(record['boundary_indices'])
|
235 |
+
else:
|
236 |
+
boundary = None
|
237 |
+
|
238 |
+
input_mask = self._online_sample_mask(inputs=inputs, boundary=boundary)
|
239 |
+
|
240 |
+
if self._reuse_length > 0:
|
241 |
+
if self._permutation_size > self._reuse_length:
|
242 |
+
logging.warning(
|
243 |
+
'`permutation_size` is greater than `reuse_length` (%d > %d).'
|
244 |
+
'This may introduce data leakage.', self._permutation_size,
|
245 |
+
self._reuse_length)
|
246 |
+
|
247 |
+
# Enable the memory mechanism.
|
248 |
+
# Permute the reuse and non-reuse segments separately.
|
249 |
+
non_reuse_len = self._seq_length - self._reuse_length
|
250 |
+
if not (self._reuse_length % self._permutation_size == 0 and
|
251 |
+
non_reuse_len % self._permutation_size == 0):
|
252 |
+
raise ValueError('`reuse_length` and `seq_length` should both be '
|
253 |
+
'a multiple of `permutation_size`.')
|
254 |
+
|
255 |
+
# Creates permutation mask and target mask for the first reuse_len tokens.
|
256 |
+
# The tokens in this part are reused from the last sequence.
|
257 |
+
perm_mask_0, target_mask_0, tokens_0, masked_0 = self._get_factorization(
|
258 |
+
inputs=inputs[:self._reuse_length],
|
259 |
+
input_mask=input_mask[:self._reuse_length])
|
260 |
+
|
261 |
+
# Creates permutation mask and target mask for the rest of tokens in
|
262 |
+
# current example, which are concatenation of two new segments.
|
263 |
+
perm_mask_1, target_mask_1, tokens_1, masked_1 = self._get_factorization(
|
264 |
+
inputs[self._reuse_length:], input_mask[self._reuse_length:])
|
265 |
+
|
266 |
+
perm_mask_0 = tf.concat([
|
267 |
+
perm_mask_0,
|
268 |
+
tf.zeros([self._reuse_length, non_reuse_len], dtype=tf.int32)
|
269 |
+
],
|
270 |
+
axis=1)
|
271 |
+
perm_mask_1 = tf.concat([
|
272 |
+
tf.ones([non_reuse_len, self._reuse_length], dtype=tf.int32),
|
273 |
+
perm_mask_1
|
274 |
+
],
|
275 |
+
axis=1)
|
276 |
+
perm_mask = tf.concat([perm_mask_0, perm_mask_1], axis=0)
|
277 |
+
target_mask = tf.concat([target_mask_0, target_mask_1], axis=0)
|
278 |
+
tokens = tf.concat([tokens_0, tokens_1], axis=0)
|
279 |
+
masked_tokens = tf.concat([masked_0, masked_1], axis=0)
|
280 |
+
else:
|
281 |
+
# Disable the memory mechanism.
|
282 |
+
if self._seq_length % self._permutation_size != 0:
|
283 |
+
raise ValueError('`seq_length` should be a multiple of '
|
284 |
+
'`permutation_size`.')
|
285 |
+
# Permute the entire sequence together
|
286 |
+
perm_mask, target_mask, tokens, masked_tokens = self._get_factorization(
|
287 |
+
inputs=inputs, input_mask=input_mask)
|
288 |
+
x['permutation_mask'] = tf.reshape(perm_mask,
|
289 |
+
[self._seq_length, self._seq_length])
|
290 |
+
x['input_word_ids'] = tokens
|
291 |
+
x['masked_tokens'] = masked_tokens
|
292 |
+
|
293 |
+
target = tokens
|
294 |
+
if self._max_predictions_per_seq is not None:
|
295 |
+
indices = tf.range(self._seq_length, dtype=tf.int32)
|
296 |
+
bool_target_mask = tf.cast(target_mask, tf.bool)
|
297 |
+
indices = tf.boolean_mask(indices, bool_target_mask)
|
298 |
+
|
299 |
+
# account for extra padding due to CLS/SEP.
|
300 |
+
actual_num_predict = tf.shape(indices)[0]
|
301 |
+
pad_len = self._max_predictions_per_seq - actual_num_predict
|
302 |
+
|
303 |
+
target_mapping = tf.one_hot(indices, self._seq_length, dtype=tf.int32)
|
304 |
+
paddings = tf.zeros([pad_len, self._seq_length],
|
305 |
+
dtype=target_mapping.dtype)
|
306 |
+
target_mapping = tf.concat([target_mapping, paddings], axis=0)
|
307 |
+
x['target_mapping'] = tf.reshape(
|
308 |
+
target_mapping, [self._max_predictions_per_seq, self._seq_length])
|
309 |
+
|
310 |
+
target = tf.boolean_mask(target, bool_target_mask)
|
311 |
+
paddings = tf.zeros([pad_len], dtype=target.dtype)
|
312 |
+
target = tf.concat([target, paddings], axis=0)
|
313 |
+
x['target'] = tf.reshape(target, [self._max_predictions_per_seq])
|
314 |
+
|
315 |
+
target_mask = tf.concat([
|
316 |
+
tf.ones([actual_num_predict], dtype=tf.int32),
|
317 |
+
tf.zeros([pad_len], dtype=tf.int32)
|
318 |
+
],
|
319 |
+
axis=0)
|
320 |
+
x['target_mask'] = tf.reshape(target_mask,
|
321 |
+
[self._max_predictions_per_seq])
|
322 |
+
else:
|
323 |
+
x['target'] = tf.reshape(target, [self._seq_length])
|
324 |
+
x['target_mask'] = tf.reshape(target_mask, [self._seq_length])
|
325 |
+
return x
|
326 |
+
|
327 |
+
def _index_pair_to_mask(self, begin_indices: tf.Tensor,
|
328 |
+
end_indices: tf.Tensor,
|
329 |
+
inputs: tf.Tensor) -> tf.Tensor:
|
330 |
+
"""Converts beginning and end indices into an actual mask."""
|
331 |
+
non_func_mask = tf.logical_and(
|
332 |
+
tf.not_equal(inputs, self._sep_id), tf.not_equal(inputs, self._cls_id))
|
333 |
+
all_indices = tf.where(
|
334 |
+
non_func_mask, tf.range(self._seq_length, dtype=tf.int32),
|
335 |
+
tf.constant(-1, shape=[self._seq_length], dtype=tf.int32))
|
336 |
+
candidate_matrix = tf.cast(
|
337 |
+
tf.logical_and(all_indices[None, :] >= begin_indices[:, None],
|
338 |
+
all_indices[None, :] < end_indices[:, None]), tf.float32)
|
339 |
+
cumsum_matrix = tf.reshape(
|
340 |
+
tf.cumsum(tf.reshape(candidate_matrix, [-1])), [-1, self._seq_length])
|
341 |
+
masked_matrix = tf.cast(cumsum_matrix <= self._max_predictions_per_seq,
|
342 |
+
tf.float32)
|
343 |
+
target_mask = tf.reduce_sum(candidate_matrix * masked_matrix, axis=0)
|
344 |
+
return tf.cast(target_mask, tf.bool)
|
345 |
+
|
346 |
+
def _single_token_mask(self, inputs: tf.Tensor) -> tf.Tensor:
|
347 |
+
"""Samples individual tokens as prediction targets."""
|
348 |
+
all_indices = tf.range(self._seq_length, dtype=tf.int32)
|
349 |
+
non_func_mask = tf.logical_and(
|
350 |
+
tf.not_equal(inputs, self._sep_id), tf.not_equal(inputs, self._cls_id))
|
351 |
+
non_func_indices = tf.boolean_mask(all_indices, non_func_mask)
|
352 |
+
|
353 |
+
masked_pos = tf.random.shuffle(non_func_indices)
|
354 |
+
masked_pos = tf.sort(masked_pos[:self._max_predictions_per_seq])
|
355 |
+
|
356 |
+
sparse_indices = tf.stack([tf.zeros_like(masked_pos), masked_pos], axis=-1)
|
357 |
+
sparse_indices = tf.cast(sparse_indices, tf.int64)
|
358 |
+
|
359 |
+
sparse_indices = tf.sparse.SparseTensor(
|
360 |
+
sparse_indices,
|
361 |
+
values=tf.ones_like(masked_pos),
|
362 |
+
dense_shape=(1, self._seq_length))
|
363 |
+
|
364 |
+
target_mask = tf.sparse.to_dense(sp_input=sparse_indices, default_value=0)
|
365 |
+
|
366 |
+
return tf.squeeze(tf.cast(target_mask, tf.bool))
|
367 |
+
|
368 |
+
def _whole_word_mask(self, inputs: tf.Tensor,
|
369 |
+
boundary: tf.Tensor) -> tf.Tensor:
|
370 |
+
"""Samples whole words as prediction targets."""
|
371 |
+
pair_indices = tf.concat([boundary[:-1, None], boundary[1:, None]], axis=1)
|
372 |
+
cand_pair_indices = tf.random.shuffle(
|
373 |
+
pair_indices)[:self._max_predictions_per_seq]
|
374 |
+
begin_indices = cand_pair_indices[:, 0]
|
375 |
+
end_indices = cand_pair_indices[:, 1]
|
376 |
+
|
377 |
+
return self._index_pair_to_mask(
|
378 |
+
begin_indices=begin_indices, end_indices=end_indices, inputs=inputs)
|
379 |
+
|
380 |
+
def _token_span_mask(self, inputs: tf.Tensor) -> tf.Tensor:
|
381 |
+
"""Samples token spans as prediction targets."""
|
382 |
+
min_num_tokens = self._params.min_num_tokens
|
383 |
+
max_num_tokens = self._params.max_num_tokens
|
384 |
+
|
385 |
+
mask_alpha = self._seq_length / self._max_predictions_per_seq
|
386 |
+
round_to_int = lambda x: tf.cast(tf.round(x), tf.int32)
|
387 |
+
|
388 |
+
# Sample span lengths from a zipf distribution
|
389 |
+
span_len_seq = np.arange(min_num_tokens, max_num_tokens + 1)
|
390 |
+
probs = np.array([1.0 / (i + 1) for i in span_len_seq])
|
391 |
+
|
392 |
+
probs /= np.sum(probs)
|
393 |
+
logits = tf.constant(np.log(probs), dtype=tf.float32)
|
394 |
+
span_lens = tf.random.categorical(
|
395 |
+
logits=logits[None],
|
396 |
+
num_samples=self._max_predictions_per_seq,
|
397 |
+
dtype=tf.int32,
|
398 |
+
)[0] + min_num_tokens
|
399 |
+
|
400 |
+
# Sample the ratio [0.0, 1.0) of left context lengths
|
401 |
+
span_lens_float = tf.cast(span_lens, tf.float32)
|
402 |
+
left_ratio = tf.random.uniform(
|
403 |
+
shape=[self._max_predictions_per_seq], minval=0.0, maxval=1.0)
|
404 |
+
left_ctx_len = left_ratio * span_lens_float * (mask_alpha - 1)
|
405 |
+
left_ctx_len = round_to_int(left_ctx_len)
|
406 |
+
|
407 |
+
# Compute the offset from left start to the right end
|
408 |
+
right_offset = round_to_int(span_lens_float * mask_alpha) - left_ctx_len
|
409 |
+
|
410 |
+
# Get the actual begin and end indices
|
411 |
+
begin_indices = (
|
412 |
+
tf.cumsum(left_ctx_len) + tf.cumsum(right_offset, exclusive=True))
|
413 |
+
end_indices = begin_indices + span_lens
|
414 |
+
|
415 |
+
# Remove out of range indices
|
416 |
+
valid_idx_mask = end_indices < self._seq_length
|
417 |
+
begin_indices = tf.boolean_mask(begin_indices, valid_idx_mask)
|
418 |
+
end_indices = tf.boolean_mask(end_indices, valid_idx_mask)
|
419 |
+
|
420 |
+
# Shuffle valid indices
|
421 |
+
num_valid = tf.cast(tf.shape(begin_indices)[0], tf.int32)
|
422 |
+
order = tf.random.shuffle(tf.range(num_valid, dtype=tf.int32))
|
423 |
+
begin_indices = tf.gather(begin_indices, order)
|
424 |
+
end_indices = tf.gather(end_indices, order)
|
425 |
+
|
426 |
+
return self._index_pair_to_mask(
|
427 |
+
begin_indices=begin_indices, end_indices=end_indices, inputs=inputs)
|
428 |
+
|
429 |
+
def _word_span_mask(self, inputs: tf.Tensor, boundary: tf.Tensor):
|
430 |
+
"""Sample whole word spans as prediction targets."""
|
431 |
+
min_num_words = self._params.min_num_words
|
432 |
+
max_num_words = self._params.max_num_words
|
433 |
+
|
434 |
+
# Note: 1.2 is the token-to-word ratio
|
435 |
+
mask_alpha = self._seq_length / self._max_predictions_per_seq / 1.2
|
436 |
+
round_to_int = lambda x: tf.cast(tf.round(x), tf.int32)
|
437 |
+
|
438 |
+
# Sample span lengths from a zipf distribution
|
439 |
+
span_len_seq = np.arange(min_num_words, max_num_words + 1)
|
440 |
+
probs = np.array([1.0 / (i + 1) for i in span_len_seq])
|
441 |
+
probs /= np.sum(probs)
|
442 |
+
logits = tf.constant(np.log(probs), dtype=tf.float32)
|
443 |
+
|
444 |
+
# Sample `num_predict` words here: note that this is over sampling
|
445 |
+
span_lens = tf.random.categorical(
|
446 |
+
logits=logits[None],
|
447 |
+
num_samples=self._max_predictions_per_seq,
|
448 |
+
dtype=tf.int32,
|
449 |
+
)[0] + min_num_words
|
450 |
+
|
451 |
+
# Sample the ratio [0.0, 1.0) of left context lengths
|
452 |
+
span_lens_float = tf.cast(span_lens, tf.float32)
|
453 |
+
left_ratio = tf.random.uniform(
|
454 |
+
shape=[self._max_predictions_per_seq], minval=0.0, maxval=1.0)
|
455 |
+
left_ctx_len = left_ratio * span_lens_float * (mask_alpha - 1)
|
456 |
+
|
457 |
+
left_ctx_len = round_to_int(left_ctx_len)
|
458 |
+
right_offset = round_to_int(span_lens_float * mask_alpha) - left_ctx_len
|
459 |
+
|
460 |
+
begin_indices = (
|
461 |
+
tf.cumsum(left_ctx_len) + tf.cumsum(right_offset, exclusive=True))
|
462 |
+
end_indices = begin_indices + span_lens
|
463 |
+
|
464 |
+
# Remove out of range indices
|
465 |
+
max_boundary_index = tf.cast(tf.shape(boundary)[0] - 1, tf.int32)
|
466 |
+
valid_idx_mask = end_indices < max_boundary_index
|
467 |
+
begin_indices = tf.boolean_mask(begin_indices, valid_idx_mask)
|
468 |
+
end_indices = tf.boolean_mask(end_indices, valid_idx_mask)
|
469 |
+
|
470 |
+
begin_indices = tf.gather(boundary, begin_indices)
|
471 |
+
end_indices = tf.gather(boundary, end_indices)
|
472 |
+
|
473 |
+
# Shuffle valid indices
|
474 |
+
num_valid = tf.cast(tf.shape(begin_indices)[0], tf.int32)
|
475 |
+
order = tf.random.shuffle(tf.range(num_valid, dtype=tf.int32))
|
476 |
+
begin_indices = tf.gather(begin_indices, order)
|
477 |
+
end_indices = tf.gather(end_indices, order)
|
478 |
+
|
479 |
+
return self._index_pair_to_mask(
|
480 |
+
begin_indices=begin_indices, end_indices=end_indices, inputs=inputs)
|
481 |
+
|
482 |
+
def _online_sample_mask(self, inputs: tf.Tensor,
|
483 |
+
boundary: tf.Tensor) -> tf.Tensor:
|
484 |
+
"""Samples target positions for predictions.
|
485 |
+
|
486 |
+
Descriptions of each strategy:
|
487 |
+
- 'single_token': Samples individual tokens as prediction targets.
|
488 |
+
- 'token_span': Samples spans of tokens as prediction targets.
|
489 |
+
- 'whole_word': Samples individual words as prediction targets.
|
490 |
+
- 'word_span': Samples spans of words as prediction targets.
|
491 |
+
|
492 |
+
Args:
|
493 |
+
inputs: The input tokens.
|
494 |
+
boundary: The `int` Tensor of indices indicating whole word boundaries.
|
495 |
+
This is used in 'whole_word' and 'word_span'
|
496 |
+
|
497 |
+
Returns:
|
498 |
+
The sampled `bool` input mask.
|
499 |
+
|
500 |
+
Raises:
|
501 |
+
`ValueError`: if `max_predictions_per_seq` is not set or if boundary is
|
502 |
+
not provided for 'whole_word' and 'word_span' sample strategies.
|
503 |
+
"""
|
504 |
+
if self._max_predictions_per_seq is None:
|
505 |
+
raise ValueError('`max_predictions_per_seq` must be set.')
|
506 |
+
|
507 |
+
if boundary is None and 'word' in self._sample_strategy:
|
508 |
+
raise ValueError('`boundary` must be provided for {} strategy'.format(
|
509 |
+
self._sample_strategy))
|
510 |
+
|
511 |
+
if self._sample_strategy == 'single_token':
|
512 |
+
return self._single_token_mask(inputs)
|
513 |
+
elif self._sample_strategy == 'token_span':
|
514 |
+
return self._token_span_mask(inputs)
|
515 |
+
elif self._sample_strategy == 'whole_word':
|
516 |
+
return self._whole_word_mask(inputs, boundary)
|
517 |
+
elif self._sample_strategy == 'word_span':
|
518 |
+
return self._word_span_mask(inputs, boundary)
|
519 |
+
else:
|
520 |
+
raise NotImplementedError('Invalid sample strategy.')
|
521 |
+
|
522 |
+
def _get_factorization(self, inputs: tf.Tensor, input_mask: tf.Tensor):
|
523 |
+
"""Samples a permutation of the factorization order.
|
524 |
+
|
525 |
+
Args:
|
526 |
+
inputs: the input tokens.
|
527 |
+
input_mask: the `bool` Tensor of the same shape as `inputs`. If `True`,
|
528 |
+
then this means select for partial prediction.
|
529 |
+
|
530 |
+
Returns:
|
531 |
+
perm_mask: An `int32` Tensor of shape [seq_length, seq_length] consisting
|
532 |
+
of 0s and 1s. If perm_mask[i][j] == 0, then this means that the i-th
|
533 |
+
token (in original order) cannot attend to the jth attention token.
|
534 |
+
target_mask: An `int32` Tensor of shape [seq_len] consisting of 0s and 1s.
|
535 |
+
If target_mask[i] == 1, then the i-th token needs to be predicted and
|
536 |
+
the mask will be used as input. This token will be included in the loss.
|
537 |
+
If target_mask[i] == 0, then the token (or [SEP], [CLS]) will be used as
|
538 |
+
input. This token will not be included in the loss.
|
539 |
+
tokens: int32 Tensor of shape [seq_length].
|
540 |
+
masked_tokens: int32 Tensor of shape [seq_length].
|
541 |
+
"""
|
542 |
+
factorization_length = tf.shape(inputs)[0]
|
543 |
+
# Generate permutation indices
|
544 |
+
index = tf.range(factorization_length, dtype=tf.int32)
|
545 |
+
index = tf.transpose(tf.reshape(index, [-1, self._permutation_size]))
|
546 |
+
index = tf.random.shuffle(index)
|
547 |
+
index = tf.reshape(tf.transpose(index), [-1])
|
548 |
+
|
549 |
+
input_mask = tf.cast(input_mask, tf.bool)
|
550 |
+
|
551 |
+
# non-functional tokens
|
552 |
+
non_func_tokens = tf.logical_not(
|
553 |
+
tf.logical_or(
|
554 |
+
tf.equal(inputs, self._sep_id), tf.equal(inputs, self._cls_id)))
|
555 |
+
masked_tokens = tf.logical_and(input_mask, non_func_tokens)
|
556 |
+
non_masked_or_func_tokens = tf.logical_not(masked_tokens)
|
557 |
+
|
558 |
+
smallest_index = -2 * tf.ones([factorization_length], dtype=tf.int32)
|
559 |
+
|
560 |
+
# Similar to BERT, randomly leak some masked tokens
|
561 |
+
if self._leak_ratio > 0:
|
562 |
+
leak_tokens = tf.logical_and(
|
563 |
+
masked_tokens,
|
564 |
+
tf.random.uniform([factorization_length], maxval=1.0) <
|
565 |
+
self._leak_ratio)
|
566 |
+
can_attend_self = tf.logical_or(non_masked_or_func_tokens, leak_tokens)
|
567 |
+
else:
|
568 |
+
can_attend_self = non_masked_or_func_tokens
|
569 |
+
to_index = tf.where(can_attend_self, smallest_index, index)
|
570 |
+
from_index = tf.where(can_attend_self, to_index + 1, to_index)
|
571 |
+
|
572 |
+
# For masked tokens, can attend if i > j
|
573 |
+
# For context tokens, always can attend each other
|
574 |
+
can_attend = from_index[:, None] > to_index[None, :]
|
575 |
+
|
576 |
+
perm_mask = tf.cast(can_attend, tf.int32)
|
577 |
+
|
578 |
+
# Only masked tokens are included in the loss
|
579 |
+
target_mask = tf.cast(masked_tokens, tf.int32)
|
580 |
+
|
581 |
+
return perm_mask, target_mask, inputs, masked_tokens
|
582 |
+
|
583 |
+
def load(self, input_context: Optional[tf.distribute.InputContext] = None):
|
584 |
+
"""Returns a tf.dataset.Dataset."""
|
585 |
+
if input_context:
|
586 |
+
self._num_replicas_in_sync = input_context.num_replicas_in_sync
|
587 |
+
reader = input_reader.InputReader(
|
588 |
+
params=self._params, decoder_fn=self._decode, parser_fn=self._parse)
|
589 |
+
return reader.read(input_context)
|
pretrain_dataloader_test.py
ADDED
@@ -0,0 +1,242 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
"""Tests for official.nlp.data.pretrain_dataloader."""
|
16 |
+
import itertools
|
17 |
+
import os
|
18 |
+
|
19 |
+
from absl.testing import parameterized
|
20 |
+
import numpy as np
|
21 |
+
import tensorflow as tf, tf_keras
|
22 |
+
|
23 |
+
from official.nlp.data import pretrain_dataloader
|
24 |
+
|
25 |
+
|
26 |
+
def create_int_feature(values):
|
27 |
+
f = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values)))
|
28 |
+
return f
|
29 |
+
|
30 |
+
|
31 |
+
def _create_fake_bert_dataset(
|
32 |
+
output_path,
|
33 |
+
seq_length,
|
34 |
+
max_predictions_per_seq,
|
35 |
+
use_position_id,
|
36 |
+
use_next_sentence_label,
|
37 |
+
use_v2_feature_names=False):
|
38 |
+
"""Creates a fake dataset."""
|
39 |
+
writer = tf.io.TFRecordWriter(output_path)
|
40 |
+
|
41 |
+
def create_float_feature(values):
|
42 |
+
f = tf.train.Feature(float_list=tf.train.FloatList(value=list(values)))
|
43 |
+
return f
|
44 |
+
|
45 |
+
for _ in range(100):
|
46 |
+
features = {}
|
47 |
+
input_ids = np.random.randint(100, size=(seq_length))
|
48 |
+
features["input_mask"] = create_int_feature(np.ones_like(input_ids))
|
49 |
+
if use_v2_feature_names:
|
50 |
+
features["input_word_ids"] = create_int_feature(input_ids)
|
51 |
+
features["input_type_ids"] = create_int_feature(np.ones_like(input_ids))
|
52 |
+
else:
|
53 |
+
features["input_ids"] = create_int_feature(input_ids)
|
54 |
+
features["segment_ids"] = create_int_feature(np.ones_like(input_ids))
|
55 |
+
|
56 |
+
features["masked_lm_positions"] = create_int_feature(
|
57 |
+
np.random.randint(100, size=(max_predictions_per_seq)))
|
58 |
+
features["masked_lm_ids"] = create_int_feature(
|
59 |
+
np.random.randint(100, size=(max_predictions_per_seq)))
|
60 |
+
features["masked_lm_weights"] = create_float_feature(
|
61 |
+
[1.0] * max_predictions_per_seq)
|
62 |
+
|
63 |
+
if use_next_sentence_label:
|
64 |
+
features["next_sentence_labels"] = create_int_feature([1])
|
65 |
+
|
66 |
+
if use_position_id:
|
67 |
+
features["position_ids"] = create_int_feature(range(0, seq_length))
|
68 |
+
|
69 |
+
tf_example = tf.train.Example(features=tf.train.Features(feature=features))
|
70 |
+
writer.write(tf_example.SerializeToString())
|
71 |
+
writer.close()
|
72 |
+
|
73 |
+
|
74 |
+
def _create_fake_xlnet_dataset(
|
75 |
+
output_path, seq_length, max_predictions_per_seq):
|
76 |
+
"""Creates a fake dataset."""
|
77 |
+
writer = tf.io.TFRecordWriter(output_path)
|
78 |
+
for _ in range(100):
|
79 |
+
features = {}
|
80 |
+
input_ids = np.random.randint(100, size=(seq_length))
|
81 |
+
num_boundary_indices = np.random.randint(1, seq_length)
|
82 |
+
|
83 |
+
if max_predictions_per_seq is not None:
|
84 |
+
input_mask = np.zeros_like(input_ids)
|
85 |
+
input_mask[:max_predictions_per_seq] = 1
|
86 |
+
np.random.shuffle(input_mask)
|
87 |
+
else:
|
88 |
+
input_mask = np.ones_like(input_ids)
|
89 |
+
|
90 |
+
features["input_mask"] = create_int_feature(input_mask)
|
91 |
+
features["input_word_ids"] = create_int_feature(input_ids)
|
92 |
+
features["input_type_ids"] = create_int_feature(np.ones_like(input_ids))
|
93 |
+
features["boundary_indices"] = create_int_feature(
|
94 |
+
sorted(np.random.randint(seq_length, size=(num_boundary_indices))))
|
95 |
+
features["target"] = create_int_feature(input_ids + 1)
|
96 |
+
features["label"] = create_int_feature([1])
|
97 |
+
tf_example = tf.train.Example(features=tf.train.Features(feature=features))
|
98 |
+
writer.write(tf_example.SerializeToString())
|
99 |
+
writer.close()
|
100 |
+
|
101 |
+
|
102 |
+
class BertPretrainDataTest(tf.test.TestCase, parameterized.TestCase):
|
103 |
+
|
104 |
+
@parameterized.parameters(itertools.product(
|
105 |
+
(False, True),
|
106 |
+
(False, True),
|
107 |
+
))
|
108 |
+
def test_load_data(self, use_next_sentence_label, use_position_id):
|
109 |
+
train_data_path = os.path.join(self.get_temp_dir(), "train.tf_record")
|
110 |
+
seq_length = 128
|
111 |
+
max_predictions_per_seq = 20
|
112 |
+
_create_fake_bert_dataset(
|
113 |
+
train_data_path,
|
114 |
+
seq_length,
|
115 |
+
max_predictions_per_seq,
|
116 |
+
use_next_sentence_label=use_next_sentence_label,
|
117 |
+
use_position_id=use_position_id)
|
118 |
+
data_config = pretrain_dataloader.BertPretrainDataConfig(
|
119 |
+
input_path=train_data_path,
|
120 |
+
max_predictions_per_seq=max_predictions_per_seq,
|
121 |
+
seq_length=seq_length,
|
122 |
+
global_batch_size=10,
|
123 |
+
is_training=True,
|
124 |
+
use_next_sentence_label=use_next_sentence_label,
|
125 |
+
use_position_id=use_position_id)
|
126 |
+
|
127 |
+
dataset = pretrain_dataloader.BertPretrainDataLoader(data_config).load()
|
128 |
+
features = next(iter(dataset))
|
129 |
+
self.assertLen(features,
|
130 |
+
6 + int(use_next_sentence_label) + int(use_position_id))
|
131 |
+
self.assertIn("input_word_ids", features)
|
132 |
+
self.assertIn("input_mask", features)
|
133 |
+
self.assertIn("input_type_ids", features)
|
134 |
+
self.assertIn("masked_lm_positions", features)
|
135 |
+
self.assertIn("masked_lm_ids", features)
|
136 |
+
self.assertIn("masked_lm_weights", features)
|
137 |
+
|
138 |
+
self.assertEqual("next_sentence_labels" in features,
|
139 |
+
use_next_sentence_label)
|
140 |
+
self.assertEqual("position_ids" in features, use_position_id)
|
141 |
+
|
142 |
+
def test_v2_feature_names(self):
|
143 |
+
train_data_path = os.path.join(self.get_temp_dir(), "train.tf_record")
|
144 |
+
seq_length = 128
|
145 |
+
max_predictions_per_seq = 20
|
146 |
+
_create_fake_bert_dataset(
|
147 |
+
train_data_path,
|
148 |
+
seq_length,
|
149 |
+
max_predictions_per_seq,
|
150 |
+
use_next_sentence_label=True,
|
151 |
+
use_position_id=False,
|
152 |
+
use_v2_feature_names=True)
|
153 |
+
data_config = pretrain_dataloader.BertPretrainDataConfig(
|
154 |
+
input_path=train_data_path,
|
155 |
+
max_predictions_per_seq=max_predictions_per_seq,
|
156 |
+
seq_length=seq_length,
|
157 |
+
global_batch_size=10,
|
158 |
+
is_training=True,
|
159 |
+
use_next_sentence_label=True,
|
160 |
+
use_position_id=False,
|
161 |
+
use_v2_feature_names=True)
|
162 |
+
|
163 |
+
dataset = pretrain_dataloader.BertPretrainDataLoader(data_config).load()
|
164 |
+
features = next(iter(dataset))
|
165 |
+
self.assertIn("input_word_ids", features)
|
166 |
+
self.assertIn("input_mask", features)
|
167 |
+
self.assertIn("input_type_ids", features)
|
168 |
+
self.assertIn("masked_lm_positions", features)
|
169 |
+
self.assertIn("masked_lm_ids", features)
|
170 |
+
self.assertIn("masked_lm_weights", features)
|
171 |
+
|
172 |
+
|
173 |
+
class XLNetPretrainDataTest(parameterized.TestCase, tf.test.TestCase):
|
174 |
+
|
175 |
+
@parameterized.parameters(itertools.product(
|
176 |
+
("single_token", "whole_word", "token_span"),
|
177 |
+
(0, 64),
|
178 |
+
(20, None),
|
179 |
+
))
|
180 |
+
def test_load_data(
|
181 |
+
self, sample_strategy, reuse_length, max_predictions_per_seq):
|
182 |
+
train_data_path = os.path.join(self.get_temp_dir(), "train.tf_record")
|
183 |
+
seq_length = 128
|
184 |
+
batch_size = 5
|
185 |
+
|
186 |
+
_create_fake_xlnet_dataset(
|
187 |
+
train_data_path, seq_length, max_predictions_per_seq)
|
188 |
+
|
189 |
+
data_config = pretrain_dataloader.XLNetPretrainDataConfig(
|
190 |
+
input_path=train_data_path,
|
191 |
+
max_predictions_per_seq=max_predictions_per_seq,
|
192 |
+
seq_length=seq_length,
|
193 |
+
global_batch_size=batch_size,
|
194 |
+
is_training=True,
|
195 |
+
reuse_length=reuse_length,
|
196 |
+
sample_strategy=sample_strategy,
|
197 |
+
min_num_tokens=1,
|
198 |
+
max_num_tokens=2,
|
199 |
+
permutation_size=seq_length // 2,
|
200 |
+
leak_ratio=0.1)
|
201 |
+
|
202 |
+
if max_predictions_per_seq is None:
|
203 |
+
with self.assertRaises(ValueError):
|
204 |
+
dataset = pretrain_dataloader.XLNetPretrainDataLoader(
|
205 |
+
data_config).load()
|
206 |
+
features = next(iter(dataset))
|
207 |
+
else:
|
208 |
+
dataset = pretrain_dataloader.XLNetPretrainDataLoader(data_config).load()
|
209 |
+
features = next(iter(dataset))
|
210 |
+
|
211 |
+
self.assertIn("input_word_ids", features)
|
212 |
+
self.assertIn("input_type_ids", features)
|
213 |
+
self.assertIn("permutation_mask", features)
|
214 |
+
self.assertIn("masked_tokens", features)
|
215 |
+
self.assertIn("target", features)
|
216 |
+
self.assertIn("target_mask", features)
|
217 |
+
|
218 |
+
self.assertAllClose(features["input_word_ids"].shape,
|
219 |
+
(batch_size, seq_length))
|
220 |
+
self.assertAllClose(features["input_type_ids"].shape,
|
221 |
+
(batch_size, seq_length))
|
222 |
+
self.assertAllClose(features["permutation_mask"].shape,
|
223 |
+
(batch_size, seq_length, seq_length))
|
224 |
+
self.assertAllClose(features["masked_tokens"].shape,
|
225 |
+
(batch_size, seq_length,))
|
226 |
+
if max_predictions_per_seq is not None:
|
227 |
+
self.assertIn("target_mapping", features)
|
228 |
+
self.assertAllClose(features["target_mapping"].shape,
|
229 |
+
(batch_size, max_predictions_per_seq, seq_length))
|
230 |
+
self.assertAllClose(features["target_mask"].shape,
|
231 |
+
(batch_size, max_predictions_per_seq))
|
232 |
+
self.assertAllClose(features["target"].shape,
|
233 |
+
(batch_size, max_predictions_per_seq))
|
234 |
+
else:
|
235 |
+
self.assertAllClose(features["target_mask"].shape,
|
236 |
+
(batch_size, seq_length))
|
237 |
+
self.assertAllClose(features["target"].shape,
|
238 |
+
(batch_size, seq_length))
|
239 |
+
|
240 |
+
|
241 |
+
if __name__ == "__main__":
|
242 |
+
tf.test.main()
|
pretrain_dynamic_dataloader.py
ADDED
@@ -0,0 +1,223 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
"""Dataset loader for the pre-training with dynamic sequence length."""
|
16 |
+
from typing import Optional, Tuple
|
17 |
+
|
18 |
+
import dataclasses
|
19 |
+
import tensorflow as tf, tf_keras
|
20 |
+
|
21 |
+
from official.core import config_definitions as cfg
|
22 |
+
from official.core import input_reader
|
23 |
+
from official.nlp.data import data_loader_factory
|
24 |
+
from official.nlp.data import pretrain_dataloader
|
25 |
+
|
26 |
+
|
27 |
+
@dataclasses.dataclass
|
28 |
+
class BertPretrainDataConfig(cfg.DataConfig):
|
29 |
+
"""Data config for BERT pretraining task (tasks/masked_lm)."""
|
30 |
+
input_path: str = ''
|
31 |
+
global_batch_size: int = 512
|
32 |
+
is_training: bool = True
|
33 |
+
seq_bucket_lengths: Tuple[int, ...] = (128, 256, 384, 512,)
|
34 |
+
# TODO(rxsang): `seq_bucket_window_scale` is only useful when round robin
|
35 |
+
# tf.data service is disabled. Deprecate this flag once we always enable round
|
36 |
+
# robin tf.data service.
|
37 |
+
seq_bucket_window_scale: int = 8
|
38 |
+
use_next_sentence_label: bool = True
|
39 |
+
use_position_id: bool = False
|
40 |
+
deterministic: bool = False
|
41 |
+
enable_tf_data_service: bool = False
|
42 |
+
enable_round_robin_tf_data_service: bool = False
|
43 |
+
tf_data_service_job_name: str = 'bert_pretrain'
|
44 |
+
use_v2_feature_names: bool = False
|
45 |
+
|
46 |
+
|
47 |
+
@data_loader_factory.register_data_loader_cls(BertPretrainDataConfig)
|
48 |
+
class PretrainingDynamicDataLoader(pretrain_dataloader.BertPretrainDataLoader):
|
49 |
+
"""Dataset loader for bert-style pretraining with dynamic sequenece length.
|
50 |
+
|
51 |
+
Bucketizes the input id features by the seq_bucket_lengths and features are
|
52 |
+
padded to the bucket boundaries. The mask features are usually short than
|
53 |
+
input id features and can also be dynamic. We require the mask feature lengths
|
54 |
+
within a bucket must be the same. For example, with [128, 256] buckets,
|
55 |
+
the mask features for bucket 128 should always have the length as X and
|
56 |
+
features for bucket 256 should always have the length as Y.
|
57 |
+
|
58 |
+
The dataloader does not filter out empty masks. Make sure to handle this
|
59 |
+
in the model.
|
60 |
+
"""
|
61 |
+
|
62 |
+
def __init__(self, params):
|
63 |
+
self._params = params
|
64 |
+
if len(params.seq_bucket_lengths) < 1:
|
65 |
+
raise ValueError('The seq_bucket_lengths cannot be empty.')
|
66 |
+
self._seq_bucket_lengths = params.seq_bucket_lengths
|
67 |
+
self._seq_bucket_window_scale = params.seq_bucket_window_scale
|
68 |
+
self._global_batch_size = params.global_batch_size
|
69 |
+
self._use_next_sentence_label = params.use_next_sentence_label
|
70 |
+
self._use_position_id = params.use_position_id
|
71 |
+
self._drop_remainder = params.drop_remainder
|
72 |
+
self._enable_tf_data_service = params.enable_tf_data_service
|
73 |
+
self._enable_round_robin_tf_data_service = (
|
74 |
+
params.enable_round_robin_tf_data_service)
|
75 |
+
self._mask_keys = [
|
76 |
+
'masked_lm_positions', 'masked_lm_ids', 'masked_lm_weights'
|
77 |
+
]
|
78 |
+
|
79 |
+
def _decode(self, record: tf.Tensor):
|
80 |
+
"""Decodes a serialized tf.Example."""
|
81 |
+
name_to_features = {
|
82 |
+
'input_mask': tf.io.VarLenFeature(tf.int64),
|
83 |
+
'masked_lm_positions': tf.io.VarLenFeature(tf.int64),
|
84 |
+
'masked_lm_ids': tf.io.VarLenFeature(tf.int64),
|
85 |
+
'masked_lm_weights': tf.io.VarLenFeature(tf.float32),
|
86 |
+
}
|
87 |
+
if self._params.use_v2_feature_names:
|
88 |
+
input_ids_key = 'input_word_ids'
|
89 |
+
segment_key = 'input_type_ids'
|
90 |
+
name_to_features.update({
|
91 |
+
input_ids_key: tf.io.VarLenFeature(tf.int64),
|
92 |
+
segment_key: tf.io.VarLenFeature(tf.int64),
|
93 |
+
})
|
94 |
+
else:
|
95 |
+
input_ids_key = 'input_ids'
|
96 |
+
segment_key = 'segment_ids'
|
97 |
+
name_to_features.update({
|
98 |
+
input_ids_key: tf.io.VarLenFeature(tf.int64),
|
99 |
+
segment_key: tf.io.VarLenFeature(tf.int64),
|
100 |
+
})
|
101 |
+
if self._use_next_sentence_label:
|
102 |
+
name_to_features['next_sentence_labels'] = tf.io.FixedLenFeature([1],
|
103 |
+
tf.int64)
|
104 |
+
dynamic_keys = [input_ids_key, 'input_mask', segment_key]
|
105 |
+
if self._use_position_id:
|
106 |
+
name_to_features['position_ids'] = tf.io.VarLenFeature(tf.int64)
|
107 |
+
dynamic_keys.append('position_ids')
|
108 |
+
|
109 |
+
example = tf.io.parse_single_example(record, name_to_features)
|
110 |
+
for key in dynamic_keys + self._mask_keys:
|
111 |
+
example[key] = tf.sparse.to_dense(example[key])
|
112 |
+
|
113 |
+
# Truncate padded data after the first non pad in the
|
114 |
+
# sequence length dimension.
|
115 |
+
# Pad before the first non pad from the back should not be removed.
|
116 |
+
mask = tf.math.greater(
|
117 |
+
tf.math.cumsum(example[input_ids_key], reverse=True), 0)
|
118 |
+
for key in dynamic_keys:
|
119 |
+
example[key] = tf.boolean_mask(example[key], mask)
|
120 |
+
|
121 |
+
# masked_lm_ids should be 0 padded.
|
122 |
+
# Change mask features to -1 padding so that we can differentiate
|
123 |
+
# padding from data or from bucketizing.
|
124 |
+
mask = tf.math.not_equal(example['masked_lm_ids'], 0)
|
125 |
+
example['masked_lm_ids'] = tf.where(
|
126 |
+
mask, example['masked_lm_ids'],
|
127 |
+
-tf.ones(tf.shape(example['masked_lm_ids']), dtype=example[key].dtype))
|
128 |
+
|
129 |
+
# tf.Example only supports tf.int64, but the TPU only supports tf.int32.
|
130 |
+
# So cast all int64 to int32.
|
131 |
+
# tf.data service uses dataset graph fingerprint to distinguish input
|
132 |
+
# pipeline jobs, thus we sort the keys here to make sure they are generated
|
133 |
+
# in a deterministic order each time the dataset function is traced.
|
134 |
+
for name in sorted(list(example.keys())):
|
135 |
+
t = example[name]
|
136 |
+
if t.dtype == tf.int64:
|
137 |
+
t = tf.cast(t, tf.int32)
|
138 |
+
example[name] = t
|
139 |
+
|
140 |
+
return example
|
141 |
+
|
142 |
+
def _bucketize_and_batch(
|
143 |
+
self,
|
144 |
+
dataset,
|
145 |
+
input_context: Optional[tf.distribute.InputContext] = None):
|
146 |
+
"""Bucketize by sequence length and batch the datasets."""
|
147 |
+
per_replica_batch_size = input_context.get_per_replica_batch_size(
|
148 |
+
self._global_batch_size) if input_context else self._global_batch_size
|
149 |
+
|
150 |
+
def element_length_func(example, seq_len_dim):
|
151 |
+
return tf.shape(example['input_word_ids'])[seq_len_dim]
|
152 |
+
|
153 |
+
bucket_boundaries = [length + 1 for length in self._seq_bucket_lengths]
|
154 |
+
bucket_batch_sizes = [per_replica_batch_size] * (len(bucket_boundaries) + 1)
|
155 |
+
|
156 |
+
# Bucketize and batch the dataset with per replica batch size first.
|
157 |
+
dataset = dataset.apply(
|
158 |
+
tf.data.experimental.bucket_by_sequence_length(
|
159 |
+
lambda example: tf.cast(element_length_func(example, 0), tf.int32),
|
160 |
+
bucket_boundaries,
|
161 |
+
bucket_batch_sizes,
|
162 |
+
pad_to_bucket_boundary=True,
|
163 |
+
drop_remainder=self._drop_remainder))
|
164 |
+
if input_context:
|
165 |
+
window_size = input_context.num_replicas_in_sync
|
166 |
+
if self._enable_tf_data_service and (
|
167 |
+
not self._enable_round_robin_tf_data_service):
|
168 |
+
# If tf.data service is enabled but round-robin behavior is not enabled,
|
169 |
+
# different TPU workers may fetch data from one tf.data service worker
|
170 |
+
# in different speed. We set the window size to be
|
171 |
+
# `seq_bucket_window_scale` larger to leave buffer if some workers are
|
172 |
+
# fetching data faster than others, so all the data within the same
|
173 |
+
# global batch can still have more chances to be in the same bucket.
|
174 |
+
window_size *= self._seq_bucket_window_scale
|
175 |
+
|
176 |
+
# Group `num_replicas_in_sync` batches from same bucket together, so all
|
177 |
+
# replicas can get the same sequence length for one global step.
|
178 |
+
dataset = dataset.apply(
|
179 |
+
tf.data.experimental.group_by_window(
|
180 |
+
key_func=lambda example: tf.cast( # pylint: disable=g-long-lambda
|
181 |
+
element_length_func(example, 1), tf.int64),
|
182 |
+
reduce_func=lambda _, x: tf.data.Dataset.from_tensors(x),
|
183 |
+
window_size=window_size))
|
184 |
+
dataset = dataset.flat_map(lambda x: x)
|
185 |
+
|
186 |
+
def _remove_pads_from_bucketize(features):
|
187 |
+
# All mask features must have the same effective length.
|
188 |
+
# The real masked ids padding token is -1 and 0 comes from
|
189 |
+
# bucket_by_sequence_length.
|
190 |
+
mask = tf.math.not_equal(features['masked_lm_ids'], 0)
|
191 |
+
|
192 |
+
mask_per_example = tf.math.reduce_sum(tf.cast(mask, tf.int32), axis=1)
|
193 |
+
normalized = tf.cast(
|
194 |
+
mask_per_example / tf.math.reduce_max(mask_per_example), tf.int32)
|
195 |
+
assert_op = tf.debugging.assert_equal(
|
196 |
+
tf.math.reduce_sum(normalized), per_replica_batch_size,
|
197 |
+
'Number of non padded mask tokens is not the same for each example '
|
198 |
+
'in the same sequence length.')
|
199 |
+
with tf.control_dependencies([assert_op]):
|
200 |
+
for key in self._mask_keys:
|
201 |
+
features[key] = tf.reshape(
|
202 |
+
tf.boolean_mask(
|
203 |
+
features[key], mask), [per_replica_batch_size, -1])
|
204 |
+
# Revert masked_lm_ids to be 0-padded.
|
205 |
+
mask = tf.math.not_equal(features['masked_lm_ids'], -1)
|
206 |
+
features['masked_lm_ids'] = tf.where(
|
207 |
+
mask, features['masked_lm_ids'],
|
208 |
+
tf.zeros(
|
209 |
+
tf.shape(features['masked_lm_ids']),
|
210 |
+
dtype=features['masked_lm_ids'].dtype))
|
211 |
+
return features
|
212 |
+
|
213 |
+
dataset = dataset.map(_remove_pads_from_bucketize)
|
214 |
+
return dataset
|
215 |
+
|
216 |
+
def load(self, input_context: Optional[tf.distribute.InputContext] = None):
|
217 |
+
"""Returns a tf.dataset.Dataset."""
|
218 |
+
reader = input_reader.InputReader(
|
219 |
+
params=self._params,
|
220 |
+
decoder_fn=self._decode,
|
221 |
+
parser_fn=self._parse,
|
222 |
+
transform_and_batch_fn=self._bucketize_and_batch)
|
223 |
+
return reader.read(input_context)
|
pretrain_dynamic_dataloader_test.py
ADDED
@@ -0,0 +1,245 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
"""Tests for nlp.data.pretrain_dynamic_dataloader."""
|
16 |
+
import os
|
17 |
+
|
18 |
+
from absl import logging
|
19 |
+
from absl.testing import parameterized
|
20 |
+
import numpy as np
|
21 |
+
import orbit
|
22 |
+
import tensorflow as tf, tf_keras
|
23 |
+
|
24 |
+
from tensorflow.python.distribute import combinations
|
25 |
+
from tensorflow.python.distribute import strategy_combinations
|
26 |
+
from official.nlp.configs import bert
|
27 |
+
from official.nlp.configs import encoders
|
28 |
+
from official.nlp.data import pretrain_dataloader
|
29 |
+
from official.nlp.data import pretrain_dynamic_dataloader
|
30 |
+
from official.nlp.tasks import masked_lm
|
31 |
+
|
32 |
+
|
33 |
+
def _create_fake_dataset(output_path, seq_length, num_masked_tokens,
|
34 |
+
max_seq_length, num_examples):
|
35 |
+
"""Creates a fake dataset."""
|
36 |
+
writer = tf.io.TFRecordWriter(output_path)
|
37 |
+
|
38 |
+
def create_int_feature(values):
|
39 |
+
f = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values)))
|
40 |
+
return f
|
41 |
+
|
42 |
+
def create_float_feature(values):
|
43 |
+
f = tf.train.Feature(float_list=tf.train.FloatList(value=list(values)))
|
44 |
+
return f
|
45 |
+
|
46 |
+
rng = np.random.default_rng(37)
|
47 |
+
for _ in range(num_examples):
|
48 |
+
features = {}
|
49 |
+
padding = np.zeros(shape=(max_seq_length - seq_length), dtype=np.int32)
|
50 |
+
input_ids = rng.integers(low=1, high=100, size=(seq_length))
|
51 |
+
features['input_ids'] = create_int_feature(
|
52 |
+
np.concatenate((input_ids, padding)))
|
53 |
+
features['input_mask'] = create_int_feature(
|
54 |
+
np.concatenate((np.ones_like(input_ids), padding)))
|
55 |
+
features['segment_ids'] = create_int_feature(
|
56 |
+
np.concatenate((np.ones_like(input_ids), padding)))
|
57 |
+
features['position_ids'] = create_int_feature(
|
58 |
+
np.concatenate((np.ones_like(input_ids), padding)))
|
59 |
+
features['masked_lm_positions'] = create_int_feature(
|
60 |
+
rng.integers(60, size=(num_masked_tokens), dtype=np.int64))
|
61 |
+
features['masked_lm_ids'] = create_int_feature(
|
62 |
+
rng.integers(100, size=(num_masked_tokens), dtype=np.int64))
|
63 |
+
features['masked_lm_weights'] = create_float_feature(
|
64 |
+
np.ones((num_masked_tokens,), dtype=np.float32))
|
65 |
+
features['next_sentence_labels'] = create_int_feature(np.array([0]))
|
66 |
+
|
67 |
+
tf_example = tf.train.Example(features=tf.train.Features(feature=features))
|
68 |
+
writer.write(tf_example.SerializeToString())
|
69 |
+
writer.close()
|
70 |
+
|
71 |
+
|
72 |
+
class PretrainDynamicDataLoaderTest(tf.test.TestCase, parameterized.TestCase):
|
73 |
+
|
74 |
+
@combinations.generate(
|
75 |
+
combinations.combine(
|
76 |
+
distribution_strategy=[
|
77 |
+
strategy_combinations.cloud_tpu_strategy,
|
78 |
+
],
|
79 |
+
mode='eager'))
|
80 |
+
def test_distribution_strategy(self, distribution_strategy):
|
81 |
+
max_seq_length = 128
|
82 |
+
batch_size = 8
|
83 |
+
input_path = os.path.join(self.get_temp_dir(), 'train.tf_record')
|
84 |
+
_create_fake_dataset(
|
85 |
+
input_path,
|
86 |
+
seq_length=60,
|
87 |
+
num_masked_tokens=20,
|
88 |
+
max_seq_length=max_seq_length,
|
89 |
+
num_examples=batch_size)
|
90 |
+
data_config = pretrain_dynamic_dataloader.BertPretrainDataConfig(
|
91 |
+
is_training=False,
|
92 |
+
input_path=input_path,
|
93 |
+
seq_bucket_lengths=[64, 128],
|
94 |
+
global_batch_size=batch_size)
|
95 |
+
dataloader = pretrain_dynamic_dataloader.PretrainingDynamicDataLoader(
|
96 |
+
data_config)
|
97 |
+
distributed_ds = orbit.utils.make_distributed_dataset(
|
98 |
+
distribution_strategy, dataloader.load)
|
99 |
+
train_iter = iter(distributed_ds)
|
100 |
+
with distribution_strategy.scope():
|
101 |
+
config = masked_lm.MaskedLMConfig(
|
102 |
+
init_checkpoint=self.get_temp_dir(),
|
103 |
+
model=bert.PretrainerConfig(
|
104 |
+
encoders.EncoderConfig(
|
105 |
+
bert=encoders.BertEncoderConfig(
|
106 |
+
vocab_size=30522, num_layers=1)),
|
107 |
+
cls_heads=[
|
108 |
+
bert.ClsHeadConfig(
|
109 |
+
inner_dim=10, num_classes=2, name='next_sentence')
|
110 |
+
]),
|
111 |
+
train_data=data_config)
|
112 |
+
task = masked_lm.MaskedLMTask(config)
|
113 |
+
model = task.build_model()
|
114 |
+
metrics = task.build_metrics()
|
115 |
+
|
116 |
+
@tf.function
|
117 |
+
def step_fn(features):
|
118 |
+
return task.validation_step(features, model, metrics=metrics)
|
119 |
+
|
120 |
+
distributed_outputs = distribution_strategy.run(
|
121 |
+
step_fn, args=(next(train_iter),))
|
122 |
+
local_results = tf.nest.map_structure(
|
123 |
+
distribution_strategy.experimental_local_results, distributed_outputs)
|
124 |
+
logging.info('Dynamic padding: local_results= %s', str(local_results))
|
125 |
+
dynamic_metrics = {}
|
126 |
+
for metric in metrics:
|
127 |
+
dynamic_metrics[metric.name] = metric.result()
|
128 |
+
|
129 |
+
data_config = pretrain_dataloader.BertPretrainDataConfig(
|
130 |
+
is_training=False,
|
131 |
+
input_path=input_path,
|
132 |
+
seq_length=max_seq_length,
|
133 |
+
max_predictions_per_seq=20,
|
134 |
+
global_batch_size=batch_size)
|
135 |
+
dataloader = pretrain_dataloader.BertPretrainDataLoader(data_config)
|
136 |
+
distributed_ds = orbit.utils.make_distributed_dataset(
|
137 |
+
distribution_strategy, dataloader.load)
|
138 |
+
train_iter = iter(distributed_ds)
|
139 |
+
with distribution_strategy.scope():
|
140 |
+
metrics = task.build_metrics()
|
141 |
+
|
142 |
+
@tf.function
|
143 |
+
def step_fn_b(features):
|
144 |
+
return task.validation_step(features, model, metrics=metrics)
|
145 |
+
|
146 |
+
distributed_outputs = distribution_strategy.run(
|
147 |
+
step_fn_b, args=(next(train_iter),))
|
148 |
+
local_results = tf.nest.map_structure(
|
149 |
+
distribution_strategy.experimental_local_results, distributed_outputs)
|
150 |
+
logging.info('Static padding: local_results= %s', str(local_results))
|
151 |
+
static_metrics = {}
|
152 |
+
for metric in metrics:
|
153 |
+
static_metrics[metric.name] = metric.result()
|
154 |
+
for key in static_metrics:
|
155 |
+
# We need to investigate the differences on losses.
|
156 |
+
if key != 'next_sentence_loss':
|
157 |
+
self.assertEqual(dynamic_metrics[key], static_metrics[key])
|
158 |
+
|
159 |
+
def test_load_dataset(self):
|
160 |
+
tf.random.set_seed(0)
|
161 |
+
max_seq_length = 128
|
162 |
+
batch_size = 2
|
163 |
+
input_path_1 = os.path.join(self.get_temp_dir(), 'train_1.tf_record')
|
164 |
+
_create_fake_dataset(
|
165 |
+
input_path_1,
|
166 |
+
seq_length=60,
|
167 |
+
num_masked_tokens=20,
|
168 |
+
max_seq_length=max_seq_length,
|
169 |
+
num_examples=batch_size)
|
170 |
+
input_path_2 = os.path.join(self.get_temp_dir(), 'train_2.tf_record')
|
171 |
+
_create_fake_dataset(
|
172 |
+
input_path_2,
|
173 |
+
seq_length=100,
|
174 |
+
num_masked_tokens=70,
|
175 |
+
max_seq_length=max_seq_length,
|
176 |
+
num_examples=batch_size)
|
177 |
+
input_paths = ','.join([input_path_1, input_path_2])
|
178 |
+
data_config = pretrain_dynamic_dataloader.BertPretrainDataConfig(
|
179 |
+
is_training=False,
|
180 |
+
input_path=input_paths,
|
181 |
+
seq_bucket_lengths=[64, 128],
|
182 |
+
use_position_id=True,
|
183 |
+
global_batch_size=batch_size,
|
184 |
+
deterministic=True)
|
185 |
+
dataset = pretrain_dynamic_dataloader.PretrainingDynamicDataLoader(
|
186 |
+
data_config).load()
|
187 |
+
dataset_it = iter(dataset)
|
188 |
+
features = next(dataset_it)
|
189 |
+
self.assertCountEqual([
|
190 |
+
'input_word_ids',
|
191 |
+
'input_mask',
|
192 |
+
'input_type_ids',
|
193 |
+
'next_sentence_labels',
|
194 |
+
'masked_lm_positions',
|
195 |
+
'masked_lm_ids',
|
196 |
+
'masked_lm_weights',
|
197 |
+
'position_ids',
|
198 |
+
], features.keys())
|
199 |
+
# Sequence length dimension should be bucketized and pad to 64.
|
200 |
+
self.assertEqual(features['input_word_ids'].shape, (batch_size, 64))
|
201 |
+
self.assertEqual(features['input_mask'].shape, (batch_size, 64))
|
202 |
+
self.assertEqual(features['input_type_ids'].shape, (batch_size, 64))
|
203 |
+
self.assertEqual(features['position_ids'].shape, (batch_size, 64))
|
204 |
+
self.assertEqual(features['masked_lm_positions'].shape, (batch_size, 20))
|
205 |
+
features = next(dataset_it)
|
206 |
+
self.assertEqual(features['input_word_ids'].shape, (batch_size, 128))
|
207 |
+
self.assertEqual(features['input_mask'].shape, (batch_size, 128))
|
208 |
+
self.assertEqual(features['input_type_ids'].shape, (batch_size, 128))
|
209 |
+
self.assertEqual(features['position_ids'].shape, (batch_size, 128))
|
210 |
+
self.assertEqual(features['masked_lm_positions'].shape, (batch_size, 70))
|
211 |
+
|
212 |
+
def test_load_dataset_not_same_masks(self):
|
213 |
+
max_seq_length = 128
|
214 |
+
batch_size = 2
|
215 |
+
input_path_1 = os.path.join(self.get_temp_dir(), 'train_3.tf_record')
|
216 |
+
_create_fake_dataset(
|
217 |
+
input_path_1,
|
218 |
+
seq_length=60,
|
219 |
+
num_masked_tokens=20,
|
220 |
+
max_seq_length=max_seq_length,
|
221 |
+
num_examples=batch_size)
|
222 |
+
input_path_2 = os.path.join(self.get_temp_dir(), 'train_4.tf_record')
|
223 |
+
_create_fake_dataset(
|
224 |
+
input_path_2,
|
225 |
+
seq_length=60,
|
226 |
+
num_masked_tokens=15,
|
227 |
+
max_seq_length=max_seq_length,
|
228 |
+
num_examples=batch_size)
|
229 |
+
input_paths = ','.join([input_path_1, input_path_2])
|
230 |
+
data_config = pretrain_dynamic_dataloader.BertPretrainDataConfig(
|
231 |
+
is_training=False,
|
232 |
+
input_path=input_paths,
|
233 |
+
seq_bucket_lengths=[64, 128],
|
234 |
+
use_position_id=True,
|
235 |
+
global_batch_size=batch_size * 2)
|
236 |
+
dataset = pretrain_dynamic_dataloader.PretrainingDynamicDataLoader(
|
237 |
+
data_config).load()
|
238 |
+
dataset_it = iter(dataset)
|
239 |
+
with self.assertRaisesRegex(
|
240 |
+
tf.errors.InvalidArgumentError, '.*Number of non padded mask tokens.*'):
|
241 |
+
next(dataset_it)
|
242 |
+
|
243 |
+
|
244 |
+
if __name__ == '__main__':
|
245 |
+
tf.test.main()
|
pretrain_text_dataloader.py
ADDED
@@ -0,0 +1,226 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
"""Loads text dataset for the BERT pretraining task."""
|
16 |
+
import dataclasses
|
17 |
+
from typing import List, Mapping, Optional, Text
|
18 |
+
|
19 |
+
import tensorflow as tf, tf_keras
|
20 |
+
import tensorflow_text as tf_text
|
21 |
+
|
22 |
+
from official.common import dataset_fn
|
23 |
+
from official.core import config_definitions as cfg
|
24 |
+
from official.core import input_reader
|
25 |
+
from official.nlp.data import data_loader
|
26 |
+
from official.nlp.data import data_loader_factory
|
27 |
+
from official.nlp.modeling.ops import segment_extractor
|
28 |
+
|
29 |
+
|
30 |
+
@dataclasses.dataclass
|
31 |
+
class BertPretrainTextDataConfig(cfg.DataConfig):
|
32 |
+
"""Data config for BERT pretraining task (tasks/masked_lm) from text."""
|
33 |
+
input_path: str = ""
|
34 |
+
doc_batch_size: int = 8
|
35 |
+
global_batch_size: int = 512
|
36 |
+
is_training: bool = True
|
37 |
+
seq_length: int = 512
|
38 |
+
max_predictions_per_seq: int = 76
|
39 |
+
use_next_sentence_label: bool = True
|
40 |
+
# The name of the text feature fields. The text features will be
|
41 |
+
# concatenated in order.
|
42 |
+
# Note: More than 1 field name is not compatible with NSP.
|
43 |
+
text_field_names: Optional[List[str]] = dataclasses.field(
|
44 |
+
default_factory=lambda: ["text"])
|
45 |
+
vocab_file_path: str = ""
|
46 |
+
masking_rate: float = 0.15
|
47 |
+
use_whole_word_masking: bool = False
|
48 |
+
file_type: str = "tfrecord"
|
49 |
+
|
50 |
+
|
51 |
+
_CLS_TOKEN = b"[CLS]"
|
52 |
+
_SEP_TOKEN = b"[SEP]"
|
53 |
+
_MASK_TOKEN = b"[MASK]"
|
54 |
+
_NUM_OOV_BUCKETS = 1
|
55 |
+
# Accounts for [CLS] and 2 x [SEP] tokens
|
56 |
+
_NUM_SPECIAL_TOKENS = 3
|
57 |
+
|
58 |
+
|
59 |
+
@data_loader_factory.register_data_loader_cls(BertPretrainTextDataConfig)
|
60 |
+
class BertPretrainTextDataLoader(data_loader.DataLoader):
|
61 |
+
"""A class to load text dataset for BERT pretraining task."""
|
62 |
+
|
63 |
+
def __init__(self, params):
|
64 |
+
"""Inits `BertPretrainTextDataLoader` class.
|
65 |
+
|
66 |
+
Args:
|
67 |
+
params: A `BertPretrainTextDataConfig` object.
|
68 |
+
"""
|
69 |
+
if len(params.text_field_names) > 1 and params.use_next_sentence_label:
|
70 |
+
raise ValueError("Currently there is no support for more than text field "
|
71 |
+
"while generating next sentence labels.")
|
72 |
+
|
73 |
+
self._params = params
|
74 |
+
self._seq_length = params.seq_length
|
75 |
+
self._max_predictions_per_seq = params.max_predictions_per_seq
|
76 |
+
self._use_next_sentence_label = params.use_next_sentence_label
|
77 |
+
self._masking_rate = params.masking_rate
|
78 |
+
self._use_whole_word_masking = params.use_whole_word_masking
|
79 |
+
|
80 |
+
lookup_table_init = tf.lookup.TextFileInitializer(
|
81 |
+
params.vocab_file_path,
|
82 |
+
key_dtype=tf.string,
|
83 |
+
key_index=tf.lookup.TextFileIndex.WHOLE_LINE,
|
84 |
+
value_dtype=tf.int64,
|
85 |
+
value_index=tf.lookup.TextFileIndex.LINE_NUMBER)
|
86 |
+
self._vocab_lookup_table = tf.lookup.StaticVocabularyTable(
|
87 |
+
lookup_table_init,
|
88 |
+
num_oov_buckets=_NUM_OOV_BUCKETS,
|
89 |
+
lookup_key_dtype=tf.string)
|
90 |
+
|
91 |
+
self._cls_token = self._vocab_lookup_table.lookup(tf.constant(_CLS_TOKEN))
|
92 |
+
self._sep_token = self._vocab_lookup_table.lookup(tf.constant(_SEP_TOKEN))
|
93 |
+
self._mask_token = self._vocab_lookup_table.lookup(tf.constant(_MASK_TOKEN))
|
94 |
+
|
95 |
+
# -_NUM_OOV_BUCKETS to offset unused OOV bucket.
|
96 |
+
self._vocab_size = self._vocab_lookup_table.size() - _NUM_OOV_BUCKETS
|
97 |
+
|
98 |
+
def _decode(self, record: tf.Tensor) -> Mapping[Text, tf.Tensor]:
|
99 |
+
"""Decodes a serialized tf.Example."""
|
100 |
+
name_to_features = {}
|
101 |
+
for text_field_name in self._params.text_field_names:
|
102 |
+
name_to_features[text_field_name] = tf.io.FixedLenFeature([], tf.string)
|
103 |
+
return tf.io.parse_single_example(record, name_to_features)
|
104 |
+
|
105 |
+
def _tokenize(self, segments):
|
106 |
+
"""Tokenize the input segments."""
|
107 |
+
# Tokenize segments
|
108 |
+
tokenizer = tf_text.BertTokenizer(
|
109 |
+
self._vocab_lookup_table, token_out_type=tf.int64)
|
110 |
+
|
111 |
+
if self._use_whole_word_masking:
|
112 |
+
# tokenize the segments which should have the shape:
|
113 |
+
# [num_sentence, (num_words), (num_wordpieces)]
|
114 |
+
segments = [tokenizer.tokenize(s) for s in segments]
|
115 |
+
else:
|
116 |
+
# tokenize the segments and merge out the token dimension so that each
|
117 |
+
# segment has the shape: [num_sentence, (num_wordpieces)]
|
118 |
+
segments = [tokenizer.tokenize(s).merge_dims(-2, -1) for s in segments]
|
119 |
+
|
120 |
+
# Truncate inputs
|
121 |
+
trimmer = tf_text.WaterfallTrimmer(
|
122 |
+
self._seq_length - _NUM_SPECIAL_TOKENS, axis=-1)
|
123 |
+
truncated_segments = trimmer.trim(segments)
|
124 |
+
|
125 |
+
# Combine segments, get segment ids and add special tokens
|
126 |
+
return tf_text.combine_segments(
|
127 |
+
truncated_segments,
|
128 |
+
start_of_sequence_id=self._cls_token,
|
129 |
+
end_of_segment_id=self._sep_token)
|
130 |
+
|
131 |
+
def _bert_preprocess(self, record: Mapping[str, tf.Tensor]):
|
132 |
+
"""Parses raw tensors into a dict of tensors to be consumed by the model."""
|
133 |
+
if self._use_next_sentence_label:
|
134 |
+
input_text = record[self._params.text_field_names[0]]
|
135 |
+
# Split sentences
|
136 |
+
sentence_breaker = tf_text.RegexSplitter()
|
137 |
+
sentences = sentence_breaker.split(input_text)
|
138 |
+
|
139 |
+
# Extract next-sentence-prediction labels and segments
|
140 |
+
next_or_random_segment, is_next = (
|
141 |
+
segment_extractor.get_next_sentence_labels(sentences))
|
142 |
+
# merge dims to change shape from [num_docs, (num_segments)] to
|
143 |
+
# [total_num_segments]
|
144 |
+
is_next = is_next.merge_dims(-2, -1)
|
145 |
+
|
146 |
+
# construct segments with shape [(num_sentence)]
|
147 |
+
segments = [
|
148 |
+
sentences.merge_dims(-2, -1),
|
149 |
+
next_or_random_segment.merge_dims(-2, -1)
|
150 |
+
]
|
151 |
+
else:
|
152 |
+
segments = [record[name] for name in self._params.text_field_names]
|
153 |
+
|
154 |
+
segments_combined, segment_ids = self._tokenize(segments)
|
155 |
+
|
156 |
+
# Dynamic masking
|
157 |
+
item_selector = tf_text.RandomItemSelector(
|
158 |
+
self._max_predictions_per_seq,
|
159 |
+
selection_rate=self._masking_rate,
|
160 |
+
unselectable_ids=[self._cls_token, self._sep_token],
|
161 |
+
shuffle_fn=(tf.identity if self._params.deterministic else None))
|
162 |
+
values_chooser = tf_text.MaskValuesChooser(
|
163 |
+
vocab_size=self._vocab_size, mask_token=self._mask_token)
|
164 |
+
masked_input_ids, masked_lm_positions, masked_lm_ids = (
|
165 |
+
tf_text.mask_language_model(
|
166 |
+
segments_combined,
|
167 |
+
item_selector=item_selector,
|
168 |
+
mask_values_chooser=values_chooser,
|
169 |
+
))
|
170 |
+
|
171 |
+
# Pad out to fixed shape and get input mask.
|
172 |
+
seq_lengths = {
|
173 |
+
"input_word_ids": self._seq_length,
|
174 |
+
"input_type_ids": self._seq_length,
|
175 |
+
"masked_lm_positions": self._max_predictions_per_seq,
|
176 |
+
"masked_lm_ids": self._max_predictions_per_seq,
|
177 |
+
}
|
178 |
+
model_inputs = {
|
179 |
+
"input_word_ids": masked_input_ids,
|
180 |
+
"input_type_ids": segment_ids,
|
181 |
+
"masked_lm_positions": masked_lm_positions,
|
182 |
+
"masked_lm_ids": masked_lm_ids,
|
183 |
+
}
|
184 |
+
padded_inputs_and_mask = tf.nest.map_structure(tf_text.pad_model_inputs,
|
185 |
+
model_inputs, seq_lengths)
|
186 |
+
model_inputs = {
|
187 |
+
k: padded_inputs_and_mask[k][0] for k in padded_inputs_and_mask
|
188 |
+
}
|
189 |
+
model_inputs["masked_lm_weights"] = tf.cast(
|
190 |
+
padded_inputs_and_mask["masked_lm_ids"][1], tf.float32)
|
191 |
+
model_inputs["input_mask"] = padded_inputs_and_mask["input_word_ids"][1]
|
192 |
+
|
193 |
+
if self._use_next_sentence_label:
|
194 |
+
model_inputs["next_sentence_labels"] = is_next
|
195 |
+
|
196 |
+
for name in model_inputs:
|
197 |
+
t = model_inputs[name]
|
198 |
+
if t.dtype == tf.int64:
|
199 |
+
t = tf.cast(t, tf.int32)
|
200 |
+
model_inputs[name] = t
|
201 |
+
|
202 |
+
return model_inputs
|
203 |
+
|
204 |
+
def load(self, input_context: Optional[tf.distribute.InputContext] = None):
|
205 |
+
"""Returns a tf.dataset.Dataset."""
|
206 |
+
|
207 |
+
def _batch_docs(dataset, input_context):
|
208 |
+
per_core_doc_batch_size = (
|
209 |
+
input_context.get_per_replica_batch_size(self._params.doc_batch_size)
|
210 |
+
if input_context else self._params.doc_batch_size)
|
211 |
+
return dataset.batch(per_core_doc_batch_size)
|
212 |
+
|
213 |
+
reader = input_reader.InputReader(
|
214 |
+
params=self._params,
|
215 |
+
dataset_fn=dataset_fn.pick_dataset_fn(self._params.file_type),
|
216 |
+
decoder_fn=self._decode if self._params.input_path else None,
|
217 |
+
transform_and_batch_fn=_batch_docs
|
218 |
+
if self._use_next_sentence_label else None,
|
219 |
+
postprocess_fn=self._bert_preprocess)
|
220 |
+
transformed_inputs = reader.read(input_context)
|
221 |
+
per_core_example_batch_size = (
|
222 |
+
input_context.get_per_replica_batch_size(self._params.global_batch_size)
|
223 |
+
if input_context else self._params.global_batch_size)
|
224 |
+
batched_inputs = transformed_inputs.unbatch().batch(
|
225 |
+
per_core_example_batch_size, self._params.drop_remainder)
|
226 |
+
return batched_inputs.prefetch(tf.data.experimental.AUTOTUNE)
|
question_answering_dataloader.py
ADDED
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
"""Loads dataset for the question answering (e.g, SQuAD) task."""
|
16 |
+
import dataclasses
|
17 |
+
from typing import Mapping, Optional
|
18 |
+
|
19 |
+
import tensorflow as tf, tf_keras
|
20 |
+
from official.common import dataset_fn
|
21 |
+
from official.core import config_definitions as cfg
|
22 |
+
from official.core import input_reader
|
23 |
+
from official.nlp.data import data_loader
|
24 |
+
from official.nlp.data import data_loader_factory
|
25 |
+
|
26 |
+
|
27 |
+
@dataclasses.dataclass
|
28 |
+
class QADataConfig(cfg.DataConfig):
|
29 |
+
"""Data config for question answering task (tasks/question_answering)."""
|
30 |
+
# For training, `input_path` is expected to be a pre-processed TFRecord file,
|
31 |
+
# while for evaluation, it is expected to be a raw JSON file (b/173814590).
|
32 |
+
input_path: str = ''
|
33 |
+
global_batch_size: int = 48
|
34 |
+
is_training: bool = True
|
35 |
+
seq_length: int = 384
|
36 |
+
# Settings below are question answering specific.
|
37 |
+
version_2_with_negative: bool = False
|
38 |
+
# Settings below are only used for eval mode.
|
39 |
+
input_preprocessed_data_path: str = ''
|
40 |
+
doc_stride: int = 128
|
41 |
+
query_length: int = 64
|
42 |
+
# The path to the vocab file of word piece tokenizer or the
|
43 |
+
# model of the sentence piece tokenizer.
|
44 |
+
vocab_file: str = ''
|
45 |
+
tokenization: str = 'WordPiece' # WordPiece or SentencePiece
|
46 |
+
do_lower_case: bool = True
|
47 |
+
xlnet_format: bool = False
|
48 |
+
file_type: str = 'tfrecord'
|
49 |
+
|
50 |
+
|
51 |
+
@data_loader_factory.register_data_loader_cls(QADataConfig)
|
52 |
+
class QuestionAnsweringDataLoader(data_loader.DataLoader):
|
53 |
+
"""A class to load dataset for sentence prediction (classification) task."""
|
54 |
+
|
55 |
+
def __init__(self, params):
|
56 |
+
self._params = params
|
57 |
+
self._seq_length = params.seq_length
|
58 |
+
self._is_training = params.is_training
|
59 |
+
self._xlnet_format = params.xlnet_format
|
60 |
+
|
61 |
+
def _decode(self, record: tf.Tensor):
|
62 |
+
"""Decodes a serialized tf.Example."""
|
63 |
+
name_to_features = {
|
64 |
+
'input_ids': tf.io.FixedLenFeature([self._seq_length], tf.int64),
|
65 |
+
'input_mask': tf.io.FixedLenFeature([self._seq_length], tf.int64),
|
66 |
+
'segment_ids': tf.io.FixedLenFeature([self._seq_length], tf.int64),
|
67 |
+
}
|
68 |
+
if self._xlnet_format:
|
69 |
+
name_to_features['class_index'] = tf.io.FixedLenFeature([], tf.int64)
|
70 |
+
name_to_features['paragraph_mask'] = tf.io.FixedLenFeature(
|
71 |
+
[self._seq_length], tf.int64)
|
72 |
+
if self._is_training:
|
73 |
+
name_to_features['is_impossible'] = tf.io.FixedLenFeature([], tf.int64)
|
74 |
+
|
75 |
+
if self._is_training:
|
76 |
+
name_to_features['start_positions'] = tf.io.FixedLenFeature([], tf.int64)
|
77 |
+
name_to_features['end_positions'] = tf.io.FixedLenFeature([], tf.int64)
|
78 |
+
else:
|
79 |
+
name_to_features['unique_ids'] = tf.io.FixedLenFeature([], tf.int64)
|
80 |
+
example = tf.io.parse_single_example(record, name_to_features)
|
81 |
+
|
82 |
+
# tf.Example only supports tf.int64, but the TPU only supports tf.int32.
|
83 |
+
# So cast all int64 to int32.
|
84 |
+
for name in example:
|
85 |
+
t = example[name]
|
86 |
+
if t.dtype == tf.int64:
|
87 |
+
t = tf.cast(t, tf.int32)
|
88 |
+
example[name] = t
|
89 |
+
|
90 |
+
return example
|
91 |
+
|
92 |
+
def _parse(self, record: Mapping[str, tf.Tensor]):
|
93 |
+
"""Parses raw tensors into a dict of tensors to be consumed by the model."""
|
94 |
+
x, y = {}, {}
|
95 |
+
for name, tensor in record.items():
|
96 |
+
if name in ('start_positions', 'end_positions', 'is_impossible'):
|
97 |
+
y[name] = tensor
|
98 |
+
elif name == 'input_ids':
|
99 |
+
x['input_word_ids'] = tensor
|
100 |
+
elif name == 'segment_ids':
|
101 |
+
x['input_type_ids'] = tensor
|
102 |
+
else:
|
103 |
+
x[name] = tensor
|
104 |
+
if name == 'start_positions' and self._xlnet_format:
|
105 |
+
x[name] = tensor
|
106 |
+
return (x, y)
|
107 |
+
|
108 |
+
def load(self, input_context: Optional[tf.distribute.InputContext] = None):
|
109 |
+
"""Returns a tf.dataset.Dataset."""
|
110 |
+
reader = input_reader.InputReader(
|
111 |
+
params=self._params,
|
112 |
+
dataset_fn=dataset_fn.pick_dataset_fn(self._params.file_type),
|
113 |
+
decoder_fn=self._decode,
|
114 |
+
parser_fn=self._parse)
|
115 |
+
return reader.read(input_context)
|
question_answering_dataloader_test.py
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
"""Tests for official.nlp.data.question_answering_dataloader."""
|
16 |
+
import os
|
17 |
+
|
18 |
+
import numpy as np
|
19 |
+
import tensorflow as tf, tf_keras
|
20 |
+
|
21 |
+
from official.nlp.data import question_answering_dataloader
|
22 |
+
|
23 |
+
|
24 |
+
def _create_fake_dataset(output_path, seq_length):
|
25 |
+
"""Creates a fake dataset."""
|
26 |
+
writer = tf.io.TFRecordWriter(output_path)
|
27 |
+
|
28 |
+
def create_int_feature(values):
|
29 |
+
f = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values)))
|
30 |
+
return f
|
31 |
+
|
32 |
+
for _ in range(100):
|
33 |
+
features = {}
|
34 |
+
input_ids = np.random.randint(100, size=(seq_length))
|
35 |
+
features['input_ids'] = create_int_feature(input_ids)
|
36 |
+
features['input_mask'] = create_int_feature(np.ones_like(input_ids))
|
37 |
+
features['segment_ids'] = create_int_feature(np.ones_like(input_ids))
|
38 |
+
features['start_positions'] = create_int_feature(np.array([0]))
|
39 |
+
features['end_positions'] = create_int_feature(np.array([10]))
|
40 |
+
|
41 |
+
tf_example = tf.train.Example(features=tf.train.Features(feature=features))
|
42 |
+
writer.write(tf_example.SerializeToString())
|
43 |
+
writer.close()
|
44 |
+
|
45 |
+
|
46 |
+
class QuestionAnsweringDataTest(tf.test.TestCase):
|
47 |
+
|
48 |
+
def test_load_dataset(self):
|
49 |
+
seq_length = 128
|
50 |
+
batch_size = 10
|
51 |
+
input_path = os.path.join(self.get_temp_dir(), 'train.tf_record')
|
52 |
+
_create_fake_dataset(input_path, seq_length)
|
53 |
+
data_config = question_answering_dataloader.QADataConfig(
|
54 |
+
is_training=True,
|
55 |
+
input_path=input_path,
|
56 |
+
seq_length=seq_length,
|
57 |
+
global_batch_size=batch_size)
|
58 |
+
dataset = question_answering_dataloader.QuestionAnsweringDataLoader(
|
59 |
+
data_config).load()
|
60 |
+
features, labels = next(iter(dataset))
|
61 |
+
|
62 |
+
self.assertCountEqual(['input_word_ids', 'input_mask', 'input_type_ids'],
|
63 |
+
features.keys())
|
64 |
+
self.assertEqual(features['input_word_ids'].shape, (batch_size, seq_length))
|
65 |
+
self.assertEqual(features['input_mask'].shape, (batch_size, seq_length))
|
66 |
+
self.assertEqual(features['input_type_ids'].shape, (batch_size, seq_length))
|
67 |
+
|
68 |
+
self.assertCountEqual(['start_positions', 'end_positions'], labels.keys())
|
69 |
+
self.assertEqual(labels['start_positions'].shape, (batch_size,))
|
70 |
+
self.assertEqual(labels['end_positions'].shape, (batch_size,))
|
71 |
+
|
72 |
+
|
73 |
+
if __name__ == '__main__':
|
74 |
+
tf.test.main()
|
sentence_prediction_dataloader.py
ADDED
@@ -0,0 +1,267 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
"""Loads dataset for the sentence prediction (classification) task."""
|
16 |
+
import dataclasses
|
17 |
+
import functools
|
18 |
+
from typing import List, Mapping, Optional, Tuple
|
19 |
+
|
20 |
+
import tensorflow as tf, tf_keras
|
21 |
+
import tensorflow_hub as hub
|
22 |
+
|
23 |
+
from official.common import dataset_fn
|
24 |
+
from official.core import config_definitions as cfg
|
25 |
+
from official.core import input_reader
|
26 |
+
from official.nlp import modeling
|
27 |
+
from official.nlp.data import data_loader
|
28 |
+
from official.nlp.data import data_loader_factory
|
29 |
+
|
30 |
+
LABEL_TYPES_MAP = {'int': tf.int64, 'float': tf.float32}
|
31 |
+
|
32 |
+
|
33 |
+
@dataclasses.dataclass
|
34 |
+
class SentencePredictionDataConfig(cfg.DataConfig):
|
35 |
+
"""Data config for sentence prediction task (tasks/sentence_prediction)."""
|
36 |
+
input_path: str = ''
|
37 |
+
global_batch_size: int = 32
|
38 |
+
is_training: bool = True
|
39 |
+
seq_length: int = 128
|
40 |
+
label_type: str = 'int'
|
41 |
+
# Whether to include the example id number.
|
42 |
+
include_example_id: bool = False
|
43 |
+
label_field: str = 'label_ids'
|
44 |
+
# Maps the key in TfExample to feature name.
|
45 |
+
# E.g 'label_ids' to 'next_sentence_labels'
|
46 |
+
label_name: Optional[Tuple[str, str]] = None
|
47 |
+
# Either tfrecord, sstable, or recordio.
|
48 |
+
file_type: str = 'tfrecord'
|
49 |
+
|
50 |
+
|
51 |
+
@data_loader_factory.register_data_loader_cls(SentencePredictionDataConfig)
|
52 |
+
class SentencePredictionDataLoader(data_loader.DataLoader):
|
53 |
+
"""A class to load dataset for sentence prediction (classification) task."""
|
54 |
+
|
55 |
+
def __init__(self, params):
|
56 |
+
self._params = params
|
57 |
+
self._seq_length = params.seq_length
|
58 |
+
self._include_example_id = params.include_example_id
|
59 |
+
self._label_field = params.label_field
|
60 |
+
if params.label_name:
|
61 |
+
self._label_name_mapping = dict([params.label_name])
|
62 |
+
else:
|
63 |
+
self._label_name_mapping = dict()
|
64 |
+
|
65 |
+
def name_to_features_spec(self):
|
66 |
+
"""Defines features to decode. Subclass may override to append features."""
|
67 |
+
label_type = LABEL_TYPES_MAP[self._params.label_type]
|
68 |
+
name_to_features = {
|
69 |
+
'input_ids': tf.io.FixedLenFeature([self._seq_length], tf.int64),
|
70 |
+
'input_mask': tf.io.FixedLenFeature([self._seq_length], tf.int64),
|
71 |
+
'segment_ids': tf.io.FixedLenFeature([self._seq_length], tf.int64),
|
72 |
+
self._label_field: tf.io.FixedLenFeature([], label_type),
|
73 |
+
}
|
74 |
+
if self._include_example_id:
|
75 |
+
name_to_features['example_id'] = tf.io.FixedLenFeature([], tf.int64)
|
76 |
+
|
77 |
+
return name_to_features
|
78 |
+
|
79 |
+
def _decode(self, record: tf.Tensor):
|
80 |
+
"""Decodes a serialized tf.Example."""
|
81 |
+
example = tf.io.parse_single_example(record, self.name_to_features_spec())
|
82 |
+
|
83 |
+
# tf.Example only supports tf.int64, but the TPU only supports tf.int32.
|
84 |
+
# So cast all int64 to int32.
|
85 |
+
for name in example:
|
86 |
+
t = example[name]
|
87 |
+
if t.dtype == tf.int64:
|
88 |
+
t = tf.cast(t, tf.int32)
|
89 |
+
example[name] = t
|
90 |
+
|
91 |
+
return example
|
92 |
+
|
93 |
+
def _parse(self, record: Mapping[str, tf.Tensor]):
|
94 |
+
"""Parses raw tensors into a dict of tensors to be consumed by the model."""
|
95 |
+
key_mapping = {
|
96 |
+
'input_ids': 'input_word_ids',
|
97 |
+
'input_mask': 'input_mask',
|
98 |
+
'segment_ids': 'input_type_ids'
|
99 |
+
}
|
100 |
+
ret = {}
|
101 |
+
for record_key in record:
|
102 |
+
if record_key in key_mapping:
|
103 |
+
ret[key_mapping[record_key]] = record[record_key]
|
104 |
+
else:
|
105 |
+
ret[record_key] = record[record_key]
|
106 |
+
|
107 |
+
if self._label_field in self._label_name_mapping:
|
108 |
+
ret[self._label_name_mapping[self._label_field]] = record[
|
109 |
+
self._label_field]
|
110 |
+
|
111 |
+
return ret
|
112 |
+
|
113 |
+
def load(self, input_context: Optional[tf.distribute.InputContext] = None):
|
114 |
+
"""Returns a tf.dataset.Dataset."""
|
115 |
+
reader = input_reader.InputReader(
|
116 |
+
dataset_fn=dataset_fn.pick_dataset_fn(self._params.file_type),
|
117 |
+
params=self._params,
|
118 |
+
decoder_fn=self._decode,
|
119 |
+
parser_fn=self._parse)
|
120 |
+
return reader.read(input_context)
|
121 |
+
|
122 |
+
|
123 |
+
@dataclasses.dataclass
|
124 |
+
class SentencePredictionTextDataConfig(cfg.DataConfig):
|
125 |
+
"""Data config for sentence prediction task with raw text."""
|
126 |
+
# Either set `input_path`...
|
127 |
+
input_path: str = ''
|
128 |
+
# Either `int` or `float`.
|
129 |
+
label_type: str = 'int'
|
130 |
+
# ...or `tfds_name` and `tfds_split` to specify input.
|
131 |
+
tfds_name: str = ''
|
132 |
+
tfds_split: str = ''
|
133 |
+
# The name of the text feature fields. The text features will be
|
134 |
+
# concatenated in order.
|
135 |
+
text_fields: Optional[List[str]] = None
|
136 |
+
label_field: str = 'label'
|
137 |
+
global_batch_size: int = 32
|
138 |
+
seq_length: int = 128
|
139 |
+
is_training: bool = True
|
140 |
+
# Either build preprocessing with Python code by specifying these values
|
141 |
+
# for modeling.layers.BertTokenizer()/SentencepieceTokenizer()....
|
142 |
+
tokenization: str = 'WordPiece' # WordPiece or SentencePiece
|
143 |
+
# Text vocab file if tokenization is WordPiece, or sentencepiece.ModelProto
|
144 |
+
# file if tokenization is SentencePiece.
|
145 |
+
vocab_file: str = ''
|
146 |
+
lower_case: bool = True
|
147 |
+
# ...or load preprocessing from a SavedModel at this location.
|
148 |
+
preprocessing_hub_module_url: str = ''
|
149 |
+
# Either tfrecord or sstsable or recordio.
|
150 |
+
file_type: str = 'tfrecord'
|
151 |
+
include_example_id: bool = False
|
152 |
+
|
153 |
+
|
154 |
+
class TextProcessor(tf.Module):
|
155 |
+
"""Text features processing for sentence prediction task."""
|
156 |
+
|
157 |
+
def __init__(self,
|
158 |
+
seq_length: int,
|
159 |
+
vocab_file: Optional[str] = None,
|
160 |
+
tokenization: Optional[str] = None,
|
161 |
+
lower_case: Optional[bool] = True,
|
162 |
+
preprocessing_hub_module_url: Optional[str] = None):
|
163 |
+
if preprocessing_hub_module_url:
|
164 |
+
self._preprocessing_hub_module = hub.load(preprocessing_hub_module_url)
|
165 |
+
self._tokenizer = self._preprocessing_hub_module.tokenize
|
166 |
+
self._pack_inputs = functools.partial(
|
167 |
+
self._preprocessing_hub_module.bert_pack_inputs,
|
168 |
+
seq_length=seq_length)
|
169 |
+
return
|
170 |
+
|
171 |
+
if tokenization == 'WordPiece':
|
172 |
+
self._tokenizer = modeling.layers.BertTokenizer(
|
173 |
+
vocab_file=vocab_file, lower_case=lower_case)
|
174 |
+
elif tokenization == 'SentencePiece':
|
175 |
+
self._tokenizer = modeling.layers.SentencepieceTokenizer(
|
176 |
+
model_file_path=vocab_file,
|
177 |
+
lower_case=lower_case,
|
178 |
+
strip_diacritics=True) # Strip diacritics to follow ALBERT model
|
179 |
+
else:
|
180 |
+
raise ValueError('Unsupported tokenization: %s' % tokenization)
|
181 |
+
|
182 |
+
self._pack_inputs = modeling.layers.BertPackInputs(
|
183 |
+
seq_length=seq_length,
|
184 |
+
special_tokens_dict=self._tokenizer.get_special_tokens_dict())
|
185 |
+
|
186 |
+
def __call__(self, segments):
|
187 |
+
segments = [self._tokenizer(s) for s in segments]
|
188 |
+
# BertTokenizer returns a RaggedTensor with shape [batch, word, subword],
|
189 |
+
# and SentencepieceTokenizer returns a RaggedTensor with shape
|
190 |
+
# [batch, sentencepiece],
|
191 |
+
segments = [
|
192 |
+
tf.cast(x.merge_dims(1, -1) if x.shape.rank > 2 else x, tf.int32)
|
193 |
+
for x in segments
|
194 |
+
]
|
195 |
+
return self._pack_inputs(segments)
|
196 |
+
|
197 |
+
|
198 |
+
@data_loader_factory.register_data_loader_cls(SentencePredictionTextDataConfig)
|
199 |
+
class SentencePredictionTextDataLoader(data_loader.DataLoader):
|
200 |
+
"""Loads dataset with raw text for sentence prediction task."""
|
201 |
+
|
202 |
+
def __init__(self, params):
|
203 |
+
if bool(params.tfds_name) != bool(params.tfds_split):
|
204 |
+
raise ValueError('`tfds_name` and `tfds_split` should be specified or '
|
205 |
+
'unspecified at the same time.')
|
206 |
+
if bool(params.tfds_name) == bool(params.input_path):
|
207 |
+
raise ValueError('Must specify either `tfds_name` and `tfds_split` '
|
208 |
+
'or `input_path`.')
|
209 |
+
if not params.text_fields:
|
210 |
+
raise ValueError('Unexpected empty text fields.')
|
211 |
+
if bool(params.vocab_file) == bool(params.preprocessing_hub_module_url):
|
212 |
+
raise ValueError('Must specify exactly one of vocab_file (with matching '
|
213 |
+
'lower_case flag) or preprocessing_hub_module_url.')
|
214 |
+
|
215 |
+
self._params = params
|
216 |
+
self._text_fields = params.text_fields
|
217 |
+
self._label_field = params.label_field
|
218 |
+
self._label_type = params.label_type
|
219 |
+
self._include_example_id = params.include_example_id
|
220 |
+
self._text_processor = TextProcessor(
|
221 |
+
seq_length=params.seq_length,
|
222 |
+
vocab_file=params.vocab_file,
|
223 |
+
tokenization=params.tokenization,
|
224 |
+
lower_case=params.lower_case,
|
225 |
+
preprocessing_hub_module_url=params.preprocessing_hub_module_url)
|
226 |
+
|
227 |
+
def _bert_preprocess(self, record: Mapping[str, tf.Tensor]):
|
228 |
+
"""Berts preprocess."""
|
229 |
+
segments = [record[x] for x in self._text_fields]
|
230 |
+
model_inputs = self._text_processor(segments)
|
231 |
+
for key in record:
|
232 |
+
if key not in self._text_fields:
|
233 |
+
model_inputs[key] = record[key]
|
234 |
+
return model_inputs
|
235 |
+
|
236 |
+
def name_to_features_spec(self):
|
237 |
+
name_to_features = {}
|
238 |
+
for text_field in self._text_fields:
|
239 |
+
name_to_features[text_field] = tf.io.FixedLenFeature([], tf.string)
|
240 |
+
|
241 |
+
label_type = LABEL_TYPES_MAP[self._label_type]
|
242 |
+
name_to_features[self._label_field] = tf.io.FixedLenFeature([], label_type)
|
243 |
+
if self._include_example_id:
|
244 |
+
name_to_features['example_id'] = tf.io.FixedLenFeature([], tf.int64)
|
245 |
+
return name_to_features
|
246 |
+
|
247 |
+
def _decode(self, record: tf.Tensor):
|
248 |
+
"""Decodes a serialized tf.Example."""
|
249 |
+
example = tf.io.parse_single_example(record, self.name_to_features_spec())
|
250 |
+
# tf.Example only supports tf.int64, but the TPU only supports tf.int32.
|
251 |
+
# So cast all int64 to int32.
|
252 |
+
for name in example:
|
253 |
+
t = example[name]
|
254 |
+
if t.dtype == tf.int64:
|
255 |
+
t = tf.cast(t, tf.int32)
|
256 |
+
example[name] = t
|
257 |
+
|
258 |
+
return example
|
259 |
+
|
260 |
+
def load(self, input_context: Optional[tf.distribute.InputContext] = None):
|
261 |
+
"""Returns a tf.dataset.Dataset."""
|
262 |
+
reader = input_reader.InputReader(
|
263 |
+
dataset_fn=dataset_fn.pick_dataset_fn(self._params.file_type),
|
264 |
+
decoder_fn=self._decode if self._params.input_path else None,
|
265 |
+
params=self._params,
|
266 |
+
postprocess_fn=self._bert_preprocess)
|
267 |
+
return reader.read(input_context)
|
sentence_prediction_dataloader_test.py
ADDED
@@ -0,0 +1,290 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
"""Tests for official.nlp.data.sentence_prediction_dataloader."""
|
16 |
+
import os
|
17 |
+
|
18 |
+
from absl.testing import parameterized
|
19 |
+
import numpy as np
|
20 |
+
import tensorflow as tf, tf_keras
|
21 |
+
|
22 |
+
from sentencepiece import SentencePieceTrainer
|
23 |
+
from official.nlp.data import sentence_prediction_dataloader as loader
|
24 |
+
|
25 |
+
|
26 |
+
def _create_fake_preprocessed_dataset(output_path, seq_length, label_type):
|
27 |
+
"""Creates a fake dataset."""
|
28 |
+
writer = tf.io.TFRecordWriter(output_path)
|
29 |
+
|
30 |
+
def create_int_feature(values):
|
31 |
+
f = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values)))
|
32 |
+
return f
|
33 |
+
|
34 |
+
def create_float_feature(values):
|
35 |
+
f = tf.train.Feature(float_list=tf.train.FloatList(value=list(values)))
|
36 |
+
return f
|
37 |
+
|
38 |
+
for _ in range(100):
|
39 |
+
features = {}
|
40 |
+
input_ids = np.random.randint(100, size=(seq_length))
|
41 |
+
features['input_ids'] = create_int_feature(input_ids)
|
42 |
+
features['input_mask'] = create_int_feature(np.ones_like(input_ids))
|
43 |
+
features['segment_ids'] = create_int_feature(np.ones_like(input_ids))
|
44 |
+
|
45 |
+
if label_type == 'int':
|
46 |
+
features['label_ids'] = create_int_feature([1])
|
47 |
+
elif label_type == 'float':
|
48 |
+
features['label_ids'] = create_float_feature([0.5])
|
49 |
+
else:
|
50 |
+
raise ValueError('Unsupported label_type: %s' % label_type)
|
51 |
+
|
52 |
+
tf_example = tf.train.Example(features=tf.train.Features(feature=features))
|
53 |
+
writer.write(tf_example.SerializeToString())
|
54 |
+
writer.close()
|
55 |
+
|
56 |
+
|
57 |
+
def _create_fake_raw_dataset(output_path, text_fields, label_type):
|
58 |
+
"""Creates a fake tf record file."""
|
59 |
+
writer = tf.io.TFRecordWriter(output_path)
|
60 |
+
|
61 |
+
def create_str_feature(value):
|
62 |
+
f = tf.train.Feature(bytes_list=tf.train.BytesList(value=value))
|
63 |
+
return f
|
64 |
+
|
65 |
+
def create_int_feature(values):
|
66 |
+
f = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values)))
|
67 |
+
return f
|
68 |
+
|
69 |
+
def create_float_feature(values):
|
70 |
+
f = tf.train.Feature(float_list=tf.train.FloatList(value=list(values)))
|
71 |
+
return f
|
72 |
+
|
73 |
+
for _ in range(100):
|
74 |
+
features = {}
|
75 |
+
for text_field in text_fields:
|
76 |
+
features[text_field] = create_str_feature([b'hello world'])
|
77 |
+
|
78 |
+
if label_type == 'int':
|
79 |
+
features['label'] = create_int_feature([0])
|
80 |
+
elif label_type == 'float':
|
81 |
+
features['label'] = create_float_feature([0.5])
|
82 |
+
else:
|
83 |
+
raise ValueError('Unexpected label_type: %s' % label_type)
|
84 |
+
tf_example = tf.train.Example(features=tf.train.Features(feature=features))
|
85 |
+
writer.write(tf_example.SerializeToString())
|
86 |
+
writer.close()
|
87 |
+
|
88 |
+
|
89 |
+
def _create_fake_sentencepiece_model(output_dir):
|
90 |
+
vocab = ['a', 'b', 'c', 'd', 'e', 'abc', 'def', 'ABC', 'DEF']
|
91 |
+
model_prefix = os.path.join(output_dir, 'spm_model')
|
92 |
+
input_text_file_path = os.path.join(output_dir, 'train_input.txt')
|
93 |
+
with tf.io.gfile.GFile(input_text_file_path, 'w') as f:
|
94 |
+
f.write(' '.join(vocab + ['\n']))
|
95 |
+
# Add 7 more tokens: <pad>, <unk>, [CLS], [SEP], [MASK], <s>, </s>.
|
96 |
+
full_vocab_size = len(vocab) + 7
|
97 |
+
flags = dict(
|
98 |
+
model_prefix=model_prefix,
|
99 |
+
model_type='word',
|
100 |
+
input=input_text_file_path,
|
101 |
+
pad_id=0,
|
102 |
+
unk_id=1,
|
103 |
+
control_symbols='[CLS],[SEP],[MASK]',
|
104 |
+
vocab_size=full_vocab_size,
|
105 |
+
bos_id=full_vocab_size - 2,
|
106 |
+
eos_id=full_vocab_size - 1)
|
107 |
+
SentencePieceTrainer.Train(' '.join(
|
108 |
+
['--{}={}'.format(k, v) for k, v in flags.items()]))
|
109 |
+
return model_prefix + '.model'
|
110 |
+
|
111 |
+
|
112 |
+
def _create_fake_vocab_file(vocab_file_path):
|
113 |
+
tokens = ['[PAD]']
|
114 |
+
for i in range(1, 100):
|
115 |
+
tokens.append('[unused%d]' % i)
|
116 |
+
tokens.extend(['[UNK]', '[CLS]', '[SEP]', '[MASK]', 'hello', 'world'])
|
117 |
+
with tf.io.gfile.GFile(vocab_file_path, 'w') as outfile:
|
118 |
+
outfile.write('\n'.join(tokens))
|
119 |
+
|
120 |
+
|
121 |
+
class SentencePredictionDataTest(tf.test.TestCase, parameterized.TestCase):
|
122 |
+
|
123 |
+
@parameterized.parameters(('int', tf.int32), ('float', tf.float32))
|
124 |
+
def test_load_dataset(self, label_type, expected_label_type):
|
125 |
+
input_path = os.path.join(self.get_temp_dir(), 'train.tf_record')
|
126 |
+
batch_size = 10
|
127 |
+
seq_length = 128
|
128 |
+
_create_fake_preprocessed_dataset(input_path, seq_length, label_type)
|
129 |
+
data_config = loader.SentencePredictionDataConfig(
|
130 |
+
input_path=input_path,
|
131 |
+
seq_length=seq_length,
|
132 |
+
global_batch_size=batch_size,
|
133 |
+
label_type=label_type)
|
134 |
+
dataset = loader.SentencePredictionDataLoader(data_config).load()
|
135 |
+
features = next(iter(dataset))
|
136 |
+
self.assertCountEqual(
|
137 |
+
['input_word_ids', 'input_type_ids', 'input_mask', 'label_ids'],
|
138 |
+
features.keys())
|
139 |
+
self.assertEqual(features['input_word_ids'].shape, (batch_size, seq_length))
|
140 |
+
self.assertEqual(features['input_mask'].shape, (batch_size, seq_length))
|
141 |
+
self.assertEqual(features['input_type_ids'].shape, (batch_size, seq_length))
|
142 |
+
self.assertEqual(features['label_ids'].shape, (batch_size,))
|
143 |
+
self.assertEqual(features['label_ids'].dtype, expected_label_type)
|
144 |
+
|
145 |
+
def test_load_dataset_with_label_mapping(self):
|
146 |
+
input_path = os.path.join(self.get_temp_dir(), 'train.tf_record')
|
147 |
+
batch_size = 10
|
148 |
+
seq_length = 128
|
149 |
+
_create_fake_preprocessed_dataset(input_path, seq_length, 'int')
|
150 |
+
data_config = loader.SentencePredictionDataConfig(
|
151 |
+
input_path=input_path,
|
152 |
+
seq_length=seq_length,
|
153 |
+
global_batch_size=batch_size,
|
154 |
+
label_type='int',
|
155 |
+
label_name=('label_ids', 'next_sentence_labels'))
|
156 |
+
dataset = loader.SentencePredictionDataLoader(data_config).load()
|
157 |
+
features = next(iter(dataset))
|
158 |
+
self.assertCountEqual([
|
159 |
+
'input_word_ids', 'input_mask', 'input_type_ids',
|
160 |
+
'next_sentence_labels', 'label_ids'
|
161 |
+
], features.keys())
|
162 |
+
self.assertEqual(features['input_word_ids'].shape, (batch_size, seq_length))
|
163 |
+
self.assertEqual(features['input_mask'].shape, (batch_size, seq_length))
|
164 |
+
self.assertEqual(features['input_type_ids'].shape, (batch_size, seq_length))
|
165 |
+
self.assertEqual(features['label_ids'].shape, (batch_size,))
|
166 |
+
self.assertEqual(features['label_ids'].dtype, tf.int32)
|
167 |
+
self.assertEqual(features['next_sentence_labels'].shape, (batch_size,))
|
168 |
+
self.assertEqual(features['next_sentence_labels'].dtype, tf.int32)
|
169 |
+
|
170 |
+
|
171 |
+
class SentencePredictionTfdsDataLoaderTest(tf.test.TestCase,
|
172 |
+
parameterized.TestCase):
|
173 |
+
|
174 |
+
@parameterized.parameters(True, False)
|
175 |
+
def test_python_wordpiece_preprocessing(self, use_tfds):
|
176 |
+
batch_size = 10
|
177 |
+
seq_length = 256 # Non-default value.
|
178 |
+
lower_case = True
|
179 |
+
|
180 |
+
tf_record_path = os.path.join(self.get_temp_dir(), 'train.tf_record')
|
181 |
+
text_fields = ['sentence1', 'sentence2']
|
182 |
+
if not use_tfds:
|
183 |
+
_create_fake_raw_dataset(tf_record_path, text_fields, label_type='int')
|
184 |
+
|
185 |
+
vocab_file_path = os.path.join(self.get_temp_dir(), 'vocab.txt')
|
186 |
+
_create_fake_vocab_file(vocab_file_path)
|
187 |
+
|
188 |
+
data_config = loader.SentencePredictionTextDataConfig(
|
189 |
+
input_path='' if use_tfds else tf_record_path,
|
190 |
+
tfds_name='glue/mrpc' if use_tfds else '',
|
191 |
+
tfds_split='train' if use_tfds else '',
|
192 |
+
text_fields=text_fields,
|
193 |
+
global_batch_size=batch_size,
|
194 |
+
seq_length=seq_length,
|
195 |
+
is_training=True,
|
196 |
+
lower_case=lower_case,
|
197 |
+
vocab_file=vocab_file_path)
|
198 |
+
dataset = loader.SentencePredictionTextDataLoader(data_config).load()
|
199 |
+
features = next(iter(dataset))
|
200 |
+
label_field = data_config.label_field
|
201 |
+
expected_keys = [
|
202 |
+
'input_word_ids', 'input_type_ids', 'input_mask', label_field
|
203 |
+
]
|
204 |
+
if use_tfds:
|
205 |
+
expected_keys += ['idx']
|
206 |
+
self.assertCountEqual(expected_keys, features.keys())
|
207 |
+
self.assertEqual(features['input_word_ids'].shape, (batch_size, seq_length))
|
208 |
+
self.assertEqual(features['input_mask'].shape, (batch_size, seq_length))
|
209 |
+
self.assertEqual(features['input_type_ids'].shape, (batch_size, seq_length))
|
210 |
+
self.assertEqual(features[label_field].shape, (batch_size,))
|
211 |
+
|
212 |
+
@parameterized.parameters(True, False)
|
213 |
+
def test_python_sentencepiece_preprocessing(self, use_tfds):
|
214 |
+
batch_size = 10
|
215 |
+
seq_length = 256 # Non-default value.
|
216 |
+
lower_case = True
|
217 |
+
|
218 |
+
tf_record_path = os.path.join(self.get_temp_dir(), 'train.tf_record')
|
219 |
+
text_fields = ['sentence1', 'sentence2']
|
220 |
+
if not use_tfds:
|
221 |
+
_create_fake_raw_dataset(tf_record_path, text_fields, label_type='int')
|
222 |
+
|
223 |
+
sp_model_file_path = _create_fake_sentencepiece_model(self.get_temp_dir())
|
224 |
+
data_config = loader.SentencePredictionTextDataConfig(
|
225 |
+
input_path='' if use_tfds else tf_record_path,
|
226 |
+
tfds_name='glue/mrpc' if use_tfds else '',
|
227 |
+
tfds_split='train' if use_tfds else '',
|
228 |
+
text_fields=text_fields,
|
229 |
+
global_batch_size=batch_size,
|
230 |
+
seq_length=seq_length,
|
231 |
+
is_training=True,
|
232 |
+
lower_case=lower_case,
|
233 |
+
tokenization='SentencePiece',
|
234 |
+
vocab_file=sp_model_file_path,
|
235 |
+
)
|
236 |
+
dataset = loader.SentencePredictionTextDataLoader(data_config).load()
|
237 |
+
features = next(iter(dataset))
|
238 |
+
label_field = data_config.label_field
|
239 |
+
expected_keys = [
|
240 |
+
'input_word_ids', 'input_type_ids', 'input_mask', label_field
|
241 |
+
]
|
242 |
+
if use_tfds:
|
243 |
+
expected_keys += ['idx']
|
244 |
+
self.assertCountEqual(expected_keys, features.keys())
|
245 |
+
self.assertEqual(features['input_word_ids'].shape, (batch_size, seq_length))
|
246 |
+
self.assertEqual(features['input_mask'].shape, (batch_size, seq_length))
|
247 |
+
self.assertEqual(features['input_type_ids'].shape, (batch_size, seq_length))
|
248 |
+
self.assertEqual(features[label_field].shape, (batch_size,))
|
249 |
+
|
250 |
+
@parameterized.parameters(True, False)
|
251 |
+
def test_saved_model_preprocessing(self, use_tfds):
|
252 |
+
batch_size = 10
|
253 |
+
seq_length = 256 # Non-default value.
|
254 |
+
|
255 |
+
tf_record_path = os.path.join(self.get_temp_dir(), 'train.tf_record')
|
256 |
+
text_fields = ['sentence1', 'sentence2']
|
257 |
+
if not use_tfds:
|
258 |
+
_create_fake_raw_dataset(tf_record_path, text_fields, label_type='float')
|
259 |
+
|
260 |
+
vocab_file_path = os.path.join(self.get_temp_dir(), 'vocab.txt')
|
261 |
+
_create_fake_vocab_file(vocab_file_path)
|
262 |
+
data_config = loader.SentencePredictionTextDataConfig(
|
263 |
+
input_path='' if use_tfds else tf_record_path,
|
264 |
+
tfds_name='glue/mrpc' if use_tfds else '',
|
265 |
+
tfds_split='train' if use_tfds else '',
|
266 |
+
text_fields=text_fields,
|
267 |
+
global_batch_size=batch_size,
|
268 |
+
seq_length=seq_length,
|
269 |
+
is_training=True,
|
270 |
+
preprocessing_hub_module_url=(
|
271 |
+
'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3'),
|
272 |
+
label_type='int' if use_tfds else 'float',
|
273 |
+
)
|
274 |
+
dataset = loader.SentencePredictionTextDataLoader(data_config).load()
|
275 |
+
features = next(iter(dataset))
|
276 |
+
label_field = data_config.label_field
|
277 |
+
expected_keys = [
|
278 |
+
'input_word_ids', 'input_type_ids', 'input_mask', label_field
|
279 |
+
]
|
280 |
+
if use_tfds:
|
281 |
+
expected_keys += ['idx']
|
282 |
+
self.assertCountEqual(expected_keys, features.keys())
|
283 |
+
self.assertEqual(features['input_word_ids'].shape, (batch_size, seq_length))
|
284 |
+
self.assertEqual(features['input_mask'].shape, (batch_size, seq_length))
|
285 |
+
self.assertEqual(features['input_type_ids'].shape, (batch_size, seq_length))
|
286 |
+
self.assertEqual(features[label_field].shape, (batch_size,))
|
287 |
+
|
288 |
+
|
289 |
+
if __name__ == '__main__':
|
290 |
+
tf.test.main()
|
sentence_retrieval_lib.py
ADDED
@@ -0,0 +1,166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
"""BERT library to process data for cross lingual sentence retrieval task."""
|
16 |
+
|
17 |
+
import os
|
18 |
+
|
19 |
+
from absl import logging
|
20 |
+
from official.nlp.data import classifier_data_lib
|
21 |
+
from official.nlp.tools import tokenization
|
22 |
+
|
23 |
+
|
24 |
+
class BuccProcessor(classifier_data_lib.DataProcessor):
|
25 |
+
"""Procssor for Xtreme BUCC data set."""
|
26 |
+
supported_languages = ["de", "fr", "ru", "zh"]
|
27 |
+
|
28 |
+
def __init__(self, process_text_fn=tokenization.convert_to_unicode):
|
29 |
+
super(BuccProcessor, self).__init__(process_text_fn)
|
30 |
+
self.languages = BuccProcessor.supported_languages
|
31 |
+
|
32 |
+
def get_dev_examples(self, data_dir, file_pattern):
|
33 |
+
return self._create_examples(
|
34 |
+
self._read_tsv(os.path.join(data_dir, file_pattern.format("dev"))),
|
35 |
+
"sample")
|
36 |
+
|
37 |
+
def get_test_examples(self, data_dir, file_pattern):
|
38 |
+
return self._create_examples(
|
39 |
+
self._read_tsv(os.path.join(data_dir, file_pattern.format("test"))),
|
40 |
+
"test")
|
41 |
+
|
42 |
+
@staticmethod
|
43 |
+
def get_processor_name():
|
44 |
+
"""See base class."""
|
45 |
+
return "BUCC"
|
46 |
+
|
47 |
+
def _create_examples(self, lines, set_type):
|
48 |
+
"""Creates examples for the training and dev sets."""
|
49 |
+
examples = []
|
50 |
+
for (i, line) in enumerate(lines):
|
51 |
+
guid = "%s-%s" % (set_type, i)
|
52 |
+
example_id = int(line[0].split("-")[1])
|
53 |
+
text_a = self.process_text_fn(line[1])
|
54 |
+
examples.append(
|
55 |
+
classifier_data_lib.InputExample(
|
56 |
+
guid=guid, text_a=text_a, example_id=example_id))
|
57 |
+
return examples
|
58 |
+
|
59 |
+
|
60 |
+
class TatoebaProcessor(classifier_data_lib.DataProcessor):
|
61 |
+
"""Procssor for Xtreme Tatoeba data set."""
|
62 |
+
supported_languages = [
|
63 |
+
"af", "ar", "bg", "bn", "de", "el", "es", "et", "eu", "fa", "fi", "fr",
|
64 |
+
"he", "hi", "hu", "id", "it", "ja", "jv", "ka", "kk", "ko", "ml", "mr",
|
65 |
+
"nl", "pt", "ru", "sw", "ta", "te", "th", "tl", "tr", "ur", "vi", "zh"
|
66 |
+
]
|
67 |
+
|
68 |
+
def __init__(self, process_text_fn=tokenization.convert_to_unicode):
|
69 |
+
super(TatoebaProcessor, self).__init__(process_text_fn)
|
70 |
+
self.languages = TatoebaProcessor.supported_languages
|
71 |
+
|
72 |
+
def get_test_examples(self, data_dir, file_path):
|
73 |
+
return self._create_examples(
|
74 |
+
self._read_tsv(os.path.join(data_dir, file_path)), "test")
|
75 |
+
|
76 |
+
@staticmethod
|
77 |
+
def get_processor_name():
|
78 |
+
"""See base class."""
|
79 |
+
return "TATOEBA"
|
80 |
+
|
81 |
+
def _create_examples(self, lines, set_type):
|
82 |
+
"""Creates examples for the training and dev sets."""
|
83 |
+
examples = []
|
84 |
+
for (i, line) in enumerate(lines):
|
85 |
+
guid = "%s-%s" % (set_type, i)
|
86 |
+
text_a = self.process_text_fn(line[0])
|
87 |
+
examples.append(
|
88 |
+
classifier_data_lib.InputExample(
|
89 |
+
guid=guid, text_a=text_a, example_id=i))
|
90 |
+
return examples
|
91 |
+
|
92 |
+
|
93 |
+
def generate_sentence_retrevial_tf_record(processor,
|
94 |
+
data_dir,
|
95 |
+
tokenizer,
|
96 |
+
eval_data_output_path=None,
|
97 |
+
test_data_output_path=None,
|
98 |
+
max_seq_length=128):
|
99 |
+
"""Generates the tf records for retrieval tasks.
|
100 |
+
|
101 |
+
Args:
|
102 |
+
processor: Input processor object to be used for generating data. Subclass
|
103 |
+
of `DataProcessor`.
|
104 |
+
data_dir: Directory that contains train/eval data to process. Data files
|
105 |
+
should be in from.
|
106 |
+
tokenizer: The tokenizer to be applied on the data.
|
107 |
+
eval_data_output_path: Output to which processed tf record for evaluation
|
108 |
+
will be saved.
|
109 |
+
test_data_output_path: Output to which processed tf record for testing
|
110 |
+
will be saved. Must be a pattern template with {} if processor has
|
111 |
+
language specific test data.
|
112 |
+
max_seq_length: Maximum sequence length of the to be generated
|
113 |
+
training/eval data.
|
114 |
+
|
115 |
+
Returns:
|
116 |
+
A dictionary containing input meta data.
|
117 |
+
"""
|
118 |
+
assert eval_data_output_path or test_data_output_path
|
119 |
+
|
120 |
+
if processor.get_processor_name() == "BUCC":
|
121 |
+
path_pattern = "{}-en.{{}}.{}"
|
122 |
+
|
123 |
+
if processor.get_processor_name() == "TATOEBA":
|
124 |
+
path_pattern = "{}-en.{}"
|
125 |
+
|
126 |
+
meta_data = {
|
127 |
+
"processor_type": processor.get_processor_name(),
|
128 |
+
"max_seq_length": max_seq_length,
|
129 |
+
"number_eval_data": {},
|
130 |
+
"number_test_data": {},
|
131 |
+
}
|
132 |
+
logging.info("Start to process %s task data", processor.get_processor_name())
|
133 |
+
|
134 |
+
for lang_a in processor.languages:
|
135 |
+
for lang_b in [lang_a, "en"]:
|
136 |
+
if eval_data_output_path:
|
137 |
+
eval_input_data_examples = processor.get_dev_examples(
|
138 |
+
data_dir, os.path.join(path_pattern.format(lang_a, lang_b)))
|
139 |
+
|
140 |
+
num_eval_data = len(eval_input_data_examples)
|
141 |
+
logging.info("Processing %d dev examples of %s-en.%s", num_eval_data,
|
142 |
+
lang_a, lang_b)
|
143 |
+
output_file = os.path.join(
|
144 |
+
eval_data_output_path,
|
145 |
+
"{}-en-{}.{}.tfrecords".format(lang_a, lang_b, "dev"))
|
146 |
+
classifier_data_lib.file_based_convert_examples_to_features(
|
147 |
+
eval_input_data_examples, None, max_seq_length, tokenizer,
|
148 |
+
output_file, None)
|
149 |
+
meta_data["number_eval_data"][f"{lang_a}-en.{lang_b}"] = num_eval_data
|
150 |
+
|
151 |
+
if test_data_output_path:
|
152 |
+
test_input_data_examples = processor.get_test_examples(
|
153 |
+
data_dir, os.path.join(path_pattern.format(lang_a, lang_b)))
|
154 |
+
|
155 |
+
num_test_data = len(test_input_data_examples)
|
156 |
+
logging.info("Processing %d test examples of %s-en.%s", num_test_data,
|
157 |
+
lang_a, lang_b)
|
158 |
+
output_file = os.path.join(
|
159 |
+
test_data_output_path,
|
160 |
+
"{}-en-{}.{}.tfrecords".format(lang_a, lang_b, "test"))
|
161 |
+
classifier_data_lib.file_based_convert_examples_to_features(
|
162 |
+
test_input_data_examples, None, max_seq_length, tokenizer,
|
163 |
+
output_file, None)
|
164 |
+
meta_data["number_test_data"][f"{lang_a}-en.{lang_b}"] = num_test_data
|
165 |
+
|
166 |
+
return meta_data
|
squad_lib.py
ADDED
@@ -0,0 +1,975 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
"""Library to process data for SQuAD 1.1 and SQuAD 2.0."""
|
16 |
+
# pylint: disable=g-bad-import-order
|
17 |
+
import collections
|
18 |
+
import copy
|
19 |
+
import json
|
20 |
+
import math
|
21 |
+
import os
|
22 |
+
|
23 |
+
import six
|
24 |
+
|
25 |
+
from absl import logging
|
26 |
+
import tensorflow as tf, tf_keras
|
27 |
+
|
28 |
+
from official.nlp.tools import tokenization
|
29 |
+
|
30 |
+
|
31 |
+
class SquadExample(object):
|
32 |
+
"""A single training/test example for simple sequence classification.
|
33 |
+
|
34 |
+
For examples without an answer, the start and end position are -1.
|
35 |
+
|
36 |
+
Attributes:
|
37 |
+
qas_id: ID of the question-answer pair.
|
38 |
+
question_text: Original text for the question.
|
39 |
+
doc_tokens: The list of tokens in the context obtained by splitting on
|
40 |
+
whitespace only.
|
41 |
+
orig_answer_text: Original text for the answer.
|
42 |
+
start_position: Starting index of the answer in `doc_tokens`.
|
43 |
+
end_position: Ending index of the answer in `doc_tokens`.
|
44 |
+
is_impossible: Whether the question is impossible to answer given the
|
45 |
+
context. Only used in SQuAD 2.0.
|
46 |
+
"""
|
47 |
+
|
48 |
+
def __init__(self,
|
49 |
+
qas_id,
|
50 |
+
question_text,
|
51 |
+
doc_tokens,
|
52 |
+
orig_answer_text=None,
|
53 |
+
start_position=None,
|
54 |
+
end_position=None,
|
55 |
+
is_impossible=False):
|
56 |
+
self.qas_id = qas_id
|
57 |
+
self.question_text = question_text
|
58 |
+
self.doc_tokens = doc_tokens
|
59 |
+
self.orig_answer_text = orig_answer_text
|
60 |
+
self.start_position = start_position
|
61 |
+
self.end_position = end_position
|
62 |
+
self.is_impossible = is_impossible
|
63 |
+
|
64 |
+
def __str__(self):
|
65 |
+
return self.__repr__()
|
66 |
+
|
67 |
+
def __repr__(self):
|
68 |
+
s = ""
|
69 |
+
s += "qas_id: %s" % (tokenization.printable_text(self.qas_id))
|
70 |
+
s += ", question_text: %s" % (
|
71 |
+
tokenization.printable_text(self.question_text))
|
72 |
+
s += ", doc_tokens: [%s]" % (" ".join(self.doc_tokens))
|
73 |
+
if self.start_position:
|
74 |
+
s += ", start_position: %d" % (self.start_position)
|
75 |
+
if self.start_position:
|
76 |
+
s += ", end_position: %d" % (self.end_position)
|
77 |
+
if self.start_position:
|
78 |
+
s += ", is_impossible: %r" % (self.is_impossible)
|
79 |
+
return s
|
80 |
+
|
81 |
+
|
82 |
+
class InputFeatures(object):
|
83 |
+
"""A single set of features of data."""
|
84 |
+
|
85 |
+
def __init__(self,
|
86 |
+
unique_id,
|
87 |
+
example_index,
|
88 |
+
doc_span_index,
|
89 |
+
tokens,
|
90 |
+
token_to_orig_map,
|
91 |
+
token_is_max_context,
|
92 |
+
input_ids,
|
93 |
+
input_mask,
|
94 |
+
segment_ids,
|
95 |
+
paragraph_mask=None,
|
96 |
+
class_index=None,
|
97 |
+
start_position=None,
|
98 |
+
end_position=None,
|
99 |
+
is_impossible=None):
|
100 |
+
self.unique_id = unique_id
|
101 |
+
self.example_index = example_index
|
102 |
+
self.doc_span_index = doc_span_index
|
103 |
+
self.tokens = tokens
|
104 |
+
self.token_to_orig_map = token_to_orig_map
|
105 |
+
self.token_is_max_context = token_is_max_context
|
106 |
+
self.input_ids = input_ids
|
107 |
+
self.input_mask = input_mask
|
108 |
+
self.segment_ids = segment_ids
|
109 |
+
self.start_position = start_position
|
110 |
+
self.end_position = end_position
|
111 |
+
self.is_impossible = is_impossible
|
112 |
+
self.paragraph_mask = paragraph_mask
|
113 |
+
self.class_index = class_index
|
114 |
+
|
115 |
+
|
116 |
+
class FeatureWriter(object):
|
117 |
+
"""Writes InputFeature to TF example file."""
|
118 |
+
|
119 |
+
def __init__(self, filename, is_training):
|
120 |
+
self.filename = filename
|
121 |
+
self.is_training = is_training
|
122 |
+
self.num_features = 0
|
123 |
+
tf.io.gfile.makedirs(os.path.dirname(filename))
|
124 |
+
self._writer = tf.io.TFRecordWriter(filename)
|
125 |
+
|
126 |
+
def process_feature(self, feature):
|
127 |
+
"""Write a InputFeature to the TFRecordWriter as a tf.train.Example."""
|
128 |
+
self.num_features += 1
|
129 |
+
|
130 |
+
def create_int_feature(values):
|
131 |
+
feature = tf.train.Feature(
|
132 |
+
int64_list=tf.train.Int64List(value=list(values)))
|
133 |
+
return feature
|
134 |
+
|
135 |
+
features = collections.OrderedDict()
|
136 |
+
features["unique_ids"] = create_int_feature([feature.unique_id])
|
137 |
+
features["input_ids"] = create_int_feature(feature.input_ids)
|
138 |
+
features["input_mask"] = create_int_feature(feature.input_mask)
|
139 |
+
features["segment_ids"] = create_int_feature(feature.segment_ids)
|
140 |
+
|
141 |
+
if feature.paragraph_mask is not None:
|
142 |
+
features["paragraph_mask"] = create_int_feature(feature.paragraph_mask)
|
143 |
+
if feature.class_index is not None:
|
144 |
+
features["class_index"] = create_int_feature([feature.class_index])
|
145 |
+
|
146 |
+
if self.is_training:
|
147 |
+
features["start_positions"] = create_int_feature([feature.start_position])
|
148 |
+
features["end_positions"] = create_int_feature([feature.end_position])
|
149 |
+
impossible = 0
|
150 |
+
if feature.is_impossible:
|
151 |
+
impossible = 1
|
152 |
+
features["is_impossible"] = create_int_feature([impossible])
|
153 |
+
|
154 |
+
tf_example = tf.train.Example(features=tf.train.Features(feature=features))
|
155 |
+
self._writer.write(tf_example.SerializeToString())
|
156 |
+
|
157 |
+
def close(self):
|
158 |
+
self._writer.close()
|
159 |
+
|
160 |
+
|
161 |
+
def read_squad_examples(input_file, is_training,
|
162 |
+
version_2_with_negative,
|
163 |
+
translated_input_folder=None):
|
164 |
+
"""Read a SQuAD json file into a list of SquadExample."""
|
165 |
+
with tf.io.gfile.GFile(input_file, "r") as reader:
|
166 |
+
input_data = json.load(reader)["data"]
|
167 |
+
|
168 |
+
if translated_input_folder is not None:
|
169 |
+
translated_files = tf.io.gfile.glob(
|
170 |
+
os.path.join(translated_input_folder, "*.json"))
|
171 |
+
for file in translated_files:
|
172 |
+
with tf.io.gfile.GFile(file, "r") as reader:
|
173 |
+
input_data.extend(json.load(reader)["data"])
|
174 |
+
|
175 |
+
def is_whitespace(c):
|
176 |
+
if c == " " or c == "\t" or c == "\r" or c == "\n" or ord(c) == 0x202F:
|
177 |
+
return True
|
178 |
+
return False
|
179 |
+
|
180 |
+
examples = []
|
181 |
+
for entry in input_data:
|
182 |
+
for paragraph in entry["paragraphs"]:
|
183 |
+
paragraph_text = paragraph["context"]
|
184 |
+
doc_tokens = []
|
185 |
+
char_to_word_offset = []
|
186 |
+
prev_is_whitespace = True
|
187 |
+
for c in paragraph_text:
|
188 |
+
if is_whitespace(c):
|
189 |
+
prev_is_whitespace = True
|
190 |
+
else:
|
191 |
+
if prev_is_whitespace:
|
192 |
+
doc_tokens.append(c)
|
193 |
+
else:
|
194 |
+
doc_tokens[-1] += c
|
195 |
+
prev_is_whitespace = False
|
196 |
+
char_to_word_offset.append(len(doc_tokens) - 1)
|
197 |
+
|
198 |
+
for qa in paragraph["qas"]:
|
199 |
+
qas_id = qa["id"]
|
200 |
+
question_text = qa["question"]
|
201 |
+
start_position = None
|
202 |
+
end_position = None
|
203 |
+
orig_answer_text = None
|
204 |
+
is_impossible = False
|
205 |
+
if is_training:
|
206 |
+
|
207 |
+
if version_2_with_negative:
|
208 |
+
is_impossible = qa["is_impossible"]
|
209 |
+
if (len(qa["answers"]) != 1) and (not is_impossible):
|
210 |
+
raise ValueError(
|
211 |
+
"For training, each question should have exactly 1 answer.")
|
212 |
+
if not is_impossible:
|
213 |
+
answer = qa["answers"][0]
|
214 |
+
orig_answer_text = answer["text"]
|
215 |
+
answer_offset = answer["answer_start"]
|
216 |
+
answer_length = len(orig_answer_text)
|
217 |
+
start_position = char_to_word_offset[answer_offset]
|
218 |
+
end_position = char_to_word_offset[answer_offset + answer_length -
|
219 |
+
1]
|
220 |
+
# Only add answers where the text can be exactly recovered from the
|
221 |
+
# document. If this CAN'T happen it's likely due to weird Unicode
|
222 |
+
# stuff so we will just skip the example.
|
223 |
+
#
|
224 |
+
# Note that this means for training mode, every example is NOT
|
225 |
+
# guaranteed to be preserved.
|
226 |
+
actual_text = " ".join(doc_tokens[start_position:(end_position +
|
227 |
+
1)])
|
228 |
+
cleaned_answer_text = " ".join(
|
229 |
+
tokenization.whitespace_tokenize(orig_answer_text))
|
230 |
+
if actual_text.find(cleaned_answer_text) == -1:
|
231 |
+
logging.warning("Could not find answer: '%s' vs. '%s'",
|
232 |
+
actual_text, cleaned_answer_text)
|
233 |
+
continue
|
234 |
+
else:
|
235 |
+
start_position = -1
|
236 |
+
end_position = -1
|
237 |
+
orig_answer_text = ""
|
238 |
+
|
239 |
+
example = SquadExample(
|
240 |
+
qas_id=qas_id,
|
241 |
+
question_text=question_text,
|
242 |
+
doc_tokens=doc_tokens,
|
243 |
+
orig_answer_text=orig_answer_text,
|
244 |
+
start_position=start_position,
|
245 |
+
end_position=end_position,
|
246 |
+
is_impossible=is_impossible)
|
247 |
+
examples.append(example)
|
248 |
+
|
249 |
+
return examples
|
250 |
+
|
251 |
+
|
252 |
+
def convert_examples_to_features(examples,
|
253 |
+
tokenizer,
|
254 |
+
max_seq_length,
|
255 |
+
doc_stride,
|
256 |
+
max_query_length,
|
257 |
+
is_training,
|
258 |
+
output_fn,
|
259 |
+
xlnet_format=False,
|
260 |
+
batch_size=None):
|
261 |
+
"""Loads a data file into a list of `InputBatch`s."""
|
262 |
+
|
263 |
+
base_id = 1000000000
|
264 |
+
unique_id = base_id
|
265 |
+
feature = None
|
266 |
+
for (example_index, example) in enumerate(examples):
|
267 |
+
query_tokens = tokenizer.tokenize(example.question_text)
|
268 |
+
|
269 |
+
if len(query_tokens) > max_query_length:
|
270 |
+
query_tokens = query_tokens[0:max_query_length]
|
271 |
+
|
272 |
+
tok_to_orig_index = []
|
273 |
+
orig_to_tok_index = []
|
274 |
+
all_doc_tokens = []
|
275 |
+
for (i, token) in enumerate(example.doc_tokens):
|
276 |
+
orig_to_tok_index.append(len(all_doc_tokens))
|
277 |
+
sub_tokens = tokenizer.tokenize(token)
|
278 |
+
for sub_token in sub_tokens:
|
279 |
+
tok_to_orig_index.append(i)
|
280 |
+
all_doc_tokens.append(sub_token)
|
281 |
+
|
282 |
+
tok_start_position = None
|
283 |
+
tok_end_position = None
|
284 |
+
if is_training and example.is_impossible:
|
285 |
+
tok_start_position = -1
|
286 |
+
tok_end_position = -1
|
287 |
+
if is_training and not example.is_impossible:
|
288 |
+
tok_start_position = orig_to_tok_index[example.start_position]
|
289 |
+
if example.end_position < len(example.doc_tokens) - 1:
|
290 |
+
tok_end_position = orig_to_tok_index[example.end_position + 1] - 1
|
291 |
+
else:
|
292 |
+
tok_end_position = len(all_doc_tokens) - 1
|
293 |
+
(tok_start_position, tok_end_position) = _improve_answer_span(
|
294 |
+
all_doc_tokens, tok_start_position, tok_end_position, tokenizer,
|
295 |
+
example.orig_answer_text)
|
296 |
+
|
297 |
+
# The -3 accounts for [CLS], [SEP] and [SEP]
|
298 |
+
max_tokens_for_doc = max_seq_length - len(query_tokens) - 3
|
299 |
+
|
300 |
+
# We can have documents that are longer than the maximum sequence length.
|
301 |
+
# To deal with this we do a sliding window approach, where we take chunks
|
302 |
+
# of the up to our max length with a stride of `doc_stride`.
|
303 |
+
_DocSpan = collections.namedtuple( # pylint: disable=invalid-name
|
304 |
+
"DocSpan", ["start", "length"])
|
305 |
+
doc_spans = []
|
306 |
+
start_offset = 0
|
307 |
+
while start_offset < len(all_doc_tokens):
|
308 |
+
length = len(all_doc_tokens) - start_offset
|
309 |
+
if length > max_tokens_for_doc:
|
310 |
+
length = max_tokens_for_doc
|
311 |
+
doc_spans.append(_DocSpan(start=start_offset, length=length))
|
312 |
+
if start_offset + length == len(all_doc_tokens):
|
313 |
+
break
|
314 |
+
start_offset += min(length, doc_stride)
|
315 |
+
|
316 |
+
for (doc_span_index, doc_span) in enumerate(doc_spans):
|
317 |
+
tokens = []
|
318 |
+
token_to_orig_map = {}
|
319 |
+
token_is_max_context = {}
|
320 |
+
segment_ids = []
|
321 |
+
|
322 |
+
# Paragraph mask used in XLNet.
|
323 |
+
# 1 represents paragraph and class tokens.
|
324 |
+
# 0 represents query and other special tokens.
|
325 |
+
paragraph_mask = []
|
326 |
+
|
327 |
+
# pylint: disable=cell-var-from-loop
|
328 |
+
def process_query(seg_q):
|
329 |
+
for token in query_tokens:
|
330 |
+
tokens.append(token)
|
331 |
+
segment_ids.append(seg_q)
|
332 |
+
paragraph_mask.append(0)
|
333 |
+
tokens.append("[SEP]")
|
334 |
+
segment_ids.append(seg_q)
|
335 |
+
paragraph_mask.append(0)
|
336 |
+
|
337 |
+
def process_paragraph(seg_p):
|
338 |
+
for i in range(doc_span.length):
|
339 |
+
split_token_index = doc_span.start + i
|
340 |
+
token_to_orig_map[len(tokens)] = tok_to_orig_index[split_token_index]
|
341 |
+
|
342 |
+
is_max_context = _check_is_max_context(doc_spans, doc_span_index,
|
343 |
+
split_token_index)
|
344 |
+
token_is_max_context[len(tokens)] = is_max_context
|
345 |
+
tokens.append(all_doc_tokens[split_token_index])
|
346 |
+
segment_ids.append(seg_p)
|
347 |
+
paragraph_mask.append(1)
|
348 |
+
tokens.append("[SEP]")
|
349 |
+
segment_ids.append(seg_p)
|
350 |
+
paragraph_mask.append(0)
|
351 |
+
|
352 |
+
def process_class(seg_class):
|
353 |
+
class_index = len(segment_ids)
|
354 |
+
tokens.append("[CLS]")
|
355 |
+
segment_ids.append(seg_class)
|
356 |
+
paragraph_mask.append(1)
|
357 |
+
return class_index
|
358 |
+
|
359 |
+
if xlnet_format:
|
360 |
+
seg_p, seg_q, seg_class, seg_pad = 0, 1, 2, 3
|
361 |
+
process_paragraph(seg_p)
|
362 |
+
process_query(seg_q)
|
363 |
+
class_index = process_class(seg_class)
|
364 |
+
else:
|
365 |
+
seg_p, seg_q, seg_class, seg_pad = 1, 0, 0, 0
|
366 |
+
class_index = process_class(seg_class)
|
367 |
+
process_query(seg_q)
|
368 |
+
process_paragraph(seg_p)
|
369 |
+
|
370 |
+
input_ids = tokenizer.convert_tokens_to_ids(tokens)
|
371 |
+
|
372 |
+
# The mask has 1 for real tokens and 0 for padding tokens. Only real
|
373 |
+
# tokens are attended to.
|
374 |
+
input_mask = [1] * len(input_ids)
|
375 |
+
|
376 |
+
# Zero-pad up to the sequence length.
|
377 |
+
while len(input_ids) < max_seq_length:
|
378 |
+
input_ids.append(0)
|
379 |
+
input_mask.append(0)
|
380 |
+
segment_ids.append(seg_pad)
|
381 |
+
paragraph_mask.append(0)
|
382 |
+
|
383 |
+
assert len(input_ids) == max_seq_length
|
384 |
+
assert len(input_mask) == max_seq_length
|
385 |
+
assert len(segment_ids) == max_seq_length
|
386 |
+
assert len(paragraph_mask) == max_seq_length
|
387 |
+
|
388 |
+
start_position = 0
|
389 |
+
end_position = 0
|
390 |
+
span_contains_answer = False
|
391 |
+
|
392 |
+
if is_training and not example.is_impossible:
|
393 |
+
# For training, if our document chunk does not contain an annotation
|
394 |
+
# we throw it out, since there is nothing to predict.
|
395 |
+
doc_start = doc_span.start
|
396 |
+
doc_end = doc_span.start + doc_span.length - 1
|
397 |
+
span_contains_answer = (tok_start_position >= doc_start and
|
398 |
+
tok_end_position <= doc_end)
|
399 |
+
if span_contains_answer:
|
400 |
+
doc_offset = 0 if xlnet_format else len(query_tokens) + 2
|
401 |
+
start_position = tok_start_position - doc_start + doc_offset
|
402 |
+
end_position = tok_end_position - doc_start + doc_offset
|
403 |
+
|
404 |
+
if example_index < 20:
|
405 |
+
logging.info("*** Example ***")
|
406 |
+
logging.info("unique_id: %s", (unique_id))
|
407 |
+
logging.info("example_index: %s", (example_index))
|
408 |
+
logging.info("doc_span_index: %s", (doc_span_index))
|
409 |
+
logging.info("tokens: %s",
|
410 |
+
" ".join([tokenization.printable_text(x) for x in tokens]))
|
411 |
+
logging.info(
|
412 |
+
"token_to_orig_map: %s", " ".join([
|
413 |
+
"%d:%d" % (x, y) for (x, y) in six.iteritems(token_to_orig_map)
|
414 |
+
]))
|
415 |
+
logging.info(
|
416 |
+
"token_is_max_context: %s", " ".join([
|
417 |
+
"%d:%s" % (x, y)
|
418 |
+
for (x, y) in six.iteritems(token_is_max_context)
|
419 |
+
]))
|
420 |
+
logging.info("input_ids: %s", " ".join([str(x) for x in input_ids]))
|
421 |
+
logging.info("input_mask: %s", " ".join([str(x) for x in input_mask]))
|
422 |
+
logging.info("segment_ids: %s", " ".join([str(x) for x in segment_ids]))
|
423 |
+
logging.info("paragraph_mask: %s", " ".join(
|
424 |
+
[str(x) for x in paragraph_mask]))
|
425 |
+
logging.info("class_index: %d", class_index)
|
426 |
+
if is_training:
|
427 |
+
if span_contains_answer:
|
428 |
+
answer_text = " ".join(tokens[start_position:(end_position + 1)])
|
429 |
+
logging.info("start_position: %d", (start_position))
|
430 |
+
logging.info("end_position: %d", (end_position))
|
431 |
+
logging.info("answer: %s", tokenization.printable_text(answer_text))
|
432 |
+
else:
|
433 |
+
logging.info("document span doesn't contain answer")
|
434 |
+
|
435 |
+
feature = InputFeatures(
|
436 |
+
unique_id=unique_id,
|
437 |
+
example_index=example_index,
|
438 |
+
doc_span_index=doc_span_index,
|
439 |
+
tokens=tokens,
|
440 |
+
paragraph_mask=paragraph_mask,
|
441 |
+
class_index=class_index,
|
442 |
+
token_to_orig_map=token_to_orig_map,
|
443 |
+
token_is_max_context=token_is_max_context,
|
444 |
+
input_ids=input_ids,
|
445 |
+
input_mask=input_mask,
|
446 |
+
segment_ids=segment_ids,
|
447 |
+
start_position=start_position,
|
448 |
+
end_position=end_position,
|
449 |
+
is_impossible=not span_contains_answer)
|
450 |
+
|
451 |
+
# Run callback
|
452 |
+
if is_training:
|
453 |
+
output_fn(feature)
|
454 |
+
else:
|
455 |
+
output_fn(feature, is_padding=False)
|
456 |
+
|
457 |
+
unique_id += 1
|
458 |
+
|
459 |
+
if not is_training and feature:
|
460 |
+
assert batch_size
|
461 |
+
num_padding = 0
|
462 |
+
num_examples = unique_id - base_id
|
463 |
+
if unique_id % batch_size != 0:
|
464 |
+
num_padding = batch_size - (num_examples % batch_size)
|
465 |
+
logging.info("Adding padding examples to make sure no partial batch.")
|
466 |
+
logging.info("Adds %d padding examples for inference.", num_padding)
|
467 |
+
dummy_feature = copy.deepcopy(feature)
|
468 |
+
for _ in range(num_padding):
|
469 |
+
dummy_feature.unique_id = unique_id
|
470 |
+
|
471 |
+
# Run callback
|
472 |
+
output_fn(feature, is_padding=True)
|
473 |
+
unique_id += 1
|
474 |
+
return unique_id - base_id
|
475 |
+
|
476 |
+
|
477 |
+
def _improve_answer_span(doc_tokens, input_start, input_end, tokenizer,
|
478 |
+
orig_answer_text):
|
479 |
+
"""Returns tokenized answer spans that better match the annotated answer."""
|
480 |
+
|
481 |
+
# The SQuAD annotations are character based. We first project them to
|
482 |
+
# whitespace-tokenized words. But then after WordPiece tokenization, we can
|
483 |
+
# often find a "better match". For example:
|
484 |
+
#
|
485 |
+
# Question: What year was John Smith born?
|
486 |
+
# Context: The leader was John Smith (1895-1943).
|
487 |
+
# Answer: 1895
|
488 |
+
#
|
489 |
+
# The original whitespace-tokenized answer will be "(1895-1943).". However
|
490 |
+
# after tokenization, our tokens will be "( 1895 - 1943 ) .". So we can match
|
491 |
+
# the exact answer, 1895.
|
492 |
+
#
|
493 |
+
# However, this is not always possible. Consider the following:
|
494 |
+
#
|
495 |
+
# Question: What country is the top exporter of electronics?
|
496 |
+
# Context: The Japanese electronics industry is the lagest in the world.
|
497 |
+
# Answer: Japan
|
498 |
+
#
|
499 |
+
# In this case, the annotator chose "Japan" as a character sub-span of
|
500 |
+
# the word "Japanese". Since our WordPiece tokenizer does not split
|
501 |
+
# "Japanese", we just use "Japanese" as the annotation. This is fairly rare
|
502 |
+
# in SQuAD, but does happen.
|
503 |
+
tok_answer_text = " ".join(tokenizer.tokenize(orig_answer_text))
|
504 |
+
|
505 |
+
for new_start in range(input_start, input_end + 1):
|
506 |
+
for new_end in range(input_end, new_start - 1, -1):
|
507 |
+
text_span = " ".join(doc_tokens[new_start:(new_end + 1)])
|
508 |
+
if text_span == tok_answer_text:
|
509 |
+
return (new_start, new_end)
|
510 |
+
|
511 |
+
return (input_start, input_end)
|
512 |
+
|
513 |
+
|
514 |
+
def _check_is_max_context(doc_spans, cur_span_index, position):
|
515 |
+
"""Check if this is the 'max context' doc span for the token."""
|
516 |
+
|
517 |
+
# Because of the sliding window approach taken to scoring documents, a single
|
518 |
+
# token can appear in multiple documents. E.g.
|
519 |
+
# Doc: the man went to the store and bought a gallon of milk
|
520 |
+
# Span A: the man went to the
|
521 |
+
# Span B: to the store and bought
|
522 |
+
# Span C: and bought a gallon of
|
523 |
+
# ...
|
524 |
+
#
|
525 |
+
# Now the word 'bought' will have two scores from spans B and C. We only
|
526 |
+
# want to consider the score with "maximum context", which we define as
|
527 |
+
# the *minimum* of its left and right context (the *sum* of left and
|
528 |
+
# right context will always be the same, of course).
|
529 |
+
#
|
530 |
+
# In the example the maximum context for 'bought' would be span C since
|
531 |
+
# it has 1 left context and 3 right context, while span B has 4 left context
|
532 |
+
# and 0 right context.
|
533 |
+
best_score = None
|
534 |
+
best_span_index = None
|
535 |
+
for (span_index, doc_span) in enumerate(doc_spans):
|
536 |
+
end = doc_span.start + doc_span.length - 1
|
537 |
+
if position < doc_span.start:
|
538 |
+
continue
|
539 |
+
if position > end:
|
540 |
+
continue
|
541 |
+
num_left_context = position - doc_span.start
|
542 |
+
num_right_context = end - position
|
543 |
+
score = min(num_left_context, num_right_context) + 0.01 * doc_span.length
|
544 |
+
if best_score is None or score > best_score:
|
545 |
+
best_score = score
|
546 |
+
best_span_index = span_index
|
547 |
+
|
548 |
+
return cur_span_index == best_span_index
|
549 |
+
|
550 |
+
|
551 |
+
def write_predictions(all_examples,
|
552 |
+
all_features,
|
553 |
+
all_results,
|
554 |
+
n_best_size,
|
555 |
+
max_answer_length,
|
556 |
+
do_lower_case,
|
557 |
+
output_prediction_file,
|
558 |
+
output_nbest_file,
|
559 |
+
output_null_log_odds_file,
|
560 |
+
version_2_with_negative=False,
|
561 |
+
null_score_diff_threshold=0.0,
|
562 |
+
verbose=False):
|
563 |
+
"""Write final predictions to the json file and log-odds of null if needed."""
|
564 |
+
logging.info("Writing predictions to: %s", (output_prediction_file))
|
565 |
+
logging.info("Writing nbest to: %s", (output_nbest_file))
|
566 |
+
|
567 |
+
all_predictions, all_nbest_json, scores_diff_json = (
|
568 |
+
postprocess_output(
|
569 |
+
all_examples=all_examples,
|
570 |
+
all_features=all_features,
|
571 |
+
all_results=all_results,
|
572 |
+
n_best_size=n_best_size,
|
573 |
+
max_answer_length=max_answer_length,
|
574 |
+
do_lower_case=do_lower_case,
|
575 |
+
version_2_with_negative=version_2_with_negative,
|
576 |
+
null_score_diff_threshold=null_score_diff_threshold,
|
577 |
+
verbose=verbose))
|
578 |
+
|
579 |
+
write_to_json_files(all_predictions, output_prediction_file)
|
580 |
+
write_to_json_files(all_nbest_json, output_nbest_file)
|
581 |
+
if version_2_with_negative:
|
582 |
+
write_to_json_files(scores_diff_json, output_null_log_odds_file)
|
583 |
+
|
584 |
+
|
585 |
+
def postprocess_output(all_examples,
|
586 |
+
all_features,
|
587 |
+
all_results,
|
588 |
+
n_best_size,
|
589 |
+
max_answer_length,
|
590 |
+
do_lower_case,
|
591 |
+
version_2_with_negative=False,
|
592 |
+
null_score_diff_threshold=0.0,
|
593 |
+
xlnet_format=False,
|
594 |
+
verbose=False):
|
595 |
+
"""Postprocess model output, to form predicton results."""
|
596 |
+
|
597 |
+
example_index_to_features = collections.defaultdict(list)
|
598 |
+
for feature in all_features:
|
599 |
+
example_index_to_features[feature.example_index].append(feature)
|
600 |
+
unique_id_to_result = {}
|
601 |
+
for result in all_results:
|
602 |
+
unique_id_to_result[result.unique_id] = result
|
603 |
+
|
604 |
+
_PrelimPrediction = collections.namedtuple( # pylint: disable=invalid-name
|
605 |
+
"PrelimPrediction",
|
606 |
+
["feature_index", "start_index", "end_index", "start_logit", "end_logit"])
|
607 |
+
|
608 |
+
all_predictions = collections.OrderedDict()
|
609 |
+
all_nbest_json = collections.OrderedDict()
|
610 |
+
scores_diff_json = collections.OrderedDict()
|
611 |
+
|
612 |
+
for (example_index, example) in enumerate(all_examples):
|
613 |
+
features = example_index_to_features[example_index]
|
614 |
+
|
615 |
+
prelim_predictions = []
|
616 |
+
# keep track of the minimum score of null start+end of position 0
|
617 |
+
score_null = 1000000 # large and positive
|
618 |
+
min_null_feature_index = 0 # the paragraph slice with min mull score
|
619 |
+
null_start_logit = 0 # the start logit at the slice with min null score
|
620 |
+
null_end_logit = 0 # the end logit at the slice with min null score
|
621 |
+
for (feature_index, feature) in enumerate(features):
|
622 |
+
if feature.unique_id not in unique_id_to_result:
|
623 |
+
logging.info("Skip eval example %s, not in pred.", feature.unique_id)
|
624 |
+
continue
|
625 |
+
result = unique_id_to_result[feature.unique_id]
|
626 |
+
|
627 |
+
# if we could have irrelevant answers, get the min score of irrelevant
|
628 |
+
if version_2_with_negative:
|
629 |
+
if xlnet_format:
|
630 |
+
feature_null_score = result.class_logits
|
631 |
+
else:
|
632 |
+
feature_null_score = result.start_logits[0] + result.end_logits[0]
|
633 |
+
if feature_null_score < score_null:
|
634 |
+
score_null = feature_null_score
|
635 |
+
min_null_feature_index = feature_index
|
636 |
+
null_start_logit = result.start_logits[0]
|
637 |
+
null_end_logit = result.end_logits[0]
|
638 |
+
for (start_index, start_logit,
|
639 |
+
end_index, end_logit) in _get_best_indexes_and_logits(
|
640 |
+
result=result,
|
641 |
+
n_best_size=n_best_size,
|
642 |
+
xlnet_format=xlnet_format):
|
643 |
+
# We could hypothetically create invalid predictions, e.g., predict
|
644 |
+
# that the start of the span is in the question. We throw out all
|
645 |
+
# invalid predictions.
|
646 |
+
if start_index >= len(feature.tokens):
|
647 |
+
continue
|
648 |
+
if end_index >= len(feature.tokens):
|
649 |
+
continue
|
650 |
+
if start_index not in feature.token_to_orig_map:
|
651 |
+
continue
|
652 |
+
if end_index not in feature.token_to_orig_map:
|
653 |
+
continue
|
654 |
+
if not feature.token_is_max_context.get(start_index, False):
|
655 |
+
continue
|
656 |
+
if end_index < start_index:
|
657 |
+
continue
|
658 |
+
length = end_index - start_index + 1
|
659 |
+
if length > max_answer_length:
|
660 |
+
continue
|
661 |
+
prelim_predictions.append(
|
662 |
+
_PrelimPrediction(
|
663 |
+
feature_index=feature_index,
|
664 |
+
start_index=start_index,
|
665 |
+
end_index=end_index,
|
666 |
+
start_logit=start_logit,
|
667 |
+
end_logit=end_logit))
|
668 |
+
|
669 |
+
if version_2_with_negative and not xlnet_format:
|
670 |
+
prelim_predictions.append(
|
671 |
+
_PrelimPrediction(
|
672 |
+
feature_index=min_null_feature_index,
|
673 |
+
start_index=0,
|
674 |
+
end_index=0,
|
675 |
+
start_logit=null_start_logit,
|
676 |
+
end_logit=null_end_logit))
|
677 |
+
prelim_predictions = sorted(
|
678 |
+
prelim_predictions,
|
679 |
+
key=lambda x: (x.start_logit + x.end_logit),
|
680 |
+
reverse=True)
|
681 |
+
|
682 |
+
_NbestPrediction = collections.namedtuple( # pylint: disable=invalid-name
|
683 |
+
"NbestPrediction", ["text", "start_logit", "end_logit"])
|
684 |
+
|
685 |
+
seen_predictions = {}
|
686 |
+
nbest = []
|
687 |
+
for pred in prelim_predictions:
|
688 |
+
if len(nbest) >= n_best_size:
|
689 |
+
break
|
690 |
+
feature = features[pred.feature_index]
|
691 |
+
if pred.start_index > 0 or xlnet_format: # this is a non-null prediction
|
692 |
+
tok_tokens = feature.tokens[pred.start_index:(pred.end_index + 1)]
|
693 |
+
orig_doc_start = feature.token_to_orig_map[pred.start_index]
|
694 |
+
orig_doc_end = feature.token_to_orig_map[pred.end_index]
|
695 |
+
orig_tokens = example.doc_tokens[orig_doc_start:(orig_doc_end + 1)]
|
696 |
+
tok_text = " ".join(tok_tokens)
|
697 |
+
|
698 |
+
# De-tokenize WordPieces that have been split off.
|
699 |
+
tok_text = tok_text.replace(" ##", "")
|
700 |
+
tok_text = tok_text.replace("##", "")
|
701 |
+
|
702 |
+
# Clean whitespace
|
703 |
+
tok_text = tok_text.strip()
|
704 |
+
tok_text = " ".join(tok_text.split())
|
705 |
+
orig_text = " ".join(orig_tokens)
|
706 |
+
|
707 |
+
final_text = get_final_text(
|
708 |
+
tok_text, orig_text, do_lower_case, verbose=verbose)
|
709 |
+
if final_text in seen_predictions:
|
710 |
+
continue
|
711 |
+
|
712 |
+
seen_predictions[final_text] = True
|
713 |
+
else:
|
714 |
+
final_text = ""
|
715 |
+
seen_predictions[final_text] = True
|
716 |
+
|
717 |
+
nbest.append(
|
718 |
+
_NbestPrediction(
|
719 |
+
text=final_text,
|
720 |
+
start_logit=pred.start_logit,
|
721 |
+
end_logit=pred.end_logit))
|
722 |
+
|
723 |
+
# if we didn't include the empty option in the n-best, include it
|
724 |
+
if version_2_with_negative and not xlnet_format:
|
725 |
+
if "" not in seen_predictions:
|
726 |
+
nbest.append(
|
727 |
+
_NbestPrediction(
|
728 |
+
text="", start_logit=null_start_logit,
|
729 |
+
end_logit=null_end_logit))
|
730 |
+
# In very rare edge cases we could have no valid predictions. So we
|
731 |
+
# just create a nonce prediction in this case to avoid failure.
|
732 |
+
if not nbest:
|
733 |
+
nbest.append(
|
734 |
+
_NbestPrediction(text="empty", start_logit=0.0, end_logit=0.0))
|
735 |
+
|
736 |
+
assert len(nbest) >= 1
|
737 |
+
|
738 |
+
total_scores = []
|
739 |
+
best_non_null_entry = None
|
740 |
+
for entry in nbest:
|
741 |
+
total_scores.append(entry.start_logit + entry.end_logit)
|
742 |
+
if not best_non_null_entry:
|
743 |
+
if entry.text:
|
744 |
+
best_non_null_entry = entry
|
745 |
+
|
746 |
+
probs = _compute_softmax(total_scores)
|
747 |
+
|
748 |
+
nbest_json = []
|
749 |
+
for (i, entry) in enumerate(nbest):
|
750 |
+
output = collections.OrderedDict()
|
751 |
+
output["text"] = entry.text
|
752 |
+
output["probability"] = probs[i]
|
753 |
+
output["start_logit"] = entry.start_logit
|
754 |
+
output["end_logit"] = entry.end_logit
|
755 |
+
nbest_json.append(output)
|
756 |
+
|
757 |
+
assert len(nbest_json) >= 1
|
758 |
+
|
759 |
+
if not version_2_with_negative:
|
760 |
+
all_predictions[example.qas_id] = nbest_json[0]["text"]
|
761 |
+
else:
|
762 |
+
# pytype: disable=attribute-error
|
763 |
+
# predict "" iff the null score - the score of best non-null > threshold
|
764 |
+
if best_non_null_entry is not None:
|
765 |
+
if xlnet_format:
|
766 |
+
score_diff = score_null
|
767 |
+
scores_diff_json[example.qas_id] = score_diff
|
768 |
+
all_predictions[example.qas_id] = best_non_null_entry.text
|
769 |
+
else:
|
770 |
+
score_diff = score_null - best_non_null_entry.start_logit - (
|
771 |
+
best_non_null_entry.end_logit)
|
772 |
+
scores_diff_json[example.qas_id] = score_diff
|
773 |
+
if score_diff > null_score_diff_threshold:
|
774 |
+
all_predictions[example.qas_id] = ""
|
775 |
+
else:
|
776 |
+
all_predictions[example.qas_id] = best_non_null_entry.text
|
777 |
+
else:
|
778 |
+
logging.warning("best_non_null_entry is None")
|
779 |
+
scores_diff_json[example.qas_id] = score_null
|
780 |
+
all_predictions[example.qas_id] = ""
|
781 |
+
# pytype: enable=attribute-error
|
782 |
+
|
783 |
+
all_nbest_json[example.qas_id] = nbest_json
|
784 |
+
|
785 |
+
return all_predictions, all_nbest_json, scores_diff_json
|
786 |
+
|
787 |
+
|
788 |
+
def write_to_json_files(json_records, json_file):
|
789 |
+
with tf.io.gfile.GFile(json_file, "w") as writer:
|
790 |
+
writer.write(json.dumps(json_records, indent=4) + "\n")
|
791 |
+
|
792 |
+
|
793 |
+
def get_final_text(pred_text, orig_text, do_lower_case, verbose=False):
|
794 |
+
"""Project the tokenized prediction back to the original text."""
|
795 |
+
|
796 |
+
# When we created the data, we kept track of the alignment between original
|
797 |
+
# (whitespace tokenized) tokens and our WordPiece tokenized tokens. So
|
798 |
+
# now `orig_text` contains the span of our original text corresponding to the
|
799 |
+
# span that we predicted.
|
800 |
+
#
|
801 |
+
# However, `orig_text` may contain extra characters that we don't want in
|
802 |
+
# our prediction.
|
803 |
+
#
|
804 |
+
# For example, let's say:
|
805 |
+
# pred_text = steve smith
|
806 |
+
# orig_text = Steve Smith's
|
807 |
+
#
|
808 |
+
# We don't want to return `orig_text` because it contains the extra "'s".
|
809 |
+
#
|
810 |
+
# We don't want to return `pred_text` because it's already been normalized
|
811 |
+
# (the SQuAD eval script also does punctuation stripping/lower casing but
|
812 |
+
# our tokenizer does additional normalization like stripping accent
|
813 |
+
# characters).
|
814 |
+
#
|
815 |
+
# What we really want to return is "Steve Smith".
|
816 |
+
#
|
817 |
+
# Therefore, we have to apply a semi-complicated alignment heruistic between
|
818 |
+
# `pred_text` and `orig_text` to get a character-to-character alignment. This
|
819 |
+
# can fail in certain cases in which case we just return `orig_text`.
|
820 |
+
|
821 |
+
def _strip_spaces(text):
|
822 |
+
ns_chars = []
|
823 |
+
ns_to_s_map = collections.OrderedDict()
|
824 |
+
for (i, c) in enumerate(text):
|
825 |
+
if c == " ":
|
826 |
+
continue
|
827 |
+
ns_to_s_map[len(ns_chars)] = i
|
828 |
+
ns_chars.append(c)
|
829 |
+
ns_text = "".join(ns_chars)
|
830 |
+
return (ns_text, ns_to_s_map)
|
831 |
+
|
832 |
+
# We first tokenize `orig_text`, strip whitespace from the result
|
833 |
+
# and `pred_text`, and check if they are the same length. If they are
|
834 |
+
# NOT the same length, the heuristic has failed. If they are the same
|
835 |
+
# length, we assume the characters are one-to-one aligned.
|
836 |
+
tokenizer = tokenization.BasicTokenizer(do_lower_case=do_lower_case)
|
837 |
+
|
838 |
+
tok_text = " ".join(tokenizer.tokenize(orig_text))
|
839 |
+
|
840 |
+
start_position = tok_text.find(pred_text)
|
841 |
+
if start_position == -1:
|
842 |
+
if verbose:
|
843 |
+
logging.info("Unable to find text: '%s' in '%s'", pred_text, orig_text)
|
844 |
+
return orig_text
|
845 |
+
end_position = start_position + len(pred_text) - 1
|
846 |
+
|
847 |
+
(orig_ns_text, orig_ns_to_s_map) = _strip_spaces(orig_text)
|
848 |
+
(tok_ns_text, tok_ns_to_s_map) = _strip_spaces(tok_text)
|
849 |
+
|
850 |
+
if len(orig_ns_text) != len(tok_ns_text):
|
851 |
+
if verbose:
|
852 |
+
logging.info("Length not equal after stripping spaces: '%s' vs '%s'",
|
853 |
+
orig_ns_text, tok_ns_text)
|
854 |
+
return orig_text
|
855 |
+
|
856 |
+
# We then project the characters in `pred_text` back to `orig_text` using
|
857 |
+
# the character-to-character alignment.
|
858 |
+
tok_s_to_ns_map = {}
|
859 |
+
for (i, tok_index) in six.iteritems(tok_ns_to_s_map):
|
860 |
+
tok_s_to_ns_map[tok_index] = i
|
861 |
+
|
862 |
+
orig_start_position = None
|
863 |
+
if start_position in tok_s_to_ns_map:
|
864 |
+
ns_start_position = tok_s_to_ns_map[start_position]
|
865 |
+
if ns_start_position in orig_ns_to_s_map:
|
866 |
+
orig_start_position = orig_ns_to_s_map[ns_start_position]
|
867 |
+
|
868 |
+
if orig_start_position is None:
|
869 |
+
if verbose:
|
870 |
+
logging.info("Couldn't map start position")
|
871 |
+
return orig_text
|
872 |
+
|
873 |
+
orig_end_position = None
|
874 |
+
if end_position in tok_s_to_ns_map:
|
875 |
+
ns_end_position = tok_s_to_ns_map[end_position]
|
876 |
+
if ns_end_position in orig_ns_to_s_map:
|
877 |
+
orig_end_position = orig_ns_to_s_map[ns_end_position]
|
878 |
+
|
879 |
+
if orig_end_position is None:
|
880 |
+
if verbose:
|
881 |
+
logging.info("Couldn't map end position")
|
882 |
+
return orig_text
|
883 |
+
|
884 |
+
output_text = orig_text[orig_start_position:(orig_end_position + 1)]
|
885 |
+
return output_text
|
886 |
+
|
887 |
+
|
888 |
+
def _get_best_indexes_and_logits(result,
|
889 |
+
n_best_size,
|
890 |
+
xlnet_format=False):
|
891 |
+
"""Generates the n-best indexes and logits from a list."""
|
892 |
+
if xlnet_format:
|
893 |
+
for i in range(n_best_size):
|
894 |
+
for j in range(n_best_size):
|
895 |
+
j_index = i * n_best_size + j
|
896 |
+
yield (result.start_indexes[i], result.start_logits[i],
|
897 |
+
result.end_indexes[j_index], result.end_logits[j_index])
|
898 |
+
else:
|
899 |
+
start_index_and_score = sorted(enumerate(result.start_logits),
|
900 |
+
key=lambda x: x[1], reverse=True)
|
901 |
+
end_index_and_score = sorted(enumerate(result.end_logits),
|
902 |
+
key=lambda x: x[1], reverse=True)
|
903 |
+
for i in range(len(start_index_and_score)):
|
904 |
+
if i >= n_best_size:
|
905 |
+
break
|
906 |
+
for j in range(len(end_index_and_score)):
|
907 |
+
if j >= n_best_size:
|
908 |
+
break
|
909 |
+
yield (start_index_and_score[i][0], start_index_and_score[i][1],
|
910 |
+
end_index_and_score[j][0], end_index_and_score[j][1])
|
911 |
+
|
912 |
+
|
913 |
+
def _compute_softmax(scores):
|
914 |
+
"""Compute softmax probability over raw logits."""
|
915 |
+
if not scores:
|
916 |
+
return []
|
917 |
+
|
918 |
+
max_score = None
|
919 |
+
for score in scores:
|
920 |
+
if max_score is None or score > max_score:
|
921 |
+
max_score = score
|
922 |
+
|
923 |
+
exp_scores = []
|
924 |
+
total_sum = 0.0
|
925 |
+
for score in scores:
|
926 |
+
x = math.exp(score - max_score)
|
927 |
+
exp_scores.append(x)
|
928 |
+
total_sum += x
|
929 |
+
|
930 |
+
probs = []
|
931 |
+
for score in exp_scores:
|
932 |
+
probs.append(score / total_sum)
|
933 |
+
return probs
|
934 |
+
|
935 |
+
|
936 |
+
def generate_tf_record_from_json_file(input_file_path,
|
937 |
+
vocab_file_path,
|
938 |
+
output_path,
|
939 |
+
translated_input_folder=None,
|
940 |
+
max_seq_length=384,
|
941 |
+
do_lower_case=True,
|
942 |
+
max_query_length=64,
|
943 |
+
doc_stride=128,
|
944 |
+
version_2_with_negative=False,
|
945 |
+
xlnet_format=False):
|
946 |
+
"""Generates and saves training data into a tf record file."""
|
947 |
+
train_examples = read_squad_examples(
|
948 |
+
input_file=input_file_path,
|
949 |
+
is_training=True,
|
950 |
+
version_2_with_negative=version_2_with_negative,
|
951 |
+
translated_input_folder=translated_input_folder)
|
952 |
+
tokenizer = tokenization.FullTokenizer(
|
953 |
+
vocab_file=vocab_file_path, do_lower_case=do_lower_case)
|
954 |
+
train_writer = FeatureWriter(filename=output_path, is_training=True)
|
955 |
+
number_of_examples = convert_examples_to_features(
|
956 |
+
examples=train_examples,
|
957 |
+
tokenizer=tokenizer,
|
958 |
+
max_seq_length=max_seq_length,
|
959 |
+
doc_stride=doc_stride,
|
960 |
+
max_query_length=max_query_length,
|
961 |
+
is_training=True,
|
962 |
+
output_fn=train_writer.process_feature,
|
963 |
+
xlnet_format=xlnet_format)
|
964 |
+
train_writer.close()
|
965 |
+
|
966 |
+
meta_data = {
|
967 |
+
"task_type": "bert_squad",
|
968 |
+
"train_data_size": number_of_examples,
|
969 |
+
"max_seq_length": max_seq_length,
|
970 |
+
"max_query_length": max_query_length,
|
971 |
+
"doc_stride": doc_stride,
|
972 |
+
"version_2_with_negative": version_2_with_negative,
|
973 |
+
}
|
974 |
+
|
975 |
+
return meta_data
|
squad_lib_sp.py
ADDED
@@ -0,0 +1,976 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
"""Run ALBERT on SQuAD 1.1 and SQuAD 2.0 using sentence piece tokenization.
|
16 |
+
|
17 |
+
The file is forked from:
|
18 |
+
|
19 |
+
https://github.com/google-research/ALBERT/blob/master/run_squad_sp.py
|
20 |
+
"""
|
21 |
+
import collections
|
22 |
+
import copy
|
23 |
+
import json
|
24 |
+
import math
|
25 |
+
import os
|
26 |
+
|
27 |
+
from absl import logging
|
28 |
+
import numpy as np
|
29 |
+
import tensorflow as tf, tf_keras
|
30 |
+
|
31 |
+
from official.nlp.tools import tokenization
|
32 |
+
|
33 |
+
|
34 |
+
class SquadExample(object):
|
35 |
+
"""A single training/test example for simple sequence classification.
|
36 |
+
|
37 |
+
For examples without an answer, the start and end position are -1.
|
38 |
+
"""
|
39 |
+
|
40 |
+
def __init__(self,
|
41 |
+
qas_id,
|
42 |
+
question_text,
|
43 |
+
paragraph_text,
|
44 |
+
orig_answer_text=None,
|
45 |
+
start_position=None,
|
46 |
+
end_position=None,
|
47 |
+
is_impossible=False):
|
48 |
+
self.qas_id = qas_id
|
49 |
+
self.question_text = question_text
|
50 |
+
self.paragraph_text = paragraph_text
|
51 |
+
self.orig_answer_text = orig_answer_text
|
52 |
+
self.start_position = start_position
|
53 |
+
self.end_position = end_position
|
54 |
+
self.is_impossible = is_impossible
|
55 |
+
|
56 |
+
def __str__(self):
|
57 |
+
return self.__repr__()
|
58 |
+
|
59 |
+
def __repr__(self):
|
60 |
+
s = ""
|
61 |
+
s += "qas_id: %s" % (tokenization.printable_text(self.qas_id))
|
62 |
+
s += ", question_text: %s" % (
|
63 |
+
tokenization.printable_text(self.question_text))
|
64 |
+
s += ", paragraph_text: [%s]" % (" ".join(self.paragraph_text))
|
65 |
+
if self.start_position:
|
66 |
+
s += ", start_position: %d" % (self.start_position,)
|
67 |
+
if self.start_position:
|
68 |
+
s += ", end_position: %d" % (self.end_position)
|
69 |
+
if self.start_position:
|
70 |
+
s += ", is_impossible: %r" % (self.is_impossible)
|
71 |
+
return s
|
72 |
+
|
73 |
+
|
74 |
+
class InputFeatures(object):
|
75 |
+
"""A single set of features of data."""
|
76 |
+
|
77 |
+
def __init__(self,
|
78 |
+
unique_id,
|
79 |
+
example_index,
|
80 |
+
doc_span_index,
|
81 |
+
tok_start_to_orig_index,
|
82 |
+
tok_end_to_orig_index,
|
83 |
+
token_is_max_context,
|
84 |
+
tokens,
|
85 |
+
input_ids,
|
86 |
+
input_mask,
|
87 |
+
segment_ids,
|
88 |
+
paragraph_len,
|
89 |
+
class_index=None,
|
90 |
+
paragraph_mask=None,
|
91 |
+
start_position=None,
|
92 |
+
end_position=None,
|
93 |
+
is_impossible=None):
|
94 |
+
self.unique_id = unique_id
|
95 |
+
self.example_index = example_index
|
96 |
+
self.doc_span_index = doc_span_index
|
97 |
+
self.tok_start_to_orig_index = tok_start_to_orig_index
|
98 |
+
self.tok_end_to_orig_index = tok_end_to_orig_index
|
99 |
+
self.token_is_max_context = token_is_max_context
|
100 |
+
self.tokens = tokens
|
101 |
+
self.input_ids = input_ids
|
102 |
+
self.input_mask = input_mask
|
103 |
+
self.paragraph_mask = paragraph_mask
|
104 |
+
self.segment_ids = segment_ids
|
105 |
+
self.paragraph_len = paragraph_len
|
106 |
+
self.class_index = class_index
|
107 |
+
self.start_position = start_position
|
108 |
+
self.end_position = end_position
|
109 |
+
self.is_impossible = is_impossible
|
110 |
+
|
111 |
+
|
112 |
+
def read_squad_examples(input_file,
|
113 |
+
is_training,
|
114 |
+
version_2_with_negative,
|
115 |
+
translated_input_folder=None):
|
116 |
+
"""Read a SQuAD json file into a list of SquadExample."""
|
117 |
+
del version_2_with_negative
|
118 |
+
with tf.io.gfile.GFile(input_file, "r") as reader:
|
119 |
+
input_data = json.load(reader)["data"]
|
120 |
+
|
121 |
+
if translated_input_folder is not None:
|
122 |
+
translated_files = tf.io.gfile.glob(
|
123 |
+
os.path.join(translated_input_folder, "*.json"))
|
124 |
+
for file in translated_files:
|
125 |
+
with tf.io.gfile.GFile(file, "r") as reader:
|
126 |
+
input_data.extend(json.load(reader)["data"])
|
127 |
+
|
128 |
+
examples = []
|
129 |
+
for entry in input_data:
|
130 |
+
for paragraph in entry["paragraphs"]:
|
131 |
+
paragraph_text = paragraph["context"]
|
132 |
+
|
133 |
+
for qa in paragraph["qas"]:
|
134 |
+
qas_id = qa["id"]
|
135 |
+
question_text = qa["question"]
|
136 |
+
start_position = None
|
137 |
+
orig_answer_text = None
|
138 |
+
is_impossible = False
|
139 |
+
|
140 |
+
if is_training:
|
141 |
+
is_impossible = qa.get("is_impossible", False)
|
142 |
+
if (len(qa["answers"]) != 1) and (not is_impossible):
|
143 |
+
raise ValueError(
|
144 |
+
"For training, each question should have exactly 1 answer.")
|
145 |
+
if not is_impossible:
|
146 |
+
answer = qa["answers"][0]
|
147 |
+
orig_answer_text = answer["text"]
|
148 |
+
start_position = answer["answer_start"]
|
149 |
+
else:
|
150 |
+
start_position = -1
|
151 |
+
orig_answer_text = ""
|
152 |
+
|
153 |
+
example = SquadExample(
|
154 |
+
qas_id=qas_id,
|
155 |
+
question_text=question_text,
|
156 |
+
paragraph_text=paragraph_text,
|
157 |
+
orig_answer_text=orig_answer_text,
|
158 |
+
start_position=start_position,
|
159 |
+
is_impossible=is_impossible)
|
160 |
+
examples.append(example)
|
161 |
+
|
162 |
+
return examples
|
163 |
+
|
164 |
+
|
165 |
+
def _convert_index(index, pos, m=None, is_start=True):
|
166 |
+
"""Converts index."""
|
167 |
+
if index[pos] is not None:
|
168 |
+
return index[pos]
|
169 |
+
n = len(index)
|
170 |
+
rear = pos
|
171 |
+
while rear < n - 1 and index[rear] is None:
|
172 |
+
rear += 1
|
173 |
+
front = pos
|
174 |
+
while front > 0 and index[front] is None:
|
175 |
+
front -= 1
|
176 |
+
assert index[front] is not None or index[rear] is not None
|
177 |
+
if index[front] is None:
|
178 |
+
if index[rear] >= 1: # pytype: disable=unsupported-operands
|
179 |
+
if is_start:
|
180 |
+
return 0
|
181 |
+
else:
|
182 |
+
return index[rear] - 1
|
183 |
+
return index[rear]
|
184 |
+
if index[rear] is None:
|
185 |
+
if m is not None and index[front] < m - 1:
|
186 |
+
if is_start:
|
187 |
+
return index[front] + 1
|
188 |
+
else:
|
189 |
+
return m - 1
|
190 |
+
return index[front]
|
191 |
+
if is_start:
|
192 |
+
if index[rear] > index[front] + 1:
|
193 |
+
return index[front] + 1
|
194 |
+
else:
|
195 |
+
return index[rear]
|
196 |
+
else:
|
197 |
+
if index[rear] > index[front] + 1:
|
198 |
+
return index[rear] - 1
|
199 |
+
else:
|
200 |
+
return index[front]
|
201 |
+
|
202 |
+
|
203 |
+
def convert_examples_to_features(examples,
|
204 |
+
tokenizer,
|
205 |
+
max_seq_length,
|
206 |
+
doc_stride,
|
207 |
+
max_query_length,
|
208 |
+
is_training,
|
209 |
+
output_fn,
|
210 |
+
do_lower_case,
|
211 |
+
xlnet_format=False,
|
212 |
+
batch_size=None):
|
213 |
+
"""Loads a data file into a list of `InputBatch`s."""
|
214 |
+
cnt_pos, cnt_neg = 0, 0
|
215 |
+
base_id = 1000000000
|
216 |
+
unique_id = base_id
|
217 |
+
max_n, max_m = 1024, 1024
|
218 |
+
f = np.zeros((max_n, max_m), dtype=np.float32)
|
219 |
+
|
220 |
+
for (example_index, example) in enumerate(examples):
|
221 |
+
|
222 |
+
if example_index % 100 == 0:
|
223 |
+
logging.info("Converting %d/%d pos %d neg %d", example_index,
|
224 |
+
len(examples), cnt_pos, cnt_neg)
|
225 |
+
|
226 |
+
query_tokens = tokenization.encode_ids(
|
227 |
+
tokenizer.sp_model,
|
228 |
+
tokenization.preprocess_text(
|
229 |
+
example.question_text, lower=do_lower_case))
|
230 |
+
|
231 |
+
if len(query_tokens) > max_query_length:
|
232 |
+
query_tokens = query_tokens[0:max_query_length]
|
233 |
+
|
234 |
+
paragraph_text = example.paragraph_text
|
235 |
+
para_tokens = tokenization.encode_pieces(
|
236 |
+
tokenizer.sp_model,
|
237 |
+
tokenization.preprocess_text(
|
238 |
+
example.paragraph_text, lower=do_lower_case))
|
239 |
+
|
240 |
+
chartok_to_tok_index = []
|
241 |
+
tok_start_to_chartok_index = []
|
242 |
+
tok_end_to_chartok_index = []
|
243 |
+
char_cnt = 0
|
244 |
+
for i, token in enumerate(para_tokens):
|
245 |
+
new_token = token.replace(tokenization.SPIECE_UNDERLINE, " ")
|
246 |
+
chartok_to_tok_index.extend([i] * len(new_token))
|
247 |
+
tok_start_to_chartok_index.append(char_cnt)
|
248 |
+
char_cnt += len(new_token)
|
249 |
+
tok_end_to_chartok_index.append(char_cnt - 1)
|
250 |
+
|
251 |
+
tok_cat_text = "".join(para_tokens).replace(tokenization.SPIECE_UNDERLINE,
|
252 |
+
" ")
|
253 |
+
n, m = len(paragraph_text), len(tok_cat_text)
|
254 |
+
|
255 |
+
if n > max_n or m > max_m:
|
256 |
+
max_n = max(n, max_n)
|
257 |
+
max_m = max(m, max_m)
|
258 |
+
f = np.zeros((max_n, max_m), dtype=np.float32)
|
259 |
+
|
260 |
+
g = {}
|
261 |
+
|
262 |
+
# pylint: disable=cell-var-from-loop
|
263 |
+
def _lcs_match(max_dist, n=n, m=m):
|
264 |
+
"""Longest-common-substring algorithm."""
|
265 |
+
f.fill(0)
|
266 |
+
g.clear()
|
267 |
+
|
268 |
+
### longest common sub sequence
|
269 |
+
# f[i, j] = max(f[i - 1, j], f[i, j - 1], f[i - 1, j - 1] + match(i, j))
|
270 |
+
for i in range(n):
|
271 |
+
|
272 |
+
# unlike standard LCS, this is specifically optimized for the setting
|
273 |
+
# because the mismatch between sentence pieces and original text will
|
274 |
+
# be small
|
275 |
+
for j in range(i - max_dist, i + max_dist):
|
276 |
+
if j >= m or j < 0:
|
277 |
+
continue
|
278 |
+
|
279 |
+
if i > 0:
|
280 |
+
g[(i, j)] = 0
|
281 |
+
f[i, j] = f[i - 1, j]
|
282 |
+
|
283 |
+
if j > 0 and f[i, j - 1] > f[i, j]:
|
284 |
+
g[(i, j)] = 1
|
285 |
+
f[i, j] = f[i, j - 1]
|
286 |
+
|
287 |
+
f_prev = f[i - 1, j - 1] if i > 0 and j > 0 else 0
|
288 |
+
if (tokenization.preprocess_text(
|
289 |
+
paragraph_text[i], lower=do_lower_case,
|
290 |
+
remove_space=False) == tok_cat_text[j] and f_prev + 1 > f[i, j]):
|
291 |
+
g[(i, j)] = 2
|
292 |
+
f[i, j] = f_prev + 1
|
293 |
+
|
294 |
+
# pylint: enable=cell-var-from-loop
|
295 |
+
|
296 |
+
max_dist = abs(n - m) + 5
|
297 |
+
for _ in range(2):
|
298 |
+
_lcs_match(max_dist)
|
299 |
+
if f[n - 1, m - 1] > 0.8 * n:
|
300 |
+
break
|
301 |
+
max_dist *= 2
|
302 |
+
|
303 |
+
orig_to_chartok_index = [None] * n
|
304 |
+
chartok_to_orig_index = [None] * m
|
305 |
+
i, j = n - 1, m - 1
|
306 |
+
while i >= 0 and j >= 0:
|
307 |
+
if (i, j) not in g:
|
308 |
+
break
|
309 |
+
if g[(i, j)] == 2:
|
310 |
+
orig_to_chartok_index[i] = j
|
311 |
+
chartok_to_orig_index[j] = i
|
312 |
+
i, j = i - 1, j - 1
|
313 |
+
elif g[(i, j)] == 1:
|
314 |
+
j = j - 1
|
315 |
+
else:
|
316 |
+
i = i - 1
|
317 |
+
|
318 |
+
if (all(v is None for v in orig_to_chartok_index) or
|
319 |
+
f[n - 1, m - 1] < 0.8 * n):
|
320 |
+
logging.info("MISMATCH DETECTED!")
|
321 |
+
continue
|
322 |
+
|
323 |
+
tok_start_to_orig_index = []
|
324 |
+
tok_end_to_orig_index = []
|
325 |
+
for i in range(len(para_tokens)):
|
326 |
+
start_chartok_pos = tok_start_to_chartok_index[i]
|
327 |
+
end_chartok_pos = tok_end_to_chartok_index[i]
|
328 |
+
start_orig_pos = _convert_index(
|
329 |
+
chartok_to_orig_index, start_chartok_pos, n, is_start=True)
|
330 |
+
end_orig_pos = _convert_index(
|
331 |
+
chartok_to_orig_index, end_chartok_pos, n, is_start=False)
|
332 |
+
|
333 |
+
tok_start_to_orig_index.append(start_orig_pos)
|
334 |
+
tok_end_to_orig_index.append(end_orig_pos)
|
335 |
+
|
336 |
+
if not is_training:
|
337 |
+
tok_start_position = tok_end_position = None
|
338 |
+
|
339 |
+
if is_training and example.is_impossible:
|
340 |
+
tok_start_position = 0
|
341 |
+
tok_end_position = 0
|
342 |
+
|
343 |
+
if is_training and not example.is_impossible:
|
344 |
+
start_position = example.start_position
|
345 |
+
end_position = start_position + len(example.orig_answer_text) - 1
|
346 |
+
|
347 |
+
start_chartok_pos = _convert_index(
|
348 |
+
orig_to_chartok_index, start_position, is_start=True)
|
349 |
+
tok_start_position = chartok_to_tok_index[start_chartok_pos]
|
350 |
+
|
351 |
+
end_chartok_pos = _convert_index(
|
352 |
+
orig_to_chartok_index, end_position, is_start=False)
|
353 |
+
tok_end_position = chartok_to_tok_index[end_chartok_pos]
|
354 |
+
assert tok_start_position <= tok_end_position
|
355 |
+
|
356 |
+
def _piece_to_id(x):
|
357 |
+
return tokenizer.sp_model.PieceToId(x)
|
358 |
+
|
359 |
+
all_doc_tokens = list(map(_piece_to_id, para_tokens))
|
360 |
+
|
361 |
+
# The -3 accounts for [CLS], [SEP] and [SEP]
|
362 |
+
max_tokens_for_doc = max_seq_length - len(query_tokens) - 3
|
363 |
+
|
364 |
+
# We can have documents that are longer than the maximum sequence length.
|
365 |
+
# To deal with this we do a sliding window approach, where we take chunks
|
366 |
+
# of the up to our max length with a stride of `doc_stride`.
|
367 |
+
_DocSpan = collections.namedtuple( # pylint: disable=invalid-name
|
368 |
+
"DocSpan", ["start", "length"])
|
369 |
+
doc_spans = []
|
370 |
+
start_offset = 0
|
371 |
+
|
372 |
+
while start_offset < len(all_doc_tokens):
|
373 |
+
length = len(all_doc_tokens) - start_offset
|
374 |
+
if length > max_tokens_for_doc:
|
375 |
+
length = max_tokens_for_doc
|
376 |
+
doc_spans.append(_DocSpan(start=start_offset, length=length))
|
377 |
+
if start_offset + length == len(all_doc_tokens):
|
378 |
+
break
|
379 |
+
start_offset += min(length, doc_stride)
|
380 |
+
|
381 |
+
for (doc_span_index, doc_span) in enumerate(doc_spans):
|
382 |
+
tokens = []
|
383 |
+
token_is_max_context = {}
|
384 |
+
segment_ids = []
|
385 |
+
|
386 |
+
# Paragraph mask used in XLNet.
|
387 |
+
# 1 represents paragraph and class tokens.
|
388 |
+
# 0 represents query and other special tokens.
|
389 |
+
paragraph_mask = []
|
390 |
+
|
391 |
+
cur_tok_start_to_orig_index = []
|
392 |
+
cur_tok_end_to_orig_index = []
|
393 |
+
|
394 |
+
# pylint: disable=cell-var-from-loop
|
395 |
+
def process_query(seg_q):
|
396 |
+
for token in query_tokens:
|
397 |
+
tokens.append(token)
|
398 |
+
segment_ids.append(seg_q)
|
399 |
+
paragraph_mask.append(0)
|
400 |
+
tokens.append(tokenizer.sp_model.PieceToId("[SEP]"))
|
401 |
+
segment_ids.append(seg_q)
|
402 |
+
paragraph_mask.append(0)
|
403 |
+
|
404 |
+
def process_paragraph(seg_p):
|
405 |
+
for i in range(doc_span.length):
|
406 |
+
split_token_index = doc_span.start + i
|
407 |
+
|
408 |
+
cur_tok_start_to_orig_index.append(
|
409 |
+
tok_start_to_orig_index[split_token_index])
|
410 |
+
cur_tok_end_to_orig_index.append(
|
411 |
+
tok_end_to_orig_index[split_token_index])
|
412 |
+
|
413 |
+
is_max_context = _check_is_max_context(doc_spans, doc_span_index,
|
414 |
+
split_token_index)
|
415 |
+
token_is_max_context[len(tokens)] = is_max_context
|
416 |
+
tokens.append(all_doc_tokens[split_token_index])
|
417 |
+
segment_ids.append(seg_p)
|
418 |
+
paragraph_mask.append(1)
|
419 |
+
tokens.append(tokenizer.sp_model.PieceToId("[SEP]"))
|
420 |
+
segment_ids.append(seg_p)
|
421 |
+
paragraph_mask.append(0)
|
422 |
+
return len(tokens)
|
423 |
+
|
424 |
+
def process_class(seg_class):
|
425 |
+
class_index = len(segment_ids)
|
426 |
+
tokens.append(tokenizer.sp_model.PieceToId("[CLS]"))
|
427 |
+
segment_ids.append(seg_class)
|
428 |
+
paragraph_mask.append(1)
|
429 |
+
return class_index
|
430 |
+
|
431 |
+
if xlnet_format:
|
432 |
+
seg_p, seg_q, seg_class, seg_pad = 0, 1, 2, 3
|
433 |
+
paragraph_len = process_paragraph(seg_p)
|
434 |
+
process_query(seg_q)
|
435 |
+
class_index = process_class(seg_class)
|
436 |
+
else:
|
437 |
+
seg_p, seg_q, seg_class, seg_pad = 1, 0, 0, 0
|
438 |
+
class_index = process_class(seg_class)
|
439 |
+
process_query(seg_q)
|
440 |
+
paragraph_len = process_paragraph(seg_p)
|
441 |
+
|
442 |
+
input_ids = tokens
|
443 |
+
|
444 |
+
# The mask has 1 for real tokens and 0 for padding tokens. Only real
|
445 |
+
# tokens are attended to.
|
446 |
+
input_mask = [1] * len(input_ids)
|
447 |
+
|
448 |
+
# Zero-pad up to the sequence length.
|
449 |
+
while len(input_ids) < max_seq_length:
|
450 |
+
input_ids.append(0)
|
451 |
+
input_mask.append(0)
|
452 |
+
segment_ids.append(seg_pad)
|
453 |
+
paragraph_mask.append(0)
|
454 |
+
|
455 |
+
assert len(input_ids) == max_seq_length
|
456 |
+
assert len(input_mask) == max_seq_length
|
457 |
+
assert len(segment_ids) == max_seq_length
|
458 |
+
assert len(paragraph_mask) == max_seq_length
|
459 |
+
|
460 |
+
span_is_impossible = example.is_impossible
|
461 |
+
start_position = None
|
462 |
+
end_position = None
|
463 |
+
if is_training and not span_is_impossible:
|
464 |
+
# For training, if our document chunk does not contain an annotation
|
465 |
+
# we throw it out, since there is nothing to predict.
|
466 |
+
doc_start = doc_span.start
|
467 |
+
doc_end = doc_span.start + doc_span.length - 1
|
468 |
+
out_of_span = False
|
469 |
+
if not (tok_start_position >= doc_start and
|
470 |
+
tok_end_position <= doc_end):
|
471 |
+
out_of_span = True
|
472 |
+
if out_of_span:
|
473 |
+
# continue
|
474 |
+
start_position = 0
|
475 |
+
end_position = 0
|
476 |
+
span_is_impossible = True
|
477 |
+
else:
|
478 |
+
doc_offset = 0 if xlnet_format else len(query_tokens) + 2
|
479 |
+
start_position = tok_start_position - doc_start + doc_offset
|
480 |
+
end_position = tok_end_position - doc_start + doc_offset
|
481 |
+
|
482 |
+
if is_training and span_is_impossible:
|
483 |
+
start_position = class_index
|
484 |
+
end_position = class_index
|
485 |
+
|
486 |
+
if example_index < 20:
|
487 |
+
logging.info("*** Example ***")
|
488 |
+
logging.info("unique_id: %s", (unique_id))
|
489 |
+
logging.info("example_index: %s", (example_index))
|
490 |
+
logging.info("doc_span_index: %s", (doc_span_index))
|
491 |
+
logging.info("tok_start_to_orig_index: %s",
|
492 |
+
" ".join([str(x) for x in cur_tok_start_to_orig_index]))
|
493 |
+
logging.info("tok_end_to_orig_index: %s",
|
494 |
+
" ".join([str(x) for x in cur_tok_end_to_orig_index]))
|
495 |
+
logging.info(
|
496 |
+
"token_is_max_context: %s", " ".join(
|
497 |
+
["%d:%s" % (x, y) for (x, y) in token_is_max_context.items()]))
|
498 |
+
logging.info(
|
499 |
+
"input_pieces: %s",
|
500 |
+
" ".join([tokenizer.sp_model.IdToPiece(x) for x in tokens]))
|
501 |
+
logging.info("input_ids: %s", " ".join([str(x) for x in input_ids]))
|
502 |
+
logging.info("input_mask: %s", " ".join([str(x) for x in input_mask]))
|
503 |
+
logging.info("segment_ids: %s", " ".join([str(x) for x in segment_ids]))
|
504 |
+
logging.info("paragraph_mask: %s", " ".join(
|
505 |
+
[str(x) for x in paragraph_mask]))
|
506 |
+
logging.info("class_index: %d", class_index)
|
507 |
+
|
508 |
+
if is_training and span_is_impossible:
|
509 |
+
logging.info("impossible example span")
|
510 |
+
|
511 |
+
if is_training and not span_is_impossible:
|
512 |
+
pieces = [
|
513 |
+
tokenizer.sp_model.IdToPiece(token)
|
514 |
+
for token in tokens[start_position:(end_position + 1)]
|
515 |
+
]
|
516 |
+
answer_text = tokenizer.sp_model.DecodePieces(pieces)
|
517 |
+
logging.info("start_position: %d", (start_position))
|
518 |
+
logging.info("end_position: %d", (end_position))
|
519 |
+
logging.info("answer: %s", (tokenization.printable_text(answer_text)))
|
520 |
+
|
521 |
+
# With multi processing, the example_index is actually the index
|
522 |
+
# within the current process therefore we use example_index=None
|
523 |
+
# to avoid being used in the future.
|
524 |
+
# The current code does not use example_index of training data.
|
525 |
+
if is_training:
|
526 |
+
feat_example_index = None
|
527 |
+
else:
|
528 |
+
feat_example_index = example_index
|
529 |
+
|
530 |
+
feature = InputFeatures(
|
531 |
+
unique_id=unique_id,
|
532 |
+
example_index=feat_example_index,
|
533 |
+
doc_span_index=doc_span_index,
|
534 |
+
tok_start_to_orig_index=cur_tok_start_to_orig_index,
|
535 |
+
tok_end_to_orig_index=cur_tok_end_to_orig_index,
|
536 |
+
token_is_max_context=token_is_max_context,
|
537 |
+
tokens=[tokenizer.sp_model.IdToPiece(x) for x in tokens],
|
538 |
+
input_ids=input_ids,
|
539 |
+
input_mask=input_mask,
|
540 |
+
paragraph_mask=paragraph_mask,
|
541 |
+
segment_ids=segment_ids,
|
542 |
+
paragraph_len=paragraph_len,
|
543 |
+
class_index=class_index,
|
544 |
+
start_position=start_position,
|
545 |
+
end_position=end_position,
|
546 |
+
is_impossible=span_is_impossible)
|
547 |
+
|
548 |
+
# Run callback
|
549 |
+
if is_training:
|
550 |
+
output_fn(feature)
|
551 |
+
else:
|
552 |
+
output_fn(feature, is_padding=False)
|
553 |
+
|
554 |
+
unique_id += 1
|
555 |
+
if span_is_impossible:
|
556 |
+
cnt_neg += 1
|
557 |
+
else:
|
558 |
+
cnt_pos += 1
|
559 |
+
|
560 |
+
if not is_training and feature:
|
561 |
+
assert batch_size
|
562 |
+
num_padding = 0
|
563 |
+
num_examples = unique_id - base_id
|
564 |
+
if unique_id % batch_size != 0:
|
565 |
+
num_padding = batch_size - (num_examples % batch_size)
|
566 |
+
dummy_feature = copy.deepcopy(feature)
|
567 |
+
for _ in range(num_padding):
|
568 |
+
dummy_feature.unique_id = unique_id
|
569 |
+
|
570 |
+
# Run callback
|
571 |
+
output_fn(feature, is_padding=True)
|
572 |
+
unique_id += 1
|
573 |
+
|
574 |
+
logging.info("Total number of instances: %d = pos %d neg %d",
|
575 |
+
cnt_pos + cnt_neg, cnt_pos, cnt_neg)
|
576 |
+
return unique_id - base_id
|
577 |
+
|
578 |
+
|
579 |
+
def _check_is_max_context(doc_spans, cur_span_index, position):
|
580 |
+
"""Check if this is the 'max context' doc span for the token."""
|
581 |
+
|
582 |
+
# Because of the sliding window approach taken to scoring documents, a single
|
583 |
+
# token can appear in multiple documents. E.g.
|
584 |
+
# Doc: the man went to the store and bought a gallon of milk
|
585 |
+
# Span A: the man went to the
|
586 |
+
# Span B: to the store and bought
|
587 |
+
# Span C: and bought a gallon of
|
588 |
+
# ...
|
589 |
+
#
|
590 |
+
# Now the word 'bought' will have two scores from spans B and C. We only
|
591 |
+
# want to consider the score with "maximum context", which we define as
|
592 |
+
# the *minimum* of its left and right context (the *sum* of left and
|
593 |
+
# right context will always be the same, of course).
|
594 |
+
#
|
595 |
+
# In the example the maximum context for 'bought' would be span C since
|
596 |
+
# it has 1 left context and 3 right context, while span B has 4 left context
|
597 |
+
# and 0 right context.
|
598 |
+
best_score = None
|
599 |
+
best_span_index = None
|
600 |
+
for (span_index, doc_span) in enumerate(doc_spans):
|
601 |
+
end = doc_span.start + doc_span.length - 1
|
602 |
+
if position < doc_span.start:
|
603 |
+
continue
|
604 |
+
if position > end:
|
605 |
+
continue
|
606 |
+
num_left_context = position - doc_span.start
|
607 |
+
num_right_context = end - position
|
608 |
+
score = min(num_left_context, num_right_context) + 0.01 * doc_span.length
|
609 |
+
if best_score is None or score > best_score:
|
610 |
+
best_score = score
|
611 |
+
best_span_index = span_index
|
612 |
+
|
613 |
+
return cur_span_index == best_span_index
|
614 |
+
|
615 |
+
|
616 |
+
def write_predictions(all_examples,
|
617 |
+
all_features,
|
618 |
+
all_results,
|
619 |
+
n_best_size,
|
620 |
+
max_answer_length,
|
621 |
+
do_lower_case,
|
622 |
+
output_prediction_file,
|
623 |
+
output_nbest_file,
|
624 |
+
output_null_log_odds_file,
|
625 |
+
version_2_with_negative=False,
|
626 |
+
null_score_diff_threshold=0.0,
|
627 |
+
verbose=False):
|
628 |
+
"""Write final predictions to the json file and log-odds of null if needed."""
|
629 |
+
logging.info("Writing predictions to: %s", (output_prediction_file))
|
630 |
+
logging.info("Writing nbest to: %s", (output_nbest_file))
|
631 |
+
|
632 |
+
all_predictions, all_nbest_json, scores_diff_json = (
|
633 |
+
postprocess_output(
|
634 |
+
all_examples=all_examples,
|
635 |
+
all_features=all_features,
|
636 |
+
all_results=all_results,
|
637 |
+
n_best_size=n_best_size,
|
638 |
+
max_answer_length=max_answer_length,
|
639 |
+
do_lower_case=do_lower_case,
|
640 |
+
version_2_with_negative=version_2_with_negative,
|
641 |
+
null_score_diff_threshold=null_score_diff_threshold,
|
642 |
+
verbose=verbose))
|
643 |
+
|
644 |
+
write_to_json_files(all_predictions, output_prediction_file)
|
645 |
+
write_to_json_files(all_nbest_json, output_nbest_file)
|
646 |
+
if version_2_with_negative:
|
647 |
+
write_to_json_files(scores_diff_json, output_null_log_odds_file)
|
648 |
+
|
649 |
+
|
650 |
+
def postprocess_output(all_examples,
|
651 |
+
all_features,
|
652 |
+
all_results,
|
653 |
+
n_best_size,
|
654 |
+
max_answer_length,
|
655 |
+
do_lower_case,
|
656 |
+
version_2_with_negative=False,
|
657 |
+
null_score_diff_threshold=0.0,
|
658 |
+
xlnet_format=False,
|
659 |
+
verbose=False):
|
660 |
+
"""Postprocess model output, to form predicton results."""
|
661 |
+
|
662 |
+
del do_lower_case, verbose
|
663 |
+
example_index_to_features = collections.defaultdict(list)
|
664 |
+
for feature in all_features:
|
665 |
+
example_index_to_features[feature.example_index].append(feature)
|
666 |
+
|
667 |
+
unique_id_to_result = {}
|
668 |
+
for result in all_results:
|
669 |
+
unique_id_to_result[result.unique_id] = result
|
670 |
+
|
671 |
+
_PrelimPrediction = collections.namedtuple( # pylint: disable=invalid-name
|
672 |
+
"PrelimPrediction",
|
673 |
+
["feature_index", "start_index", "end_index", "start_logit", "end_logit"])
|
674 |
+
|
675 |
+
all_predictions = collections.OrderedDict()
|
676 |
+
all_nbest_json = collections.OrderedDict()
|
677 |
+
scores_diff_json = collections.OrderedDict()
|
678 |
+
|
679 |
+
for (example_index, example) in enumerate(all_examples):
|
680 |
+
features = example_index_to_features[example_index]
|
681 |
+
|
682 |
+
prelim_predictions = []
|
683 |
+
# keep track of the minimum score of null start+end of position 0
|
684 |
+
score_null = 1000000 # large and positive
|
685 |
+
min_null_feature_index = 0 # the paragraph slice with min mull score
|
686 |
+
null_start_logit = 0 # the start logit at the slice with min null score
|
687 |
+
null_end_logit = 0 # the end logit at the slice with min null score
|
688 |
+
for (feature_index, feature) in enumerate(features):
|
689 |
+
if feature.unique_id not in unique_id_to_result:
|
690 |
+
logging.info("Skip eval example %s, not in pred.", feature.unique_id)
|
691 |
+
continue
|
692 |
+
result = unique_id_to_result[feature.unique_id]
|
693 |
+
|
694 |
+
# if we could have irrelevant answers, get the min score of irrelevant
|
695 |
+
if version_2_with_negative:
|
696 |
+
if xlnet_format:
|
697 |
+
feature_null_score = result.class_logits
|
698 |
+
else:
|
699 |
+
feature_null_score = result.start_logits[0] + result.end_logits[0]
|
700 |
+
if feature_null_score < score_null:
|
701 |
+
score_null = feature_null_score
|
702 |
+
min_null_feature_index = feature_index
|
703 |
+
null_start_logit = result.start_logits[0]
|
704 |
+
null_end_logit = result.end_logits[0]
|
705 |
+
|
706 |
+
doc_offset = 0 if xlnet_format else feature.tokens.index("[SEP]") + 1
|
707 |
+
|
708 |
+
for (start_index, start_logit,
|
709 |
+
end_index, end_logit) in _get_best_indexes_and_logits(
|
710 |
+
result=result,
|
711 |
+
n_best_size=n_best_size,
|
712 |
+
xlnet_format=xlnet_format):
|
713 |
+
# We could hypothetically create invalid predictions, e.g., predict
|
714 |
+
# that the start of the span is in the question. We throw out all
|
715 |
+
# invalid predictions.
|
716 |
+
if start_index - doc_offset >= len(feature.tok_start_to_orig_index):
|
717 |
+
continue
|
718 |
+
if end_index - doc_offset >= len(feature.tok_end_to_orig_index):
|
719 |
+
continue
|
720 |
+
if not feature.token_is_max_context.get(start_index, False):
|
721 |
+
continue
|
722 |
+
if end_index < start_index:
|
723 |
+
continue
|
724 |
+
length = end_index - start_index + 1
|
725 |
+
if length > max_answer_length:
|
726 |
+
continue
|
727 |
+
prelim_predictions.append(
|
728 |
+
_PrelimPrediction(
|
729 |
+
feature_index=feature_index,
|
730 |
+
start_index=start_index - doc_offset,
|
731 |
+
end_index=end_index - doc_offset,
|
732 |
+
start_logit=start_logit,
|
733 |
+
end_logit=end_logit))
|
734 |
+
|
735 |
+
if version_2_with_negative and not xlnet_format:
|
736 |
+
prelim_predictions.append(
|
737 |
+
_PrelimPrediction(
|
738 |
+
feature_index=min_null_feature_index,
|
739 |
+
start_index=-1,
|
740 |
+
end_index=-1,
|
741 |
+
start_logit=null_start_logit,
|
742 |
+
end_logit=null_end_logit))
|
743 |
+
prelim_predictions = sorted(
|
744 |
+
prelim_predictions,
|
745 |
+
key=lambda x: (x.start_logit + x.end_logit),
|
746 |
+
reverse=True)
|
747 |
+
|
748 |
+
_NbestPrediction = collections.namedtuple( # pylint: disable=invalid-name
|
749 |
+
"NbestPrediction", ["text", "start_logit", "end_logit"])
|
750 |
+
|
751 |
+
seen_predictions = {}
|
752 |
+
nbest = []
|
753 |
+
for pred in prelim_predictions:
|
754 |
+
if len(nbest) >= n_best_size:
|
755 |
+
break
|
756 |
+
feature = features[pred.feature_index]
|
757 |
+
if pred.start_index >= 0 or xlnet_format: # this is a non-null prediction
|
758 |
+
tok_start_to_orig_index = feature.tok_start_to_orig_index
|
759 |
+
tok_end_to_orig_index = feature.tok_end_to_orig_index
|
760 |
+
start_orig_pos = tok_start_to_orig_index[pred.start_index]
|
761 |
+
end_orig_pos = tok_end_to_orig_index[pred.end_index]
|
762 |
+
|
763 |
+
paragraph_text = example.paragraph_text
|
764 |
+
final_text = paragraph_text[start_orig_pos:end_orig_pos + 1].strip()
|
765 |
+
if final_text in seen_predictions:
|
766 |
+
continue
|
767 |
+
|
768 |
+
seen_predictions[final_text] = True
|
769 |
+
else:
|
770 |
+
final_text = ""
|
771 |
+
seen_predictions[final_text] = True
|
772 |
+
|
773 |
+
nbest.append(
|
774 |
+
_NbestPrediction(
|
775 |
+
text=final_text,
|
776 |
+
start_logit=pred.start_logit,
|
777 |
+
end_logit=pred.end_logit))
|
778 |
+
|
779 |
+
# if we didn't include the empty option in the n-best, include it
|
780 |
+
if version_2_with_negative and not xlnet_format:
|
781 |
+
if "" not in seen_predictions:
|
782 |
+
nbest.append(
|
783 |
+
_NbestPrediction(
|
784 |
+
text="", start_logit=null_start_logit,
|
785 |
+
end_logit=null_end_logit))
|
786 |
+
# In very rare edge cases we could have no valid predictions. So we
|
787 |
+
# just create a nonce prediction in this case to avoid failure.
|
788 |
+
if not nbest:
|
789 |
+
nbest.append(
|
790 |
+
_NbestPrediction(text="empty", start_logit=0.0, end_logit=0.0))
|
791 |
+
|
792 |
+
assert len(nbest) >= 1
|
793 |
+
|
794 |
+
total_scores = []
|
795 |
+
best_non_null_entry = None
|
796 |
+
for entry in nbest:
|
797 |
+
total_scores.append(entry.start_logit + entry.end_logit)
|
798 |
+
if not best_non_null_entry:
|
799 |
+
if entry.text:
|
800 |
+
best_non_null_entry = entry
|
801 |
+
|
802 |
+
probs = _compute_softmax(total_scores)
|
803 |
+
|
804 |
+
nbest_json = []
|
805 |
+
for (i, entry) in enumerate(nbest):
|
806 |
+
output = collections.OrderedDict()
|
807 |
+
output["text"] = entry.text
|
808 |
+
output["probability"] = probs[i]
|
809 |
+
output["start_logit"] = entry.start_logit
|
810 |
+
output["end_logit"] = entry.end_logit
|
811 |
+
nbest_json.append(output)
|
812 |
+
|
813 |
+
assert len(nbest_json) >= 1
|
814 |
+
|
815 |
+
if not version_2_with_negative:
|
816 |
+
all_predictions[example.qas_id] = nbest_json[0]["text"]
|
817 |
+
else:
|
818 |
+
assert best_non_null_entry is not None
|
819 |
+
if xlnet_format:
|
820 |
+
score_diff = score_null
|
821 |
+
scores_diff_json[example.qas_id] = score_diff
|
822 |
+
all_predictions[example.qas_id] = best_non_null_entry.text
|
823 |
+
else:
|
824 |
+
# predict "" iff the null score - the score of best non-null > threshold
|
825 |
+
score_diff = score_null - best_non_null_entry.start_logit - (
|
826 |
+
best_non_null_entry.end_logit)
|
827 |
+
scores_diff_json[example.qas_id] = score_diff
|
828 |
+
if score_diff > null_score_diff_threshold:
|
829 |
+
all_predictions[example.qas_id] = ""
|
830 |
+
else:
|
831 |
+
all_predictions[example.qas_id] = best_non_null_entry.text
|
832 |
+
|
833 |
+
all_nbest_json[example.qas_id] = nbest_json
|
834 |
+
|
835 |
+
return all_predictions, all_nbest_json, scores_diff_json
|
836 |
+
|
837 |
+
|
838 |
+
def write_to_json_files(json_records, json_file):
|
839 |
+
with tf.io.gfile.GFile(json_file, "w") as writer:
|
840 |
+
writer.write(json.dumps(json_records, indent=4) + "\n")
|
841 |
+
|
842 |
+
|
843 |
+
def _get_best_indexes_and_logits(result,
|
844 |
+
n_best_size,
|
845 |
+
xlnet_format=False):
|
846 |
+
"""Generates the n-best indexes and logits from a list."""
|
847 |
+
if xlnet_format:
|
848 |
+
for i in range(n_best_size):
|
849 |
+
for j in range(n_best_size):
|
850 |
+
j_index = i * n_best_size + j
|
851 |
+
yield (result.start_indexes[i], result.start_logits[i],
|
852 |
+
result.end_indexes[j_index], result.end_logits[j_index])
|
853 |
+
else:
|
854 |
+
start_index_and_score = sorted(enumerate(result.start_logits),
|
855 |
+
key=lambda x: x[1], reverse=True)
|
856 |
+
end_index_and_score = sorted(enumerate(result.end_logits),
|
857 |
+
key=lambda x: x[1], reverse=True)
|
858 |
+
for i in range(len(start_index_and_score)):
|
859 |
+
if i >= n_best_size:
|
860 |
+
break
|
861 |
+
for j in range(len(end_index_and_score)):
|
862 |
+
if j >= n_best_size:
|
863 |
+
break
|
864 |
+
yield (start_index_and_score[i][0], start_index_and_score[i][1],
|
865 |
+
end_index_and_score[j][0], end_index_and_score[j][1])
|
866 |
+
|
867 |
+
|
868 |
+
def _compute_softmax(scores):
|
869 |
+
"""Compute softmax probability over raw logits."""
|
870 |
+
if not scores:
|
871 |
+
return []
|
872 |
+
|
873 |
+
max_score = None
|
874 |
+
for score in scores:
|
875 |
+
if max_score is None or score > max_score:
|
876 |
+
max_score = score
|
877 |
+
|
878 |
+
exp_scores = []
|
879 |
+
total_sum = 0.0
|
880 |
+
for score in scores:
|
881 |
+
x = math.exp(score - max_score)
|
882 |
+
exp_scores.append(x)
|
883 |
+
total_sum += x
|
884 |
+
|
885 |
+
probs = []
|
886 |
+
for score in exp_scores:
|
887 |
+
probs.append(score / total_sum)
|
888 |
+
return probs
|
889 |
+
|
890 |
+
|
891 |
+
class FeatureWriter(object):
|
892 |
+
"""Writes InputFeature to TF example file."""
|
893 |
+
|
894 |
+
def __init__(self, filename, is_training):
|
895 |
+
self.filename = filename
|
896 |
+
self.is_training = is_training
|
897 |
+
self.num_features = 0
|
898 |
+
tf.io.gfile.makedirs(os.path.dirname(filename))
|
899 |
+
self._writer = tf.io.TFRecordWriter(filename)
|
900 |
+
|
901 |
+
def process_feature(self, feature):
|
902 |
+
"""Write a InputFeature to the TFRecordWriter as a tf.train.Example."""
|
903 |
+
self.num_features += 1
|
904 |
+
|
905 |
+
def create_int_feature(values):
|
906 |
+
feature = tf.train.Feature(
|
907 |
+
int64_list=tf.train.Int64List(value=list(values)))
|
908 |
+
return feature
|
909 |
+
|
910 |
+
features = collections.OrderedDict()
|
911 |
+
features["unique_ids"] = create_int_feature([feature.unique_id])
|
912 |
+
features["input_ids"] = create_int_feature(feature.input_ids)
|
913 |
+
features["input_mask"] = create_int_feature(feature.input_mask)
|
914 |
+
features["segment_ids"] = create_int_feature(feature.segment_ids)
|
915 |
+
if feature.paragraph_mask is not None:
|
916 |
+
features["paragraph_mask"] = create_int_feature(feature.paragraph_mask)
|
917 |
+
if feature.class_index is not None:
|
918 |
+
features["class_index"] = create_int_feature([feature.class_index])
|
919 |
+
|
920 |
+
if self.is_training:
|
921 |
+
features["start_positions"] = create_int_feature([feature.start_position])
|
922 |
+
features["end_positions"] = create_int_feature([feature.end_position])
|
923 |
+
impossible = 0
|
924 |
+
if feature.is_impossible:
|
925 |
+
impossible = 1
|
926 |
+
features["is_impossible"] = create_int_feature([impossible])
|
927 |
+
|
928 |
+
tf_example = tf.train.Example(features=tf.train.Features(feature=features))
|
929 |
+
self._writer.write(tf_example.SerializeToString())
|
930 |
+
|
931 |
+
def close(self):
|
932 |
+
self._writer.close()
|
933 |
+
|
934 |
+
|
935 |
+
def generate_tf_record_from_json_file(input_file_path,
|
936 |
+
sp_model_file,
|
937 |
+
output_path,
|
938 |
+
translated_input_folder=None,
|
939 |
+
max_seq_length=384,
|
940 |
+
do_lower_case=True,
|
941 |
+
max_query_length=64,
|
942 |
+
doc_stride=128,
|
943 |
+
xlnet_format=False,
|
944 |
+
version_2_with_negative=False):
|
945 |
+
"""Generates and saves training data into a tf record file."""
|
946 |
+
train_examples = read_squad_examples(
|
947 |
+
input_file=input_file_path,
|
948 |
+
is_training=True,
|
949 |
+
version_2_with_negative=version_2_with_negative,
|
950 |
+
translated_input_folder=translated_input_folder)
|
951 |
+
tokenizer = tokenization.FullSentencePieceTokenizer(
|
952 |
+
sp_model_file=sp_model_file)
|
953 |
+
train_writer = FeatureWriter(
|
954 |
+
filename=output_path, is_training=True)
|
955 |
+
number_of_examples = convert_examples_to_features(
|
956 |
+
examples=train_examples,
|
957 |
+
tokenizer=tokenizer,
|
958 |
+
max_seq_length=max_seq_length,
|
959 |
+
doc_stride=doc_stride,
|
960 |
+
max_query_length=max_query_length,
|
961 |
+
is_training=True,
|
962 |
+
output_fn=train_writer.process_feature,
|
963 |
+
xlnet_format=xlnet_format,
|
964 |
+
do_lower_case=do_lower_case)
|
965 |
+
train_writer.close()
|
966 |
+
|
967 |
+
meta_data = {
|
968 |
+
"task_type": "bert_squad",
|
969 |
+
"train_data_size": number_of_examples,
|
970 |
+
"max_seq_length": max_seq_length,
|
971 |
+
"max_query_length": max_query_length,
|
972 |
+
"doc_stride": doc_stride,
|
973 |
+
"version_2_with_negative": version_2_with_negative,
|
974 |
+
}
|
975 |
+
|
976 |
+
return meta_data
|
tagging_data_lib.py
ADDED
@@ -0,0 +1,426 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
"""Library to process data for tagging task such as NER/POS."""
|
16 |
+
import collections
|
17 |
+
import os
|
18 |
+
|
19 |
+
from absl import logging
|
20 |
+
import tensorflow as tf, tf_keras
|
21 |
+
|
22 |
+
from official.nlp.data import classifier_data_lib
|
23 |
+
from official.nlp.tools import tokenization
|
24 |
+
|
25 |
+
# A negative label id for the padding label, which will not contribute
|
26 |
+
# to loss/metrics in training.
|
27 |
+
_PADDING_LABEL_ID = -1
|
28 |
+
|
29 |
+
# The special unknown token, used to substitute a word which has too many
|
30 |
+
# subwords after tokenization.
|
31 |
+
_UNK_TOKEN = "[UNK]"
|
32 |
+
|
33 |
+
|
34 |
+
class InputExample(object):
|
35 |
+
"""A single training/test example for token classification."""
|
36 |
+
|
37 |
+
def __init__(self,
|
38 |
+
sentence_id,
|
39 |
+
sub_sentence_id=0,
|
40 |
+
words=None,
|
41 |
+
label_ids=None):
|
42 |
+
"""Constructs an InputExample."""
|
43 |
+
self.sentence_id = sentence_id
|
44 |
+
self.sub_sentence_id = sub_sentence_id
|
45 |
+
self.words = words if words else []
|
46 |
+
self.label_ids = label_ids if label_ids else []
|
47 |
+
|
48 |
+
def add_word_and_label_id(self, word, label_id):
|
49 |
+
"""Adds word and label_id pair in the example."""
|
50 |
+
self.words.append(word)
|
51 |
+
self.label_ids.append(label_id)
|
52 |
+
|
53 |
+
|
54 |
+
def _read_one_file(file_name, label_list):
|
55 |
+
"""Reads one file and returns a list of `InputExample` instances."""
|
56 |
+
lines = tf.io.gfile.GFile(file_name, "r").readlines()
|
57 |
+
examples = []
|
58 |
+
label_id_map = {label: i for i, label in enumerate(label_list)}
|
59 |
+
sentence_id = 0
|
60 |
+
example = InputExample(sentence_id=0)
|
61 |
+
for line in lines:
|
62 |
+
line = line.strip("\n")
|
63 |
+
if line:
|
64 |
+
# The format is: <token>\t<label> for train/dev set and <token> for test.
|
65 |
+
items = line.split("\t")
|
66 |
+
assert len(items) == 2 or len(items) == 1
|
67 |
+
token = items[0].strip()
|
68 |
+
|
69 |
+
# Assign a dummy label_id for test set
|
70 |
+
label_id = label_id_map[items[1].strip()] if len(items) == 2 else 0
|
71 |
+
example.add_word_and_label_id(token, label_id)
|
72 |
+
else:
|
73 |
+
# Empty line indicates a new sentence.
|
74 |
+
if example.words:
|
75 |
+
examples.append(example)
|
76 |
+
sentence_id += 1
|
77 |
+
example = InputExample(sentence_id=sentence_id)
|
78 |
+
|
79 |
+
if example.words:
|
80 |
+
examples.append(example)
|
81 |
+
return examples
|
82 |
+
|
83 |
+
|
84 |
+
class PanxProcessor(classifier_data_lib.DataProcessor):
|
85 |
+
"""Processor for the Panx data set."""
|
86 |
+
supported_languages = [
|
87 |
+
"ar", "he", "vi", "id", "jv", "ms", "tl", "eu", "ml", "ta", "te", "af",
|
88 |
+
"nl", "en", "de", "el", "bn", "hi", "mr", "ur", "fa", "fr", "it", "pt",
|
89 |
+
"es", "bg", "ru", "ja", "ka", "ko", "th", "sw", "yo", "my", "zh", "kk",
|
90 |
+
"tr", "et", "fi", "hu"
|
91 |
+
]
|
92 |
+
|
93 |
+
def __init__(self,
|
94 |
+
process_text_fn=tokenization.convert_to_unicode,
|
95 |
+
only_use_en_train=True,
|
96 |
+
only_use_en_dev=True):
|
97 |
+
"""See base class.
|
98 |
+
|
99 |
+
Args:
|
100 |
+
process_text_fn: See base class.
|
101 |
+
only_use_en_train: If True, only use english training data. Otherwise, use
|
102 |
+
training data from all languages.
|
103 |
+
only_use_en_dev: If True, only use english dev data. Otherwise, use dev
|
104 |
+
data from all languages.
|
105 |
+
"""
|
106 |
+
super(PanxProcessor, self).__init__(process_text_fn)
|
107 |
+
self.only_use_en_train = only_use_en_train
|
108 |
+
self.only_use_en_dev = only_use_en_dev
|
109 |
+
|
110 |
+
def get_train_examples(self, data_dir):
|
111 |
+
examples = _read_one_file(
|
112 |
+
os.path.join(data_dir, "train-en.tsv"), self.get_labels())
|
113 |
+
if not self.only_use_en_train:
|
114 |
+
for language in self.supported_languages:
|
115 |
+
if language == "en":
|
116 |
+
continue
|
117 |
+
examples.extend(
|
118 |
+
_read_one_file(
|
119 |
+
os.path.join(data_dir, f"train-{language}.tsv"),
|
120 |
+
self.get_labels()))
|
121 |
+
return examples
|
122 |
+
|
123 |
+
def get_dev_examples(self, data_dir):
|
124 |
+
examples = _read_one_file(
|
125 |
+
os.path.join(data_dir, "dev-en.tsv"), self.get_labels())
|
126 |
+
if not self.only_use_en_dev:
|
127 |
+
for language in self.supported_languages:
|
128 |
+
if language == "en":
|
129 |
+
continue
|
130 |
+
examples.extend(
|
131 |
+
_read_one_file(
|
132 |
+
os.path.join(data_dir, f"dev-{language}.tsv"),
|
133 |
+
self.get_labels()))
|
134 |
+
return examples
|
135 |
+
|
136 |
+
def get_test_examples(self, data_dir):
|
137 |
+
examples_dict = {}
|
138 |
+
for language in self.supported_languages:
|
139 |
+
examples_dict[language] = _read_one_file(
|
140 |
+
os.path.join(data_dir, "test-%s.tsv" % language), self.get_labels())
|
141 |
+
return examples_dict
|
142 |
+
|
143 |
+
def get_labels(self):
|
144 |
+
return ["O", "B-PER", "I-PER", "B-LOC", "I-LOC", "B-ORG", "I-ORG"]
|
145 |
+
|
146 |
+
@staticmethod
|
147 |
+
def get_processor_name():
|
148 |
+
return "panx"
|
149 |
+
|
150 |
+
|
151 |
+
class UdposProcessor(classifier_data_lib.DataProcessor):
|
152 |
+
"""Processor for the Udpos data set."""
|
153 |
+
supported_languages = [
|
154 |
+
"af", "ar", "bg", "de", "el", "en", "es", "et", "eu", "fa", "fi", "fr",
|
155 |
+
"he", "hi", "hu", "id", "it", "ja", "kk", "ko", "mr", "nl", "pt", "ru",
|
156 |
+
"ta", "te", "th", "tl", "tr", "ur", "vi", "yo", "zh"
|
157 |
+
]
|
158 |
+
|
159 |
+
def __init__(self,
|
160 |
+
process_text_fn=tokenization.convert_to_unicode,
|
161 |
+
only_use_en_train=True,
|
162 |
+
only_use_en_dev=True):
|
163 |
+
"""See base class.
|
164 |
+
|
165 |
+
Args:
|
166 |
+
process_text_fn: See base class.
|
167 |
+
only_use_en_train: If True, only use english training data. Otherwise, use
|
168 |
+
training data from all languages.
|
169 |
+
only_use_en_dev: If True, only use english dev data. Otherwise, use dev
|
170 |
+
data from all languages.
|
171 |
+
"""
|
172 |
+
super(UdposProcessor, self).__init__(process_text_fn)
|
173 |
+
self.only_use_en_train = only_use_en_train
|
174 |
+
self.only_use_en_dev = only_use_en_dev
|
175 |
+
|
176 |
+
def get_train_examples(self, data_dir):
|
177 |
+
if self.only_use_en_train:
|
178 |
+
examples = _read_one_file(
|
179 |
+
os.path.join(data_dir, "train-en.tsv"), self.get_labels())
|
180 |
+
else:
|
181 |
+
examples = []
|
182 |
+
# Uses glob because some languages are missing in train.
|
183 |
+
for filepath in tf.io.gfile.glob(os.path.join(data_dir, "train-*.tsv")):
|
184 |
+
examples.extend(
|
185 |
+
_read_one_file(
|
186 |
+
filepath,
|
187 |
+
self.get_labels()))
|
188 |
+
return examples
|
189 |
+
|
190 |
+
def get_dev_examples(self, data_dir):
|
191 |
+
if self.only_use_en_dev:
|
192 |
+
examples = _read_one_file(
|
193 |
+
os.path.join(data_dir, "dev-en.tsv"), self.get_labels())
|
194 |
+
else:
|
195 |
+
examples = []
|
196 |
+
for filepath in tf.io.gfile.glob(os.path.join(data_dir, "dev-*.tsv")):
|
197 |
+
examples.extend(
|
198 |
+
_read_one_file(
|
199 |
+
filepath,
|
200 |
+
self.get_labels()))
|
201 |
+
return examples
|
202 |
+
|
203 |
+
def get_test_examples(self, data_dir):
|
204 |
+
examples_dict = {}
|
205 |
+
for language in self.supported_languages:
|
206 |
+
examples_dict[language] = _read_one_file(
|
207 |
+
os.path.join(data_dir, "test-%s.tsv" % language), self.get_labels())
|
208 |
+
return examples_dict
|
209 |
+
|
210 |
+
def get_labels(self):
|
211 |
+
return [
|
212 |
+
"ADJ", "ADP", "ADV", "AUX", "CCONJ", "DET", "INTJ", "NOUN", "NUM",
|
213 |
+
"PART", "PRON", "PROPN", "PUNCT", "SCONJ", "SYM", "VERB", "X"
|
214 |
+
]
|
215 |
+
|
216 |
+
@staticmethod
|
217 |
+
def get_processor_name():
|
218 |
+
return "udpos"
|
219 |
+
|
220 |
+
|
221 |
+
def _tokenize_example(example, max_length, tokenizer, text_preprocessing=None):
|
222 |
+
"""Tokenizes words and breaks long example into short ones."""
|
223 |
+
# Needs additional [CLS] and [SEP] tokens.
|
224 |
+
max_length = max_length - 2
|
225 |
+
new_examples = []
|
226 |
+
new_example = InputExample(sentence_id=example.sentence_id, sub_sentence_id=0)
|
227 |
+
if any([x < 0 for x in example.label_ids]):
|
228 |
+
raise ValueError("Unexpected negative label_id: %s" % example.label_ids)
|
229 |
+
|
230 |
+
for i, word in enumerate(example.words):
|
231 |
+
if text_preprocessing:
|
232 |
+
word = text_preprocessing(word)
|
233 |
+
subwords = tokenizer.tokenize(word)
|
234 |
+
if (not subwords or len(subwords) > max_length) and word:
|
235 |
+
subwords = [_UNK_TOKEN]
|
236 |
+
|
237 |
+
if len(subwords) + len(new_example.words) > max_length:
|
238 |
+
# Start a new example.
|
239 |
+
new_examples.append(new_example)
|
240 |
+
last_sub_sentence_id = new_example.sub_sentence_id
|
241 |
+
new_example = InputExample(
|
242 |
+
sentence_id=example.sentence_id,
|
243 |
+
sub_sentence_id=last_sub_sentence_id + 1)
|
244 |
+
|
245 |
+
for j, subword in enumerate(subwords):
|
246 |
+
# Use the real label for the first subword, and pad label for
|
247 |
+
# the remainings.
|
248 |
+
subword_label = example.label_ids[i] if j == 0 else _PADDING_LABEL_ID
|
249 |
+
new_example.add_word_and_label_id(subword, subword_label)
|
250 |
+
|
251 |
+
if new_example.words:
|
252 |
+
new_examples.append(new_example)
|
253 |
+
|
254 |
+
return new_examples
|
255 |
+
|
256 |
+
|
257 |
+
def _convert_single_example(example, max_seq_length, tokenizer):
|
258 |
+
"""Converts an `InputExample` instance to a `tf.train.Example` instance."""
|
259 |
+
tokens = ["[CLS]"]
|
260 |
+
tokens.extend(example.words)
|
261 |
+
tokens.append("[SEP]")
|
262 |
+
input_ids = tokenizer.convert_tokens_to_ids(tokens)
|
263 |
+
label_ids = [_PADDING_LABEL_ID]
|
264 |
+
label_ids.extend(example.label_ids)
|
265 |
+
label_ids.append(_PADDING_LABEL_ID)
|
266 |
+
|
267 |
+
segment_ids = [0] * len(input_ids)
|
268 |
+
input_mask = [1] * len(input_ids)
|
269 |
+
|
270 |
+
# Pad up to the sequence length.
|
271 |
+
while len(input_ids) < max_seq_length:
|
272 |
+
input_ids.append(0)
|
273 |
+
input_mask.append(0)
|
274 |
+
segment_ids.append(0)
|
275 |
+
label_ids.append(_PADDING_LABEL_ID)
|
276 |
+
|
277 |
+
def create_int_feature(values):
|
278 |
+
return tf.train.Feature(int64_list=tf.train.Int64List(value=list(values)))
|
279 |
+
|
280 |
+
features = collections.OrderedDict()
|
281 |
+
features["input_ids"] = create_int_feature(input_ids)
|
282 |
+
features["input_mask"] = create_int_feature(input_mask)
|
283 |
+
features["segment_ids"] = create_int_feature(segment_ids)
|
284 |
+
features["label_ids"] = create_int_feature(label_ids)
|
285 |
+
features["sentence_id"] = create_int_feature([example.sentence_id])
|
286 |
+
features["sub_sentence_id"] = create_int_feature([example.sub_sentence_id])
|
287 |
+
|
288 |
+
tf_example = tf.train.Example(features=tf.train.Features(feature=features))
|
289 |
+
return tf_example
|
290 |
+
|
291 |
+
|
292 |
+
def write_example_to_file(examples,
|
293 |
+
tokenizer,
|
294 |
+
max_seq_length,
|
295 |
+
output_file,
|
296 |
+
text_preprocessing=None):
|
297 |
+
"""Writes `InputExample`s into a tfrecord file with `tf.train.Example` protos.
|
298 |
+
|
299 |
+
Note that the words inside each example will be tokenized and be applied by
|
300 |
+
`text_preprocessing` if available. Also, if the length of sentence (plus
|
301 |
+
special [CLS] and [SEP] tokens) exceeds `max_seq_length`, the long sentence
|
302 |
+
will be broken into multiple short examples. For example:
|
303 |
+
|
304 |
+
Example (text_preprocessing=lowercase, max_seq_length=5)
|
305 |
+
words: ["What", "a", "great", "weekend"]
|
306 |
+
labels: [ 7, 5, 9, 10]
|
307 |
+
sentence_id: 0
|
308 |
+
preprocessed: ["what", "a", "great", "weekend"]
|
309 |
+
tokenized: ["what", "a", "great", "week", "##end"]
|
310 |
+
|
311 |
+
will result in two tf.example protos:
|
312 |
+
|
313 |
+
tokens: ["[CLS]", "what", "a", "great", "[SEP]"]
|
314 |
+
label_ids: [-1, 7, 5, 9, -1]
|
315 |
+
input_mask: [ 1, 1, 1, 1, 1]
|
316 |
+
segment_ids: [ 0, 0, 0, 0, 0]
|
317 |
+
input_ids: [ tokenizer.convert_tokens_to_ids(tokens) ]
|
318 |
+
sentence_id: 0
|
319 |
+
|
320 |
+
tokens: ["[CLS]", "week", "##end", "[SEP]", "[PAD]"]
|
321 |
+
label_ids: [-1, 10, -1, -1, -1]
|
322 |
+
input_mask: [ 1, 1, 1, 0, 0]
|
323 |
+
segment_ids: [ 0, 0, 0, 0, 0]
|
324 |
+
input_ids: [ tokenizer.convert_tokens_to_ids(tokens) ]
|
325 |
+
sentence_id: 0
|
326 |
+
|
327 |
+
Note the use of -1 in `label_ids` to indicate that a token should not be
|
328 |
+
considered for classification (e.g., trailing ## wordpieces or special
|
329 |
+
token). Token classification models should accordingly ignore these when
|
330 |
+
calculating loss, metrics, etc...
|
331 |
+
|
332 |
+
Args:
|
333 |
+
examples: A list of `InputExample` instances.
|
334 |
+
tokenizer: The tokenizer to be applied on the data.
|
335 |
+
max_seq_length: Maximum length of generated sequences.
|
336 |
+
output_file: The name of the output tfrecord file.
|
337 |
+
text_preprocessing: optional preprocessing run on each word prior to
|
338 |
+
tokenization.
|
339 |
+
|
340 |
+
Returns:
|
341 |
+
The total number of tf.train.Example proto written to file.
|
342 |
+
"""
|
343 |
+
tf.io.gfile.makedirs(os.path.dirname(output_file))
|
344 |
+
writer = tf.io.TFRecordWriter(output_file)
|
345 |
+
num_tokenized_examples = 0
|
346 |
+
for (ex_index, example) in enumerate(examples):
|
347 |
+
if ex_index % 10000 == 0:
|
348 |
+
logging.info("Writing example %d of %d to %s", ex_index, len(examples),
|
349 |
+
output_file)
|
350 |
+
|
351 |
+
tokenized_examples = _tokenize_example(example, max_seq_length, tokenizer,
|
352 |
+
text_preprocessing)
|
353 |
+
num_tokenized_examples += len(tokenized_examples)
|
354 |
+
for per_tokenized_example in tokenized_examples:
|
355 |
+
tf_example = _convert_single_example(per_tokenized_example,
|
356 |
+
max_seq_length, tokenizer)
|
357 |
+
writer.write(tf_example.SerializeToString())
|
358 |
+
|
359 |
+
writer.close()
|
360 |
+
return num_tokenized_examples
|
361 |
+
|
362 |
+
|
363 |
+
def token_classification_meta_data(train_data_size,
|
364 |
+
max_seq_length,
|
365 |
+
num_labels,
|
366 |
+
eval_data_size=None,
|
367 |
+
test_data_size=None,
|
368 |
+
label_list=None,
|
369 |
+
processor_type=None):
|
370 |
+
"""Creates metadata for tagging (token classification) datasets."""
|
371 |
+
meta_data = {
|
372 |
+
"train_data_size": train_data_size,
|
373 |
+
"max_seq_length": max_seq_length,
|
374 |
+
"num_labels": num_labels,
|
375 |
+
"task_type": "tagging",
|
376 |
+
"label_type": "int",
|
377 |
+
"label_shape": [max_seq_length],
|
378 |
+
}
|
379 |
+
if eval_data_size:
|
380 |
+
meta_data["eval_data_size"] = eval_data_size
|
381 |
+
if test_data_size:
|
382 |
+
meta_data["test_data_size"] = test_data_size
|
383 |
+
if label_list:
|
384 |
+
meta_data["label_list"] = label_list
|
385 |
+
if processor_type:
|
386 |
+
meta_data["processor_type"] = processor_type
|
387 |
+
|
388 |
+
return meta_data
|
389 |
+
|
390 |
+
|
391 |
+
def generate_tf_record_from_data_file(processor, data_dir, tokenizer,
|
392 |
+
max_seq_length, train_data_output_path,
|
393 |
+
eval_data_output_path,
|
394 |
+
test_data_output_path,
|
395 |
+
text_preprocessing):
|
396 |
+
"""Generates tfrecord files from the raw data."""
|
397 |
+
common_kwargs = dict(
|
398 |
+
tokenizer=tokenizer,
|
399 |
+
max_seq_length=max_seq_length,
|
400 |
+
text_preprocessing=text_preprocessing)
|
401 |
+
train_examples = processor.get_train_examples(data_dir)
|
402 |
+
train_data_size = write_example_to_file(
|
403 |
+
train_examples, output_file=train_data_output_path, **common_kwargs)
|
404 |
+
|
405 |
+
eval_examples = processor.get_dev_examples(data_dir)
|
406 |
+
eval_data_size = write_example_to_file(
|
407 |
+
eval_examples, output_file=eval_data_output_path, **common_kwargs)
|
408 |
+
|
409 |
+
test_input_data_examples = processor.get_test_examples(data_dir)
|
410 |
+
test_data_size = {}
|
411 |
+
for language, examples in test_input_data_examples.items():
|
412 |
+
test_data_size[language] = write_example_to_file(
|
413 |
+
examples,
|
414 |
+
output_file=test_data_output_path.format(language),
|
415 |
+
**common_kwargs)
|
416 |
+
|
417 |
+
labels = processor.get_labels()
|
418 |
+
meta_data = token_classification_meta_data(
|
419 |
+
train_data_size,
|
420 |
+
max_seq_length,
|
421 |
+
len(labels),
|
422 |
+
eval_data_size,
|
423 |
+
test_data_size,
|
424 |
+
label_list=labels,
|
425 |
+
processor_type=processor.get_processor_name())
|
426 |
+
return meta_data
|
tagging_data_lib_test.py
ADDED
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
"""Tests for official.nlp.data.tagging_data_lib."""
|
16 |
+
import os
|
17 |
+
import random
|
18 |
+
|
19 |
+
from absl.testing import parameterized
|
20 |
+
import tensorflow as tf, tf_keras
|
21 |
+
|
22 |
+
from official.nlp.data import tagging_data_lib
|
23 |
+
from official.nlp.tools import tokenization
|
24 |
+
|
25 |
+
|
26 |
+
def _create_fake_file(filename, labels, is_test):
|
27 |
+
|
28 |
+
def write_one_sentence(writer, length):
|
29 |
+
for _ in range(length):
|
30 |
+
line = "hiworld"
|
31 |
+
if not is_test:
|
32 |
+
line += "\t%s" % (labels[random.randint(0, len(labels) - 1)])
|
33 |
+
writer.write(line + "\n")
|
34 |
+
|
35 |
+
# Writes two sentences with length of 3 and 12 respectively.
|
36 |
+
with tf.io.gfile.GFile(filename, "w") as writer:
|
37 |
+
write_one_sentence(writer, 3)
|
38 |
+
writer.write("\n")
|
39 |
+
write_one_sentence(writer, 12)
|
40 |
+
|
41 |
+
|
42 |
+
class TaggingDataLibTest(tf.test.TestCase, parameterized.TestCase):
|
43 |
+
|
44 |
+
def setUp(self):
|
45 |
+
super(TaggingDataLibTest, self).setUp()
|
46 |
+
|
47 |
+
self.processors = {
|
48 |
+
"panx": tagging_data_lib.PanxProcessor,
|
49 |
+
"udpos": tagging_data_lib.UdposProcessor,
|
50 |
+
}
|
51 |
+
self.vocab_file = os.path.join(self.get_temp_dir(), "vocab.txt")
|
52 |
+
with tf.io.gfile.GFile(self.vocab_file, "w") as writer:
|
53 |
+
writer.write("\n".join(["[CLS]", "[SEP]", "hi", "##world", "[UNK]"]))
|
54 |
+
|
55 |
+
@parameterized.parameters(
|
56 |
+
{"task_type": "panx"},
|
57 |
+
{"task_type": "udpos"},
|
58 |
+
)
|
59 |
+
def test_generate_tf_record(self, task_type):
|
60 |
+
processor = self.processors[task_type]()
|
61 |
+
input_data_dir = os.path.join(self.get_temp_dir(), task_type)
|
62 |
+
tf.io.gfile.mkdir(input_data_dir)
|
63 |
+
# Write fake train file.
|
64 |
+
_create_fake_file(
|
65 |
+
os.path.join(input_data_dir, "train-en.tsv"),
|
66 |
+
processor.get_labels(),
|
67 |
+
is_test=False)
|
68 |
+
|
69 |
+
# Write fake dev file.
|
70 |
+
_create_fake_file(
|
71 |
+
os.path.join(input_data_dir, "dev-en.tsv"),
|
72 |
+
processor.get_labels(),
|
73 |
+
is_test=False)
|
74 |
+
|
75 |
+
# Write fake test files.
|
76 |
+
for lang in processor.supported_languages:
|
77 |
+
_create_fake_file(
|
78 |
+
os.path.join(input_data_dir, "test-%s.tsv" % lang),
|
79 |
+
processor.get_labels(),
|
80 |
+
is_test=True)
|
81 |
+
|
82 |
+
output_path = os.path.join(self.get_temp_dir(), task_type, "output")
|
83 |
+
tokenizer = tokenization.FullTokenizer(
|
84 |
+
vocab_file=self.vocab_file, do_lower_case=True)
|
85 |
+
metadata = tagging_data_lib.generate_tf_record_from_data_file(
|
86 |
+
processor,
|
87 |
+
input_data_dir,
|
88 |
+
tokenizer,
|
89 |
+
max_seq_length=8,
|
90 |
+
train_data_output_path=os.path.join(output_path, "train.tfrecord"),
|
91 |
+
eval_data_output_path=os.path.join(output_path, "eval.tfrecord"),
|
92 |
+
test_data_output_path=os.path.join(output_path, "test_{}.tfrecord"),
|
93 |
+
text_preprocessing=tokenization.convert_to_unicode)
|
94 |
+
|
95 |
+
self.assertEqual(metadata["train_data_size"], 5)
|
96 |
+
files = tf.io.gfile.glob(output_path + "/*")
|
97 |
+
expected_files = []
|
98 |
+
expected_files.append(os.path.join(output_path, "train.tfrecord"))
|
99 |
+
expected_files.append(os.path.join(output_path, "eval.tfrecord"))
|
100 |
+
for lang in processor.supported_languages:
|
101 |
+
expected_files.append(
|
102 |
+
os.path.join(output_path, "test_%s.tfrecord" % lang))
|
103 |
+
|
104 |
+
self.assertCountEqual(files, expected_files)
|
105 |
+
|
106 |
+
|
107 |
+
if __name__ == "__main__":
|
108 |
+
tf.test.main()
|
tagging_dataloader.py
ADDED
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
"""Loads dataset for the tagging (e.g., NER/POS) task."""
|
16 |
+
import dataclasses
|
17 |
+
from typing import Mapping, Optional
|
18 |
+
|
19 |
+
import tensorflow as tf, tf_keras
|
20 |
+
from official.common import dataset_fn
|
21 |
+
from official.core import config_definitions as cfg
|
22 |
+
from official.core import input_reader
|
23 |
+
from official.nlp.data import data_loader
|
24 |
+
from official.nlp.data import data_loader_factory
|
25 |
+
|
26 |
+
|
27 |
+
@dataclasses.dataclass
|
28 |
+
class TaggingDataConfig(cfg.DataConfig):
|
29 |
+
"""Data config for tagging (tasks/tagging)."""
|
30 |
+
is_training: bool = True
|
31 |
+
seq_length: int = 128
|
32 |
+
include_sentence_id: bool = False
|
33 |
+
file_type: str = 'tfrecord'
|
34 |
+
|
35 |
+
|
36 |
+
@data_loader_factory.register_data_loader_cls(TaggingDataConfig)
|
37 |
+
class TaggingDataLoader(data_loader.DataLoader):
|
38 |
+
"""A class to load dataset for tagging (e.g., NER and POS) task."""
|
39 |
+
|
40 |
+
def __init__(self, params: TaggingDataConfig):
|
41 |
+
self._params = params
|
42 |
+
self._seq_length = params.seq_length
|
43 |
+
self._include_sentence_id = params.include_sentence_id
|
44 |
+
|
45 |
+
def _decode(self, record: tf.Tensor):
|
46 |
+
"""Decodes a serialized tf.Example."""
|
47 |
+
name_to_features = {
|
48 |
+
'input_ids': tf.io.FixedLenFeature([self._seq_length], tf.int64),
|
49 |
+
'input_mask': tf.io.FixedLenFeature([self._seq_length], tf.int64),
|
50 |
+
'segment_ids': tf.io.FixedLenFeature([self._seq_length], tf.int64),
|
51 |
+
'label_ids': tf.io.FixedLenFeature([self._seq_length], tf.int64),
|
52 |
+
}
|
53 |
+
if self._include_sentence_id:
|
54 |
+
name_to_features['sentence_id'] = tf.io.FixedLenFeature([], tf.int64)
|
55 |
+
name_to_features['sub_sentence_id'] = tf.io.FixedLenFeature([], tf.int64)
|
56 |
+
|
57 |
+
example = tf.io.parse_single_example(record, name_to_features)
|
58 |
+
|
59 |
+
# tf.Example only supports tf.int64, but the TPU only supports tf.int32.
|
60 |
+
# So cast all int64 to int32.
|
61 |
+
for name in example:
|
62 |
+
t = example[name]
|
63 |
+
if t.dtype == tf.int64:
|
64 |
+
t = tf.cast(t, tf.int32)
|
65 |
+
example[name] = t
|
66 |
+
|
67 |
+
return example
|
68 |
+
|
69 |
+
def _parse(self, record: Mapping[str, tf.Tensor]):
|
70 |
+
"""Parses raw tensors into a dict of tensors to be consumed by the model."""
|
71 |
+
x = {
|
72 |
+
'input_word_ids': record['input_ids'],
|
73 |
+
'input_mask': record['input_mask'],
|
74 |
+
'input_type_ids': record['segment_ids']
|
75 |
+
}
|
76 |
+
if self._include_sentence_id:
|
77 |
+
x['sentence_id'] = record['sentence_id']
|
78 |
+
x['sub_sentence_id'] = record['sub_sentence_id']
|
79 |
+
|
80 |
+
y = record['label_ids']
|
81 |
+
return (x, y)
|
82 |
+
|
83 |
+
def load(self, input_context: Optional[tf.distribute.InputContext] = None):
|
84 |
+
"""Returns a tf.dataset.Dataset."""
|
85 |
+
reader = input_reader.InputReader(
|
86 |
+
params=self._params,
|
87 |
+
dataset_fn=dataset_fn.pick_dataset_fn(self._params.file_type),
|
88 |
+
decoder_fn=self._decode,
|
89 |
+
parser_fn=self._parse)
|
90 |
+
return reader.read(input_context)
|
tagging_dataloader_test.py
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
"""Tests for official.nlp.data.tagging_data_loader."""
|
16 |
+
import os
|
17 |
+
|
18 |
+
from absl.testing import parameterized
|
19 |
+
import numpy as np
|
20 |
+
import tensorflow as tf, tf_keras
|
21 |
+
|
22 |
+
from official.nlp.data import tagging_dataloader
|
23 |
+
|
24 |
+
|
25 |
+
def _create_fake_dataset(output_path, seq_length, include_sentence_id):
|
26 |
+
"""Creates a fake dataset."""
|
27 |
+
writer = tf.io.TFRecordWriter(output_path)
|
28 |
+
|
29 |
+
def create_int_feature(values):
|
30 |
+
f = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values)))
|
31 |
+
return f
|
32 |
+
|
33 |
+
for i in range(100):
|
34 |
+
features = {}
|
35 |
+
input_ids = np.random.randint(100, size=(seq_length))
|
36 |
+
features['input_ids'] = create_int_feature(input_ids)
|
37 |
+
features['input_mask'] = create_int_feature(np.ones_like(input_ids))
|
38 |
+
features['segment_ids'] = create_int_feature(np.ones_like(input_ids))
|
39 |
+
features['label_ids'] = create_int_feature(
|
40 |
+
np.random.randint(10, size=(seq_length)))
|
41 |
+
if include_sentence_id:
|
42 |
+
features['sentence_id'] = create_int_feature([i])
|
43 |
+
features['sub_sentence_id'] = create_int_feature([0])
|
44 |
+
|
45 |
+
tf_example = tf.train.Example(features=tf.train.Features(feature=features))
|
46 |
+
writer.write(tf_example.SerializeToString())
|
47 |
+
writer.close()
|
48 |
+
|
49 |
+
|
50 |
+
class TaggingDataLoaderTest(tf.test.TestCase, parameterized.TestCase):
|
51 |
+
|
52 |
+
@parameterized.parameters(True, False)
|
53 |
+
def test_load_dataset(self, include_sentence_id):
|
54 |
+
seq_length = 16
|
55 |
+
batch_size = 10
|
56 |
+
train_data_path = os.path.join(self.get_temp_dir(), 'train.tf_record')
|
57 |
+
_create_fake_dataset(train_data_path, seq_length, include_sentence_id)
|
58 |
+
data_config = tagging_dataloader.TaggingDataConfig(
|
59 |
+
input_path=train_data_path,
|
60 |
+
seq_length=seq_length,
|
61 |
+
global_batch_size=batch_size,
|
62 |
+
include_sentence_id=include_sentence_id)
|
63 |
+
|
64 |
+
dataset = tagging_dataloader.TaggingDataLoader(data_config).load()
|
65 |
+
features, labels = next(iter(dataset))
|
66 |
+
|
67 |
+
expected_keys = ['input_word_ids', 'input_mask', 'input_type_ids']
|
68 |
+
if include_sentence_id:
|
69 |
+
expected_keys.extend(['sentence_id', 'sub_sentence_id'])
|
70 |
+
self.assertCountEqual(expected_keys, features.keys())
|
71 |
+
|
72 |
+
self.assertEqual(features['input_word_ids'].shape, (batch_size, seq_length))
|
73 |
+
self.assertEqual(features['input_mask'].shape, (batch_size, seq_length))
|
74 |
+
self.assertEqual(features['input_type_ids'].shape, (batch_size, seq_length))
|
75 |
+
self.assertEqual(labels.shape, (batch_size, seq_length))
|
76 |
+
if include_sentence_id:
|
77 |
+
self.assertEqual(features['sentence_id'].shape, (batch_size,))
|
78 |
+
self.assertEqual(features['sub_sentence_id'].shape, (batch_size,))
|
79 |
+
|
80 |
+
|
81 |
+
if __name__ == '__main__':
|
82 |
+
tf.test.main()
|
train_sentencepiece.py
ADDED
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
"""A script to train sentencepiece model from tensorflow datasets.
|
16 |
+
|
17 |
+
Reserved tokens:
|
18 |
+
pad: 0,
|
19 |
+
eos: 1,
|
20 |
+
unk: 2
|
21 |
+
(bos is not reserved)
|
22 |
+
"""
|
23 |
+
|
24 |
+
import os
|
25 |
+
import tempfile
|
26 |
+
from typing import List, Tuple
|
27 |
+
|
28 |
+
from absl import app
|
29 |
+
from absl import flags
|
30 |
+
from absl import logging
|
31 |
+
import tensorflow as tf, tf_keras
|
32 |
+
import tensorflow_datasets as tfds
|
33 |
+
|
34 |
+
from sentencepiece import SentencePieceTrainer
|
35 |
+
|
36 |
+
|
37 |
+
FLAGS = flags.FLAGS
|
38 |
+
flags.DEFINE_string("output_model_path", None,
|
39 |
+
"Path to save the sentencepiece model.")
|
40 |
+
flags.mark_flag_as_required("output_model_path")
|
41 |
+
|
42 |
+
flags.DEFINE_string("tfds_dir", None, "Directory of the tfds.")
|
43 |
+
flags.DEFINE_string("tfds_name", "wmt14_translate/de-en",
|
44 |
+
"Name of the dataset we generate vacabulay from.")
|
45 |
+
flags.DEFINE_string("tfds_split", "train", "Split of the dataset.")
|
46 |
+
flags.DEFINE_integer("vocab_size", 32000, "Size of vocabulary.")
|
47 |
+
flags.DEFINE_integer(
|
48 |
+
"max_char", -1,
|
49 |
+
"Maximum number of characters to use. "
|
50 |
+
"If a non-positive number is provided, all sentences are used.")
|
51 |
+
flags.DEFINE_string("model_type", "bpe",
|
52 |
+
"Model algorithm: unigram, bpe, word or char.")
|
53 |
+
flags.DEFINE_float("character_coverage", 0.9995,
|
54 |
+
"Character coverage to determine the minimum symbols")
|
55 |
+
flags.DEFINE_list(
|
56 |
+
"data_keys", ["en", "de"],
|
57 |
+
"Comma-separated list of keys to use for training the vocabulary.")
|
58 |
+
|
59 |
+
|
60 |
+
def dump_chars_to_textfile(dataset: tf.data.Dataset,
|
61 |
+
data_keys: Tuple[str],
|
62 |
+
max_char: int = -1):
|
63 |
+
"""Write part of a TFDS sentence dataset to lines in a text file.
|
64 |
+
|
65 |
+
Args:
|
66 |
+
dataset: tf.dataset containing string-data.
|
67 |
+
data_keys: what keys in dataset to dump from.
|
68 |
+
max_char: max character to dump to text file.
|
69 |
+
|
70 |
+
Returns:
|
71 |
+
name of temp file with dataset bytes, exact number of characters dumped.
|
72 |
+
"""
|
73 |
+
ds_iter = dataset.as_numpy_iterator()
|
74 |
+
with tempfile.NamedTemporaryFile(delete=False) as outfp:
|
75 |
+
char_count = 0
|
76 |
+
while True:
|
77 |
+
example = next(ds_iter, None)
|
78 |
+
if example is None or (
|
79 |
+
max_char > 0 and char_count > max_char):
|
80 |
+
break
|
81 |
+
for k in data_keys:
|
82 |
+
line = example[k] + b"\n"
|
83 |
+
char_count += len(line)
|
84 |
+
outfp.write(line)
|
85 |
+
return outfp.name
|
86 |
+
|
87 |
+
|
88 |
+
def train_sentencepiece(
|
89 |
+
file_path: str,
|
90 |
+
model_path: str,
|
91 |
+
vocab_size: int,
|
92 |
+
character_coverage: float,
|
93 |
+
model_type: str):
|
94 |
+
"""Train SentencePiece tokenizer from subset of tf dataset.
|
95 |
+
|
96 |
+
Args:
|
97 |
+
file_path: path of data to train sentencepiece.
|
98 |
+
model_path: path of model file to save vocab model to.
|
99 |
+
vocab_size: size of vocab tokens to train.
|
100 |
+
character_coverage: amount of characters covered by the model, good defaults
|
101 |
+
are 0.9995 for languages with rich character set like Japanese or Chinese
|
102 |
+
and 1.0 for other languages with small character set.
|
103 |
+
model_type: type of sentencepiece vocab to train.
|
104 |
+
|
105 |
+
Returns:
|
106 |
+
path to the trained sentencepiece vocabulary model.
|
107 |
+
"""
|
108 |
+
argstr = " ".join([
|
109 |
+
f"--input={file_path}", f"--vocab_size={vocab_size}",
|
110 |
+
f"--character_coverage={character_coverage}",
|
111 |
+
f"--model_prefix={model_path}", f"--model_type={model_type}",
|
112 |
+
"--bos_id=-1", "--pad_id=0", "--eos_id=1", "--unk_id=2"
|
113 |
+
])
|
114 |
+
SentencePieceTrainer.Train(argstr)
|
115 |
+
|
116 |
+
|
117 |
+
def main(argv: List[str]):
|
118 |
+
del argv
|
119 |
+
builder = tfds.builder(FLAGS.tfds_name, data_dir=FLAGS.tfds_dir)
|
120 |
+
ds = builder.as_dataset(split=FLAGS.tfds_split)
|
121 |
+
tmp_filename = dump_chars_to_textfile(ds, FLAGS.data_keys, FLAGS.max_char)
|
122 |
+
logging.info("Sentencepiece model will be placed here: %s",
|
123 |
+
FLAGS.output_model_path)
|
124 |
+
train_sentencepiece(tmp_filename,
|
125 |
+
FLAGS.output_model_path,
|
126 |
+
FLAGS.vocab_size,
|
127 |
+
FLAGS.character_coverage,
|
128 |
+
FLAGS.model_type)
|
129 |
+
os.remove(tmp_filename)
|
130 |
+
|
131 |
+
|
132 |
+
if __name__ == "__main__":
|
133 |
+
app.run(main)
|
wmt_dataloader.py
ADDED
@@ -0,0 +1,295 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
"""Input pipeline for the transformer model to read, filter, and batch examples.
|
16 |
+
|
17 |
+
Batching scheme
|
18 |
+
|
19 |
+
Prior to batching, elements in the dataset are grouped by length (max between
|
20 |
+
'inputs' and 'targets' length). Each group is then batched such that:
|
21 |
+
group_batch_size * length <= batch_size.
|
22 |
+
|
23 |
+
Another way to view batch_size is the maximum number of tokens in each batch.
|
24 |
+
|
25 |
+
Once batched, each element in the dataset will have the shape:
|
26 |
+
{'inputs': [group_batch_size, padded_input_length],
|
27 |
+
'targets': [group_batch_size, padded_target_length]}
|
28 |
+
Lengths are padded to the longest 'inputs' or 'targets' sequence in the batch
|
29 |
+
(padded_input_length and padded_target_length can be different).
|
30 |
+
|
31 |
+
This batching scheme decreases the fraction of padding tokens per training
|
32 |
+
batch, thus improving the training speed significantly.
|
33 |
+
"""
|
34 |
+
from typing import Dict, Optional
|
35 |
+
|
36 |
+
import dataclasses
|
37 |
+
import tensorflow as tf, tf_keras
|
38 |
+
import tensorflow_text as tftxt
|
39 |
+
from official.core import config_definitions as cfg
|
40 |
+
from official.core import input_reader
|
41 |
+
from official.nlp.data import data_loader
|
42 |
+
from official.nlp.data import data_loader_factory
|
43 |
+
|
44 |
+
# Example grouping constants. Defines length boundaries for each group.
|
45 |
+
# These values are the defaults used in Tensor2Tensor.
|
46 |
+
_MIN_BOUNDARY = 8
|
47 |
+
_BOUNDARY_SCALE = 1.1
|
48 |
+
|
49 |
+
|
50 |
+
def _get_example_length(example):
|
51 |
+
"""Returns the maximum length between the example inputs and targets."""
|
52 |
+
length = tf.maximum(tf.shape(example[0])[0], tf.shape(example[1])[0])
|
53 |
+
return length
|
54 |
+
|
55 |
+
|
56 |
+
def _create_min_max_boundaries(max_length,
|
57 |
+
min_boundary=_MIN_BOUNDARY,
|
58 |
+
boundary_scale=_BOUNDARY_SCALE):
|
59 |
+
"""Create min and max boundary lists up to max_length.
|
60 |
+
|
61 |
+
For example, when max_length=24, min_boundary=4 and boundary_scale=2, the
|
62 |
+
returned values will be:
|
63 |
+
buckets_min = [0, 4, 8, 16]
|
64 |
+
buckets_max = [4, 8, 16, 25]
|
65 |
+
|
66 |
+
Args:
|
67 |
+
max_length: The maximum length of example in dataset.
|
68 |
+
min_boundary: Minimum length in boundary.
|
69 |
+
boundary_scale: Amount to scale consecutive boundaries in the list.
|
70 |
+
|
71 |
+
Returns:
|
72 |
+
min and max boundary lists
|
73 |
+
|
74 |
+
"""
|
75 |
+
# Create bucket boundaries list by scaling the previous boundary or adding 1
|
76 |
+
# (to ensure increasing boundary sizes).
|
77 |
+
bucket_boundaries = []
|
78 |
+
x = min_boundary
|
79 |
+
while x < max_length:
|
80 |
+
bucket_boundaries.append(x)
|
81 |
+
x = max(x + 1, int(x * boundary_scale))
|
82 |
+
|
83 |
+
# Create min and max boundary lists from the initial list.
|
84 |
+
buckets_min = [0] + bucket_boundaries
|
85 |
+
buckets_max = bucket_boundaries + [max_length + 1]
|
86 |
+
return buckets_min, buckets_max
|
87 |
+
|
88 |
+
|
89 |
+
def _batch_examples(dataset, batch_size, max_length):
|
90 |
+
"""Group examples by similar lengths, and return batched dataset.
|
91 |
+
|
92 |
+
Each batch of similar-length examples are padded to the same length, and may
|
93 |
+
have different number of elements in each batch, such that:
|
94 |
+
group_batch_size * padded_length <= batch_size.
|
95 |
+
|
96 |
+
This decreases the number of padding tokens per batch, which improves the
|
97 |
+
training speed.
|
98 |
+
|
99 |
+
Args:
|
100 |
+
dataset: Dataset of unbatched examples.
|
101 |
+
batch_size: Max number of tokens per batch of examples.
|
102 |
+
max_length: Max number of tokens in an example input or target sequence.
|
103 |
+
|
104 |
+
Returns:
|
105 |
+
Dataset of batched examples with similar lengths.
|
106 |
+
"""
|
107 |
+
# Get min and max boundary lists for each example. These are used to calculate
|
108 |
+
# the `bucket_id`, which is the index at which:
|
109 |
+
# buckets_min[bucket_id] <= len(example) < buckets_max[bucket_id]
|
110 |
+
# Note that using both min and max lists improves the performance.
|
111 |
+
buckets_min, buckets_max = _create_min_max_boundaries(max_length)
|
112 |
+
|
113 |
+
# Create list of batch sizes for each bucket_id, so that
|
114 |
+
# bucket_batch_size[bucket_id] * buckets_max[bucket_id] <= batch_size
|
115 |
+
bucket_batch_sizes = [int(batch_size) // x for x in buckets_max]
|
116 |
+
|
117 |
+
# Validates bucket batch sizes.
|
118 |
+
if any([batch_size <= 0 for batch_size in bucket_batch_sizes]):
|
119 |
+
raise ValueError(
|
120 |
+
'The token budget, global batch size, is too small to yield 0 bucket '
|
121 |
+
'window: %s' % str(bucket_batch_sizes))
|
122 |
+
|
123 |
+
# bucket_id will be a tensor, so convert this list to a tensor as well.
|
124 |
+
bucket_batch_sizes = tf.constant(bucket_batch_sizes, dtype=tf.int64)
|
125 |
+
|
126 |
+
def example_to_bucket_id(example):
|
127 |
+
"""Return int64 bucket id for this example, calculated based on length."""
|
128 |
+
example_input = example['inputs']
|
129 |
+
example_target = example['targets']
|
130 |
+
seq_length = _get_example_length((example_input, example_target))
|
131 |
+
|
132 |
+
conditions_c = tf.logical_and(
|
133 |
+
tf.less_equal(buckets_min, seq_length), tf.less(seq_length,
|
134 |
+
buckets_max))
|
135 |
+
bucket_id = tf.reduce_min(tf.where(conditions_c))
|
136 |
+
return bucket_id
|
137 |
+
|
138 |
+
def window_size_fn(bucket_id):
|
139 |
+
"""Return number of examples to be grouped when given a bucket id."""
|
140 |
+
return bucket_batch_sizes[bucket_id]
|
141 |
+
|
142 |
+
def batching_fn(bucket_id, grouped_dataset):
|
143 |
+
"""Batch and add padding to a dataset of elements with similar lengths."""
|
144 |
+
bucket_batch_size = window_size_fn(bucket_id)
|
145 |
+
|
146 |
+
# Batch the dataset and add padding so that all input sequences in the
|
147 |
+
# examples have the same length, and all target sequences have the same
|
148 |
+
# lengths as well. Resulting lengths of inputs and targets can differ.
|
149 |
+
padded_shapes = dict([
|
150 |
+
(name, [None] * len(spec.shape))
|
151 |
+
for name, spec in grouped_dataset.element_spec.items()
|
152 |
+
])
|
153 |
+
return grouped_dataset.padded_batch(bucket_batch_size, padded_shapes)
|
154 |
+
|
155 |
+
return dataset.apply(
|
156 |
+
tf.data.experimental.group_by_window(
|
157 |
+
key_func=example_to_bucket_id,
|
158 |
+
reduce_func=batching_fn,
|
159 |
+
window_size=None,
|
160 |
+
window_size_func=window_size_fn))
|
161 |
+
|
162 |
+
|
163 |
+
@dataclasses.dataclass
|
164 |
+
class WMTDataConfig(cfg.DataConfig):
|
165 |
+
"""Data config for WMT translation."""
|
166 |
+
max_seq_length: int = 64
|
167 |
+
static_batch: bool = False
|
168 |
+
sentencepiece_model_path: str = ''
|
169 |
+
src_lang: str = ''
|
170 |
+
tgt_lang: str = ''
|
171 |
+
transform_and_batch: bool = True
|
172 |
+
has_unique_id: bool = False
|
173 |
+
|
174 |
+
|
175 |
+
@data_loader_factory.register_data_loader_cls(WMTDataConfig)
|
176 |
+
class WMTDataLoader(data_loader.DataLoader):
|
177 |
+
"""A class to load dataset for WMT translation task."""
|
178 |
+
|
179 |
+
def __init__(self, params: WMTDataConfig):
|
180 |
+
self._params = params
|
181 |
+
self._max_seq_length = params.max_seq_length
|
182 |
+
self._static_batch = params.static_batch
|
183 |
+
self._global_batch_size = params.global_batch_size
|
184 |
+
if self._params.transform_and_batch:
|
185 |
+
self._tokenizer = tftxt.SentencepieceTokenizer(
|
186 |
+
model=tf.io.gfile.GFile(params.sentencepiece_model_path, 'rb').read(),
|
187 |
+
add_eos=True)
|
188 |
+
|
189 |
+
def _decode(self, record: tf.Tensor):
|
190 |
+
"""Decodes a serialized tf.Example."""
|
191 |
+
name_to_features = {
|
192 |
+
self._params.src_lang: tf.io.FixedLenFeature([], tf.string),
|
193 |
+
self._params.tgt_lang: tf.io.FixedLenFeature([], tf.string),
|
194 |
+
}
|
195 |
+
if self._params.has_unique_id:
|
196 |
+
name_to_features['unique_id'] = tf.io.FixedLenFeature([], tf.int64)
|
197 |
+
example = tf.io.parse_single_example(record, name_to_features)
|
198 |
+
# tf.Example only supports tf.int64, but the TPU only supports tf.int32.
|
199 |
+
# So cast all int64 to int32.
|
200 |
+
for name in example:
|
201 |
+
t = example[name]
|
202 |
+
if t.dtype == tf.int64:
|
203 |
+
t = tf.cast(t, tf.int32)
|
204 |
+
example[name] = t
|
205 |
+
return example
|
206 |
+
|
207 |
+
def _tokenize(self, inputs) -> Dict[str, tf.Tensor]:
|
208 |
+
tokenized_inputs = {}
|
209 |
+
for k, v in inputs.items():
|
210 |
+
if k == self._params.src_lang:
|
211 |
+
tokenized_inputs['inputs'] = self._tokenizer.tokenize(v)
|
212 |
+
elif k == self._params.tgt_lang:
|
213 |
+
tokenized_inputs['targets'] = self._tokenizer.tokenize(v)
|
214 |
+
else:
|
215 |
+
tokenized_inputs[k] = v
|
216 |
+
print(tokenized_inputs)
|
217 |
+
return tokenized_inputs
|
218 |
+
|
219 |
+
def _filter_max_length(self, inputs):
|
220 |
+
# return tf.constant(True)
|
221 |
+
return tf.logical_and(
|
222 |
+
tf.shape(inputs['inputs'])[0] <= self._max_seq_length,
|
223 |
+
tf.shape(inputs['targets'])[0] <= self._max_seq_length)
|
224 |
+
|
225 |
+
def _maybe_truncate(self, inputs):
|
226 |
+
truncated_inputs = {}
|
227 |
+
for k, v in inputs.items():
|
228 |
+
if k == 'inputs' or k == 'targets':
|
229 |
+
truncated_inputs[k] = tf.pad(
|
230 |
+
v[:self._max_seq_length - 1], [[0, 1]],
|
231 |
+
constant_values=1) if tf.shape(v)[0] > self._max_seq_length else v
|
232 |
+
else:
|
233 |
+
truncated_inputs[k] = v
|
234 |
+
return truncated_inputs
|
235 |
+
|
236 |
+
def _tokenize_bucketize_and_batch(
|
237 |
+
self,
|
238 |
+
dataset,
|
239 |
+
input_context: Optional[tf.distribute.InputContext] = None):
|
240 |
+
dataset = dataset.map(
|
241 |
+
self._tokenize, num_parallel_calls=tf.data.experimental.AUTOTUNE)
|
242 |
+
|
243 |
+
if self._params.is_training:
|
244 |
+
dataset = dataset.filter(self._filter_max_length)
|
245 |
+
else:
|
246 |
+
dataset = dataset.map(
|
247 |
+
self._maybe_truncate,
|
248 |
+
num_parallel_calls=tf.data.experimental.AUTOTUNE)
|
249 |
+
|
250 |
+
per_replica_batch_size = input_context.get_per_replica_batch_size(
|
251 |
+
self._global_batch_size) if input_context else self._global_batch_size
|
252 |
+
if self._static_batch:
|
253 |
+
padded_shapes = {}
|
254 |
+
for name, _ in dataset.element_spec.items():
|
255 |
+
if name == 'unique_id':
|
256 |
+
padded_shapes[name] = []
|
257 |
+
else:
|
258 |
+
padded_shapes[name] = [self._max_seq_length
|
259 |
+
] if self._static_batch else [None]
|
260 |
+
batch_size = per_replica_batch_size
|
261 |
+
if self._params.is_training:
|
262 |
+
batch_size = int(batch_size // self._max_seq_length)
|
263 |
+
dataset = dataset.padded_batch(
|
264 |
+
batch_size,
|
265 |
+
padded_shapes,
|
266 |
+
drop_remainder=True)
|
267 |
+
else:
|
268 |
+
# Group and batch such that each batch has examples of similar length.
|
269 |
+
dataset = _batch_examples(dataset, per_replica_batch_size,
|
270 |
+
self._max_seq_length)
|
271 |
+
# Prefetch the next element to improve speed of input pipeline.
|
272 |
+
dataset = dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
|
273 |
+
return dataset
|
274 |
+
|
275 |
+
def load(self, input_context: Optional[tf.distribute.InputContext] = None):
|
276 |
+
"""Returns a tf.dataset.Dataset."""
|
277 |
+
decoder_fn = None
|
278 |
+
# Only decode for TFRecords.
|
279 |
+
if self._params.input_path:
|
280 |
+
decoder_fn = self._decode
|
281 |
+
|
282 |
+
def _identity(
|
283 |
+
dataset, input_context: Optional[tf.distribute.InputContext] = None):
|
284 |
+
del input_context
|
285 |
+
return dataset
|
286 |
+
|
287 |
+
transform_and_batch_fn = _identity
|
288 |
+
if self._params.transform_and_batch:
|
289 |
+
transform_and_batch_fn = self._tokenize_bucketize_and_batch
|
290 |
+
|
291 |
+
reader = input_reader.InputReader(
|
292 |
+
params=self._params,
|
293 |
+
decoder_fn=decoder_fn,
|
294 |
+
transform_and_batch_fn=transform_and_batch_fn)
|
295 |
+
return reader.read(input_context)
|
wmt_dataloader_test.py
ADDED
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
"""Tests for official.nlp.data.wmt_dataloader."""
|
16 |
+
import os
|
17 |
+
from absl.testing import parameterized
|
18 |
+
|
19 |
+
import tensorflow as tf, tf_keras
|
20 |
+
|
21 |
+
from sentencepiece import SentencePieceTrainer
|
22 |
+
from official.nlp.data import wmt_dataloader
|
23 |
+
|
24 |
+
|
25 |
+
def _generate_line_file(filepath, lines):
|
26 |
+
with tf.io.gfile.GFile(filepath, 'w') as f:
|
27 |
+
for l in lines:
|
28 |
+
f.write('{}\n'.format(l))
|
29 |
+
|
30 |
+
|
31 |
+
def _generate_record_file(filepath, src_lines, tgt_lines, unique_id=False):
|
32 |
+
writer = tf.io.TFRecordWriter(filepath)
|
33 |
+
for i, (src, tgt) in enumerate(zip(src_lines, tgt_lines)):
|
34 |
+
features = {
|
35 |
+
'en': tf.train.Feature(
|
36 |
+
bytes_list=tf.train.BytesList(
|
37 |
+
value=[src.encode()])),
|
38 |
+
'reverse_en': tf.train.Feature(
|
39 |
+
bytes_list=tf.train.BytesList(
|
40 |
+
value=[tgt.encode()])),
|
41 |
+
}
|
42 |
+
if unique_id:
|
43 |
+
features['unique_id'] = tf.train.Feature(
|
44 |
+
int64_list=tf.train.Int64List(value=[i]))
|
45 |
+
example = tf.train.Example(
|
46 |
+
features=tf.train.Features(
|
47 |
+
feature=features))
|
48 |
+
writer.write(example.SerializeToString())
|
49 |
+
writer.close()
|
50 |
+
|
51 |
+
|
52 |
+
def _train_sentencepiece(input_path, vocab_size, model_path, eos_id=1):
|
53 |
+
argstr = ' '.join([
|
54 |
+
f'--input={input_path}', f'--vocab_size={vocab_size}',
|
55 |
+
'--character_coverage=0.995',
|
56 |
+
f'--model_prefix={model_path}', '--model_type=bpe',
|
57 |
+
'--bos_id=-1', '--pad_id=0', f'--eos_id={eos_id}', '--unk_id=2'
|
58 |
+
])
|
59 |
+
SentencePieceTrainer.Train(argstr)
|
60 |
+
|
61 |
+
|
62 |
+
class WMTDataLoaderTest(tf.test.TestCase, parameterized.TestCase):
|
63 |
+
|
64 |
+
def setUp(self):
|
65 |
+
super(WMTDataLoaderTest, self).setUp()
|
66 |
+
self._temp_dir = self.get_temp_dir()
|
67 |
+
src_lines = [
|
68 |
+
'abc ede fg',
|
69 |
+
'bbcd ef a g',
|
70 |
+
'de f a a g'
|
71 |
+
]
|
72 |
+
tgt_lines = [
|
73 |
+
'dd cc a ef g',
|
74 |
+
'bcd ef a g',
|
75 |
+
'gef cd ba'
|
76 |
+
]
|
77 |
+
self._record_train_input_path = os.path.join(self._temp_dir, 'train.record')
|
78 |
+
_generate_record_file(self._record_train_input_path, src_lines, tgt_lines)
|
79 |
+
self._record_test_input_path = os.path.join(self._temp_dir, 'test.record')
|
80 |
+
_generate_record_file(self._record_test_input_path, src_lines, tgt_lines,
|
81 |
+
unique_id=True)
|
82 |
+
self._sentencepeice_input_path = os.path.join(self._temp_dir, 'inputs.txt')
|
83 |
+
_generate_line_file(self._sentencepeice_input_path, src_lines + tgt_lines)
|
84 |
+
sentencepeice_model_prefix = os.path.join(self._temp_dir, 'sp')
|
85 |
+
_train_sentencepiece(self._sentencepeice_input_path, 20,
|
86 |
+
sentencepeice_model_prefix)
|
87 |
+
self._sentencepeice_model_path = '{}.model'.format(
|
88 |
+
sentencepeice_model_prefix)
|
89 |
+
|
90 |
+
@parameterized.named_parameters(
|
91 |
+
('train_static', True, True, 100, (2, 35)),
|
92 |
+
('train_non_static', True, False, 100, (12, 7)),
|
93 |
+
('non_train_static', False, True, 3, (3, 35)),
|
94 |
+
('non_train_non_static', False, False, 50, (2, 7)),)
|
95 |
+
def test_load_dataset(
|
96 |
+
self, is_training, static_batch, batch_size, expected_shape):
|
97 |
+
data_config = wmt_dataloader.WMTDataConfig(
|
98 |
+
input_path=self._record_train_input_path
|
99 |
+
if is_training else self._record_test_input_path,
|
100 |
+
max_seq_length=35,
|
101 |
+
global_batch_size=batch_size,
|
102 |
+
is_training=is_training,
|
103 |
+
static_batch=static_batch,
|
104 |
+
src_lang='en',
|
105 |
+
tgt_lang='reverse_en',
|
106 |
+
sentencepiece_model_path=self._sentencepeice_model_path)
|
107 |
+
dataset = wmt_dataloader.WMTDataLoader(data_config).load()
|
108 |
+
examples = next(iter(dataset))
|
109 |
+
inputs, targets = examples['inputs'], examples['targets']
|
110 |
+
self.assertEqual(inputs.shape, expected_shape)
|
111 |
+
self.assertEqual(targets.shape, expected_shape)
|
112 |
+
|
113 |
+
def test_load_dataset_raise_invalid_window(self):
|
114 |
+
batch_tokens_size = 10 # this is too small to form buckets.
|
115 |
+
data_config = wmt_dataloader.WMTDataConfig(
|
116 |
+
input_path=self._record_train_input_path,
|
117 |
+
max_seq_length=100,
|
118 |
+
global_batch_size=batch_tokens_size,
|
119 |
+
is_training=True,
|
120 |
+
static_batch=False,
|
121 |
+
src_lang='en',
|
122 |
+
tgt_lang='reverse_en',
|
123 |
+
sentencepiece_model_path=self._sentencepeice_model_path)
|
124 |
+
with self.assertRaisesRegex(
|
125 |
+
ValueError, 'The token budget, global batch size, is too small.*'):
|
126 |
+
_ = wmt_dataloader.WMTDataLoader(data_config).load()
|
127 |
+
|
128 |
+
|
129 |
+
if __name__ == '__main__':
|
130 |
+
tf.test.main()
|