Pradeep Kumar commited on
Commit
c130734
·
verified ·
1 Parent(s): b64b72d

Upload 10 files

Browse files
export_tfhub.py ADDED
@@ -0,0 +1,219 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The TensorFlow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ r"""Exports a BERT-like encoder and its preprocessing as SavedModels for TF Hub.
16
+
17
+ This tool creates preprocessor and encoder SavedModels suitable for uploading
18
+ to https://tfhub.dev that implement the preprocessor and encoder APIs defined
19
+ at https://www.tensorflow.org/hub/common_saved_model_apis/text.
20
+
21
+ For a full usage guide, see
22
+ https://github.com/tensorflow/models/blob/master/official/nlp/docs/tfhub.md
23
+
24
+ Minimal usage examples:
25
+
26
+ 1) Exporting an Encoder from checkpoint and config.
27
+
28
+ ```
29
+ export_tfhub \
30
+ --encoder_config_file=${BERT_DIR:?}/bert_encoder.yaml \
31
+ --model_checkpoint_path=${BERT_DIR:?}/bert_model.ckpt \
32
+ --vocab_file=${BERT_DIR:?}/vocab.txt \
33
+ --export_type=model \
34
+ --export_path=/tmp/bert_model
35
+ ```
36
+
37
+ An --encoder_config_file can specify encoder types other than BERT.
38
+ For BERT, a --bert_config_file in the legacy JSON format can be passed instead.
39
+
40
+ Flag --vocab_file (and flag --do_lower_case, whose default value is guessed
41
+ from the vocab_file path) capture how BertTokenizer was used in pre-training.
42
+ Use flag --sp_model_file instead if SentencepieceTokenizer was used.
43
+
44
+ Changing --export_type to model_with_mlm additionally creates an `.mlm`
45
+ subobject on the exported SavedModel that can be called to produce
46
+ the logits of the Masked Language Model task from pretraining.
47
+ The help string for flag --model_checkpoint_path explains the checkpoint
48
+ formats required for each --export_type.
49
+
50
+
51
+ 2) Exporting a preprocessor SavedModel
52
+
53
+ ```
54
+ export_tfhub \
55
+ --vocab_file ${BERT_DIR:?}/vocab.txt \
56
+ --export_type preprocessing --export_path /tmp/bert_preprocessing
57
+ ```
58
+
59
+ Be sure to use flag values that match the encoder and how it has been
60
+ pre-trained (see above for --vocab_file vs --sp_model_file).
61
+
62
+ If your encoder has been trained with text preprocessing for which tfhub.dev
63
+ already has SavedModel, you could guide your users to reuse that one instead
64
+ of exporting and publishing your own.
65
+
66
+ TODO(b/175369555): When exporting to users of TensorFlow 2.4, add flag
67
+ `--experimental_disable_assert_in_preprocessing`.
68
+ """
69
+
70
+ from absl import app
71
+ from absl import flags
72
+ import gin
73
+
74
+ from official.legacy.bert import configs
75
+ from official.modeling import hyperparams
76
+ from official.nlp.configs import encoders
77
+ from official.nlp.tools import export_tfhub_lib
78
+
79
+ FLAGS = flags.FLAGS
80
+
81
+ flags.DEFINE_enum(
82
+ "export_type", "model",
83
+ ["model", "model_with_mlm", "preprocessing"],
84
+ "The overall type of SavedModel to export. Flags "
85
+ "--bert_config_file/--encoder_config_file and --vocab_file/--sp_model_file "
86
+ "control which particular encoder model and preprocessing are exported.")
87
+ flags.DEFINE_string(
88
+ "export_path", None,
89
+ "Directory to which the SavedModel is written.")
90
+ flags.DEFINE_string(
91
+ "encoder_config_file", None,
92
+ "A yaml file representing `encoders.EncoderConfig` to define the encoder "
93
+ "(BERT or other). "
94
+ "Exactly one of --bert_config_file and --encoder_config_file can be set. "
95
+ "Needed for --export_type model and model_with_mlm.")
96
+ flags.DEFINE_string(
97
+ "bert_config_file", None,
98
+ "A JSON file with a legacy BERT configuration to define the BERT encoder. "
99
+ "Exactly one of --bert_config_file and --encoder_config_file can be set. "
100
+ "Needed for --export_type model and model_with_mlm.")
101
+ flags.DEFINE_bool(
102
+ "copy_pooler_dense_to_encoder", False,
103
+ "When the model is trained using `BertPretrainerV2`, the pool layer "
104
+ "of next sentence prediction task exists in `ClassificationHead` passed "
105
+ "to `BertPretrainerV2`. If True, we will copy this pooler's dense layer "
106
+ "to the encoder that is exported by this tool (as in classic BERT). "
107
+ "Using `BertPretrainerV2` and leaving this False exports an untrained "
108
+ "(randomly initialized) pooling layer, which some authors recommend for "
109
+ "subsequent fine-tuning,")
110
+ flags.DEFINE_string(
111
+ "model_checkpoint_path", None,
112
+ "File path to a pre-trained model checkpoint. "
113
+ "For --export_type model, this has to be an object-based (TF2) checkpoint "
114
+ "that can be restored to `tf.train.Checkpoint(encoder=encoder)` "
115
+ "for the `encoder` defined by the config file."
116
+ "(Legacy checkpoints with `model=` instead of `encoder=` are also "
117
+ "supported for now.) "
118
+ "For --export_type model_with_mlm, it must be restorable to "
119
+ "`tf.train.Checkpoint(**BertPretrainerV2(...).checkpoint_items)`. "
120
+ "(For now, `tf.train.Checkpoint(pretrainer=BertPretrainerV2(...))` is also "
121
+ "accepted.)")
122
+ flags.DEFINE_string(
123
+ "vocab_file", None,
124
+ "For encoders trained on BertTokenzier input: "
125
+ "the vocabulary file that the encoder model was trained with. "
126
+ "Exactly one of --vocab_file and --sp_model_file can be set. "
127
+ "Needed for --export_type model, model_with_mlm and preprocessing.")
128
+ flags.DEFINE_string(
129
+ "sp_model_file", None,
130
+ "For encoders trained on SentencepieceTokenzier input: "
131
+ "the SentencePiece .model file that the encoder model was trained with. "
132
+ "Exactly one of --vocab_file and --sp_model_file can be set. "
133
+ "Needed for --export_type model, model_with_mlm and preprocessing.")
134
+ flags.DEFINE_bool(
135
+ "do_lower_case", None,
136
+ "Whether to lowercase before tokenization. "
137
+ "If left as None, and --vocab_file is set, do_lower_case will be enabled "
138
+ "if 'uncased' appears in the name of --vocab_file. "
139
+ "If left as None, and --sp_model_file set, do_lower_case defaults to true. "
140
+ "Needed for --export_type model, model_with_mlm and preprocessing.")
141
+ flags.DEFINE_integer(
142
+ "default_seq_length", 128,
143
+ "The sequence length of preprocessing results from "
144
+ "top-level preprocess method. This is also the default "
145
+ "sequence length for the bert_pack_inputs subobject."
146
+ "Needed for --export_type preprocessing.")
147
+ flags.DEFINE_bool(
148
+ "tokenize_with_offsets", False, # TODO(b/181866850)
149
+ "Whether to export a .tokenize_with_offsets subobject for "
150
+ "--export_type preprocessing.")
151
+ flags.DEFINE_multi_string(
152
+ "gin_file", default=None,
153
+ help="List of paths to the config files.")
154
+ flags.DEFINE_multi_string(
155
+ "gin_params", default=None,
156
+ help="List of Gin bindings.")
157
+ flags.DEFINE_bool( # TODO(b/175369555): Remove this flag and its use.
158
+ "experimental_disable_assert_in_preprocessing", False,
159
+ "Export a preprocessing model without tf.Assert ops. "
160
+ "Usually, that would be a bad idea, except TF2.4 has an issue with "
161
+ "Assert ops in tf.functions used in Dataset.map() on a TPU worker, "
162
+ "and omitting the Assert ops lets SavedModels avoid the issue.")
163
+
164
+
165
+ def main(argv):
166
+ if len(argv) > 1:
167
+ raise app.UsageError("Too many command-line arguments.")
168
+ gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_params)
169
+
170
+ if bool(FLAGS.vocab_file) == bool(FLAGS.sp_model_file):
171
+ raise ValueError("Exactly one of `vocab_file` and `sp_model_file` "
172
+ "can be specified, but got %s and %s." %
173
+ (FLAGS.vocab_file, FLAGS.sp_model_file))
174
+ do_lower_case = export_tfhub_lib.get_do_lower_case(
175
+ FLAGS.do_lower_case, FLAGS.vocab_file, FLAGS.sp_model_file)
176
+
177
+ if FLAGS.export_type in ("model", "model_with_mlm"):
178
+ if bool(FLAGS.bert_config_file) == bool(FLAGS.encoder_config_file):
179
+ raise ValueError("Exactly one of `bert_config_file` and "
180
+ "`encoder_config_file` can be specified, but got "
181
+ "%s and %s." %
182
+ (FLAGS.bert_config_file, FLAGS.encoder_config_file))
183
+ if FLAGS.bert_config_file:
184
+ bert_config = configs.BertConfig.from_json_file(FLAGS.bert_config_file)
185
+ encoder_config = None
186
+ else:
187
+ bert_config = None
188
+ encoder_config = encoders.EncoderConfig()
189
+ encoder_config = hyperparams.override_params_dict(
190
+ encoder_config, FLAGS.encoder_config_file, is_strict=True)
191
+ export_tfhub_lib.export_model(
192
+ FLAGS.export_path,
193
+ bert_config=bert_config,
194
+ encoder_config=encoder_config,
195
+ model_checkpoint_path=FLAGS.model_checkpoint_path,
196
+ vocab_file=FLAGS.vocab_file,
197
+ sp_model_file=FLAGS.sp_model_file,
198
+ do_lower_case=do_lower_case,
199
+ with_mlm=FLAGS.export_type == "model_with_mlm",
200
+ copy_pooler_dense_to_encoder=FLAGS.copy_pooler_dense_to_encoder)
201
+
202
+ elif FLAGS.export_type == "preprocessing":
203
+ export_tfhub_lib.export_preprocessing(
204
+ FLAGS.export_path,
205
+ vocab_file=FLAGS.vocab_file,
206
+ sp_model_file=FLAGS.sp_model_file,
207
+ do_lower_case=do_lower_case,
208
+ default_seq_length=FLAGS.default_seq_length,
209
+ tokenize_with_offsets=FLAGS.tokenize_with_offsets,
210
+ experimental_disable_assert=
211
+ FLAGS.experimental_disable_assert_in_preprocessing)
212
+
213
+ else:
214
+ raise app.UsageError(
215
+ "Unknown value '%s' for flag --export_type" % FLAGS.export_type)
216
+
217
+
218
+ if __name__ == "__main__":
219
+ app.run(main)
export_tfhub_lib.py ADDED
@@ -0,0 +1,493 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The TensorFlow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Library of components of export_tfhub.py. See docstring there for more."""
16
+
17
+ import contextlib
18
+ import hashlib
19
+ import os
20
+ import tempfile
21
+
22
+ from typing import Optional, Text, Tuple
23
+
24
+ # Import libraries
25
+ from absl import logging
26
+ import tensorflow as tf, tf_keras
27
+ # pylint: disable=g-direct-tensorflow-import TODO(b/175369555): Remove these.
28
+ from tensorflow.core.protobuf import saved_model_pb2
29
+ from tensorflow.python.ops import control_flow_assert
30
+ # pylint: enable=g-direct-tensorflow-import
31
+ from official.legacy.bert import configs
32
+ from official.modeling import tf_utils
33
+ from official.nlp.configs import encoders
34
+ from official.nlp.modeling import layers
35
+ from official.nlp.modeling import models
36
+ from official.nlp.modeling import networks
37
+
38
+
39
+ def get_bert_encoder(bert_config):
40
+ """Returns a BertEncoder with dict outputs."""
41
+ bert_encoder = networks.BertEncoder(
42
+ vocab_size=bert_config.vocab_size,
43
+ hidden_size=bert_config.hidden_size,
44
+ num_layers=bert_config.num_hidden_layers,
45
+ num_attention_heads=bert_config.num_attention_heads,
46
+ intermediate_size=bert_config.intermediate_size,
47
+ activation=tf_utils.get_activation(bert_config.hidden_act),
48
+ dropout_rate=bert_config.hidden_dropout_prob,
49
+ attention_dropout_rate=bert_config.attention_probs_dropout_prob,
50
+ max_sequence_length=bert_config.max_position_embeddings,
51
+ type_vocab_size=bert_config.type_vocab_size,
52
+ initializer=tf_keras.initializers.TruncatedNormal(
53
+ stddev=bert_config.initializer_range),
54
+ embedding_width=bert_config.embedding_size,
55
+ dict_outputs=True)
56
+
57
+ return bert_encoder
58
+
59
+
60
+ def get_do_lower_case(do_lower_case, vocab_file=None, sp_model_file=None):
61
+ """Returns do_lower_case, replacing None by a guess from vocab file name."""
62
+ if do_lower_case is not None:
63
+ return do_lower_case
64
+ elif vocab_file:
65
+ do_lower_case = "uncased" in vocab_file
66
+ logging.info("Using do_lower_case=%s based on name of vocab_file=%s",
67
+ do_lower_case, vocab_file)
68
+ return do_lower_case
69
+ elif sp_model_file:
70
+ do_lower_case = True # All public ALBERTs (as of Oct 2020) do it.
71
+ logging.info("Defaulting to do_lower_case=%s for Sentencepiece tokenizer",
72
+ do_lower_case)
73
+ return do_lower_case
74
+ else:
75
+ raise ValueError("Must set vocab_file or sp_model_file.")
76
+
77
+
78
+ def _create_model(
79
+ *,
80
+ bert_config: Optional[configs.BertConfig] = None,
81
+ encoder_config: Optional[encoders.EncoderConfig] = None,
82
+ with_mlm: bool,
83
+ ) -> Tuple[tf_keras.Model, tf_keras.Model]:
84
+ """Creates the model to export and the model to restore the checkpoint.
85
+
86
+ Args:
87
+ bert_config: A legacy `BertConfig` to create a `BertEncoder` object. Exactly
88
+ one of encoder_config and bert_config must be set.
89
+ encoder_config: An `EncoderConfig` to create an encoder of the configured
90
+ type (`BertEncoder` or other).
91
+ with_mlm: A bool to control the second component of the result. If True,
92
+ will create a `BertPretrainerV2` object; otherwise, will create a
93
+ `BertEncoder` object.
94
+
95
+ Returns:
96
+ A Tuple of (1) a Keras model that will be exported, (2) a `BertPretrainerV2`
97
+ object or `BertEncoder` object depending on the value of `with_mlm`
98
+ argument, which contains the first model and will be used for restoring
99
+ weights from the checkpoint.
100
+ """
101
+ if (bert_config is not None) == (encoder_config is not None):
102
+ raise ValueError("Exactly one of `bert_config` and `encoder_config` "
103
+ "can be specified, but got %s and %s" %
104
+ (bert_config, encoder_config))
105
+
106
+ if bert_config is not None:
107
+ encoder = get_bert_encoder(bert_config)
108
+ else:
109
+ encoder = encoders.build_encoder(encoder_config)
110
+
111
+ # Convert from list of named inputs to dict of inputs keyed by name.
112
+ # Only the latter accepts a dict of inputs after restoring from SavedModel.
113
+ if isinstance(encoder.inputs, list) or isinstance(encoder.inputs, tuple):
114
+ encoder_inputs_dict = {x.name: x for x in encoder.inputs}
115
+ else:
116
+ # encoder.inputs by default is dict for BertEncoderV2.
117
+ encoder_inputs_dict = encoder.inputs
118
+ encoder_output_dict = encoder(encoder_inputs_dict)
119
+ # For interchangeability with other text representations,
120
+ # add "default" as an alias for BERT's whole-input reptesentations.
121
+ encoder_output_dict["default"] = encoder_output_dict["pooled_output"]
122
+ core_model = tf_keras.Model(
123
+ inputs=encoder_inputs_dict, outputs=encoder_output_dict)
124
+
125
+ if with_mlm:
126
+ if bert_config is not None:
127
+ hidden_act = bert_config.hidden_act
128
+ else:
129
+ assert encoder_config is not None
130
+ hidden_act = encoder_config.get().hidden_activation
131
+
132
+ pretrainer = models.BertPretrainerV2(
133
+ encoder_network=encoder,
134
+ mlm_activation=tf_utils.get_activation(hidden_act))
135
+
136
+ if isinstance(pretrainer.inputs, dict):
137
+ pretrainer_inputs_dict = pretrainer.inputs
138
+ else:
139
+ pretrainer_inputs_dict = {x.name: x for x in pretrainer.inputs}
140
+ pretrainer_output_dict = pretrainer(pretrainer_inputs_dict)
141
+ mlm_model = tf_keras.Model(
142
+ inputs=pretrainer_inputs_dict, outputs=pretrainer_output_dict)
143
+ # Set `_auto_track_sub_layers` to False, so that the additional weights
144
+ # from `mlm` sub-object will not be included in the core model.
145
+ # TODO(b/169210253): Use a public API when available.
146
+ core_model._auto_track_sub_layers = False # pylint: disable=protected-access
147
+ core_model.mlm = mlm_model
148
+ return core_model, pretrainer
149
+ else:
150
+ return core_model, encoder
151
+
152
+
153
+ def export_model(export_path: Text,
154
+ *,
155
+ bert_config: Optional[configs.BertConfig] = None,
156
+ encoder_config: Optional[encoders.EncoderConfig] = None,
157
+ model_checkpoint_path: Text,
158
+ with_mlm: bool,
159
+ copy_pooler_dense_to_encoder: bool = False,
160
+ vocab_file: Optional[Text] = None,
161
+ sp_model_file: Optional[Text] = None,
162
+ do_lower_case: Optional[bool] = None) -> None:
163
+ """Exports an Encoder as SavedModel after restoring pre-trained weights.
164
+
165
+ The exported SavedModel implements a superset of the Encoder API for
166
+ Text embeddings with Transformer Encoders described at
167
+ https://www.tensorflow.org/hub/common_saved_model_apis/text.
168
+
169
+ In particular, the exported SavedModel can be used in the following way:
170
+
171
+ ```
172
+ # Calls default interface (encoder only).
173
+
174
+ encoder = hub.load(...)
175
+ encoder_inputs = dict(
176
+ input_word_ids=..., # Shape [batch, seq_length], dtype=int32
177
+ input_mask=..., # Shape [batch, seq_length], dtype=int32
178
+ input_type_ids=..., # Shape [batch, seq_length], dtype=int32
179
+ )
180
+ encoder_outputs = encoder(encoder_inputs)
181
+ assert encoder_outputs.keys() == {
182
+ "pooled_output", # Shape [batch_size, width], dtype=float32
183
+ "default", # Alias for "pooled_output" (aligns with other models).
184
+ "sequence_output" # Shape [batch_size, seq_length, width], dtype=float32
185
+ "encoder_outputs", # List of Tensors with outputs of all transformer layers.
186
+ }
187
+ ```
188
+
189
+ If `with_mlm` is True, the exported SavedModel can also be called in the
190
+ following way:
191
+
192
+ ```
193
+ # Calls expanded interface that includes logits of the Masked Language Model.
194
+ mlm_inputs = dict(
195
+ input_word_ids=..., # Shape [batch, seq_length], dtype=int32
196
+ input_mask=..., # Shape [batch, seq_length], dtype=int32
197
+ input_type_ids=..., # Shape [batch, seq_length], dtype=int32
198
+ masked_lm_positions=..., # Shape [batch, num_predictions], dtype=int32
199
+ )
200
+ mlm_outputs = encoder.mlm(mlm_inputs)
201
+ assert mlm_outputs.keys() == {
202
+ "pooled_output", # Shape [batch, width], dtype=float32
203
+ "sequence_output", # Shape [batch, seq_length, width], dtype=float32
204
+ "encoder_outputs", # List of Tensors with outputs of all transformer layers.
205
+ "mlm_logits" # Shape [batch, num_predictions, vocab_size], dtype=float32
206
+ }
207
+ ```
208
+
209
+ Args:
210
+ export_path: The SavedModel output directory.
211
+ bert_config: An optional `configs.BertConfig` object. Note: exactly one of
212
+ `bert_config` and following `encoder_config` must be specified.
213
+ encoder_config: An optional `encoders.EncoderConfig` object.
214
+ model_checkpoint_path: The path to the checkpoint.
215
+ with_mlm: Whether to export the additional mlm sub-object.
216
+ copy_pooler_dense_to_encoder: Whether to copy the pooler's dense layer used
217
+ in the next sentence prediction task to the encoder.
218
+ vocab_file: The path to the wordpiece vocab file, or None.
219
+ sp_model_file: The path to the sentencepiece model file, or None. Exactly
220
+ one of vocab_file and sp_model_file must be set.
221
+ do_lower_case: Whether to lower-case text before tokenization.
222
+ """
223
+ if with_mlm:
224
+ core_model, pretrainer = _create_model(
225
+ bert_config=bert_config,
226
+ encoder_config=encoder_config,
227
+ with_mlm=with_mlm)
228
+ encoder = pretrainer.encoder_network
229
+ # It supports both the new pretrainer checkpoint produced by TF-NLP and
230
+ # the checkpoint converted from TF1 (original BERT, SmallBERTs).
231
+ checkpoint_items = pretrainer.checkpoint_items
232
+ checkpoint = tf.train.Checkpoint(**checkpoint_items)
233
+ else:
234
+ core_model, encoder = _create_model(
235
+ bert_config=bert_config,
236
+ encoder_config=encoder_config,
237
+ with_mlm=with_mlm)
238
+ checkpoint = tf.train.Checkpoint(
239
+ model=encoder, # Legacy checkpoints.
240
+ encoder=encoder)
241
+ checkpoint.restore(model_checkpoint_path).assert_existing_objects_matched()
242
+
243
+ if copy_pooler_dense_to_encoder:
244
+ logging.info("Copy pooler's dense layer to the encoder.")
245
+ pooler_checkpoint = tf.train.Checkpoint(
246
+ **{"next_sentence.pooler_dense": encoder.pooler_layer})
247
+ pooler_checkpoint.restore(
248
+ model_checkpoint_path).assert_existing_objects_matched()
249
+
250
+ # Before SavedModels for preprocessing appeared in Oct 2020, the encoders
251
+ # provided this information to let users do preprocessing themselves.
252
+ # We keep doing that for now. It helps users to upgrade incrementally.
253
+ # Moreover, it offers an escape hatch for advanced users who want the
254
+ # full vocab, not the high-level operations from the preprocessing model.
255
+ if vocab_file:
256
+ core_model.vocab_file = tf.saved_model.Asset(vocab_file)
257
+ if do_lower_case is None:
258
+ raise ValueError("Must pass do_lower_case if passing vocab_file.")
259
+ core_model.do_lower_case = tf.Variable(do_lower_case, trainable=False)
260
+ elif sp_model_file:
261
+ # This was used by ALBERT, with implied values of do_lower_case=True
262
+ # and strip_diacritics=True.
263
+ core_model.sp_model_file = tf.saved_model.Asset(sp_model_file)
264
+ else:
265
+ raise ValueError("Must set vocab_file or sp_model_file")
266
+ core_model.save(export_path, include_optimizer=False, save_format="tf")
267
+
268
+
269
+ class BertPackInputsSavedModelWrapper(tf.train.Checkpoint):
270
+ """Wraps a BertPackInputs layer for export to SavedModel.
271
+
272
+ The wrapper object is suitable for use with `tf.saved_model.save()` and
273
+ `.load()`. The wrapper object is callable with inputs and outputs like the
274
+ BertPackInputs layer, but differs from saving an unwrapped Keras object:
275
+
276
+ - The inputs can be a list of 1 or 2 RaggedTensors of dtype int32 and
277
+ ragged rank 1 or 2. (In Keras, saving to a tf.function in a SavedModel
278
+ would fix the number of RaggedTensors and their ragged rank.)
279
+ - The call accepts an optional keyword argument `seq_length=` to override
280
+ the layer's .seq_length hyperparameter. (In Keras, a hyperparameter
281
+ could not be changed after saving to a tf.function in a SavedModel.)
282
+ """
283
+
284
+ def __init__(self, bert_pack_inputs: layers.BertPackInputs):
285
+ super().__init__()
286
+
287
+ # Preserve the layer's configured seq_length as a default but make it
288
+ # overridable. Having this dynamically determined default argument
289
+ # requires self.__call__ to be defined in this indirect way.
290
+ default_seq_length = bert_pack_inputs.seq_length
291
+
292
+ @tf.function(autograph=False)
293
+ def call(inputs, seq_length=default_seq_length):
294
+ return layers.BertPackInputs.bert_pack_inputs(
295
+ inputs,
296
+ seq_length=seq_length,
297
+ start_of_sequence_id=bert_pack_inputs.start_of_sequence_id,
298
+ end_of_segment_id=bert_pack_inputs.end_of_segment_id,
299
+ padding_id=bert_pack_inputs.padding_id)
300
+
301
+ self.__call__ = call
302
+
303
+ for ragged_rank in range(1, 3):
304
+ for num_segments in range(1, 3):
305
+ _ = self.__call__.get_concrete_function([
306
+ tf.RaggedTensorSpec([None] * (ragged_rank + 1), dtype=tf.int32)
307
+ for _ in range(num_segments)
308
+ ],
309
+ seq_length=tf.TensorSpec(
310
+ [], tf.int32))
311
+
312
+
313
+ def create_preprocessing(*,
314
+ vocab_file: Optional[str] = None,
315
+ sp_model_file: Optional[str] = None,
316
+ do_lower_case: bool,
317
+ tokenize_with_offsets: bool,
318
+ default_seq_length: int) -> tf_keras.Model:
319
+ """Returns a preprocessing Model for given tokenization parameters.
320
+
321
+ This function builds a Keras Model with attached subobjects suitable for
322
+ saving to a SavedModel. The resulting SavedModel implements the Preprocessor
323
+ API for Text embeddings with Transformer Encoders described at
324
+ https://www.tensorflow.org/hub/common_saved_model_apis/text.
325
+
326
+ Args:
327
+ vocab_file: The path to the wordpiece vocab file, or None.
328
+ sp_model_file: The path to the sentencepiece model file, or None. Exactly
329
+ one of vocab_file and sp_model_file must be set. This determines the type
330
+ of tokenzer that is used.
331
+ do_lower_case: Whether to do lower case.
332
+ tokenize_with_offsets: Whether to include the .tokenize_with_offsets
333
+ subobject.
334
+ default_seq_length: The sequence length of preprocessing results from root
335
+ callable. This is also the default sequence length for the
336
+ bert_pack_inputs subobject.
337
+
338
+ Returns:
339
+ A tf_keras.Model object with several attached subobjects, suitable for
340
+ saving as a preprocessing SavedModel.
341
+ """
342
+ # Select tokenizer.
343
+ if bool(vocab_file) == bool(sp_model_file):
344
+ raise ValueError("Must set exactly one of vocab_file, sp_model_file")
345
+ if vocab_file:
346
+ tokenize = layers.BertTokenizer(
347
+ vocab_file=vocab_file,
348
+ lower_case=do_lower_case,
349
+ tokenize_with_offsets=tokenize_with_offsets)
350
+ else:
351
+ tokenize = layers.SentencepieceTokenizer(
352
+ model_file_path=sp_model_file,
353
+ lower_case=do_lower_case,
354
+ strip_diacritics=True, # Strip diacritics to follow ALBERT model.
355
+ tokenize_with_offsets=tokenize_with_offsets)
356
+
357
+ # The root object of the preprocessing model can be called to do
358
+ # one-shot preprocessing for users with single-sentence inputs.
359
+ sentences = tf_keras.layers.Input(shape=(), dtype=tf.string, name="sentences")
360
+ if tokenize_with_offsets:
361
+ tokens, start_offsets, limit_offsets = tokenize(sentences)
362
+ else:
363
+ tokens = tokenize(sentences)
364
+ pack = layers.BertPackInputs(
365
+ seq_length=default_seq_length,
366
+ special_tokens_dict=tokenize.get_special_tokens_dict())
367
+ model_inputs = pack(tokens)
368
+ preprocessing = tf_keras.Model(sentences, model_inputs)
369
+
370
+ # Individual steps of preprocessing are made available as named subobjects
371
+ # to enable more general preprocessing. For saving, they need to be Models
372
+ # in their own right.
373
+ preprocessing.tokenize = tf_keras.Model(sentences, tokens)
374
+ # Provide an equivalent to tokenize.get_special_tokens_dict().
375
+ preprocessing.tokenize.get_special_tokens_dict = tf.train.Checkpoint()
376
+ preprocessing.tokenize.get_special_tokens_dict.__call__ = tf.function(
377
+ lambda: tokenize.get_special_tokens_dict(), # pylint: disable=[unnecessary-lambda]
378
+ input_signature=[])
379
+ if tokenize_with_offsets:
380
+ preprocessing.tokenize_with_offsets = tf_keras.Model(
381
+ sentences, [tokens, start_offsets, limit_offsets])
382
+ preprocessing.tokenize_with_offsets.get_special_tokens_dict = (
383
+ preprocessing.tokenize.get_special_tokens_dict)
384
+ # Conceptually, this should be
385
+ # preprocessing.bert_pack_inputs = tf_keras.Model(tokens, model_inputs)
386
+ # but technicalities require us to use a wrapper (see comments there).
387
+ # In particular, seq_length can be overridden when calling this.
388
+ preprocessing.bert_pack_inputs = BertPackInputsSavedModelWrapper(pack)
389
+
390
+ return preprocessing
391
+
392
+
393
+ def _move_to_tmpdir(file_path: Optional[Text], tmpdir: Text) -> Optional[Text]:
394
+ """Returns new path with same basename and hash of original path."""
395
+ if file_path is None:
396
+ return None
397
+ olddir, filename = os.path.split(file_path)
398
+ hasher = hashlib.sha1()
399
+ hasher.update(olddir.encode("utf-8"))
400
+ target_dir = os.path.join(tmpdir, hasher.hexdigest())
401
+ target_file = os.path.join(target_dir, filename)
402
+ tf.io.gfile.mkdir(target_dir)
403
+ tf.io.gfile.copy(file_path, target_file)
404
+ return target_file
405
+
406
+
407
+ def export_preprocessing(export_path: Text,
408
+ *,
409
+ vocab_file: Optional[Text] = None,
410
+ sp_model_file: Optional[Text] = None,
411
+ do_lower_case: bool,
412
+ tokenize_with_offsets: bool,
413
+ default_seq_length: int,
414
+ experimental_disable_assert: bool = False) -> None:
415
+ """Exports preprocessing to a SavedModel for TF Hub."""
416
+ with tempfile.TemporaryDirectory() as tmpdir:
417
+ # TODO(b/175369555): Remove experimental_disable_assert and its use.
418
+ with _maybe_disable_assert(experimental_disable_assert):
419
+ preprocessing = create_preprocessing(
420
+ vocab_file=_move_to_tmpdir(vocab_file, tmpdir),
421
+ sp_model_file=_move_to_tmpdir(sp_model_file, tmpdir),
422
+ do_lower_case=do_lower_case,
423
+ tokenize_with_offsets=tokenize_with_offsets,
424
+ default_seq_length=default_seq_length)
425
+ preprocessing.save(export_path, include_optimizer=False, save_format="tf")
426
+ if experimental_disable_assert:
427
+ _check_no_assert(export_path)
428
+ # It helps the unit test to prevent stray copies of the vocab file.
429
+ if tf.io.gfile.exists(tmpdir):
430
+ raise IOError("Failed to clean up TemporaryDirectory")
431
+
432
+
433
+ # TODO(b/175369555): Remove all workarounds for this bug of TensorFlow 2.4
434
+ # when this bug is no longer a concern for publishing new models.
435
+ # TensorFlow 2.4 has a placement issue with Assert ops in tf.functions called
436
+ # from Dataset.map() on a TPU worker. They end up on the TPU coordinator,
437
+ # and invoking them from the TPU worker is either inefficient (when possible)
438
+ # or impossible (notably when using "headless" TPU workers on Cloud that do not
439
+ # have a channel to the coordinator). The bug has been fixed in time for TF 2.5.
440
+ # To work around this, the following code avoids Assert ops in the exported
441
+ # SavedModels. It monkey-patches calls to tf.Assert from inside TensorFlow and
442
+ # replaces them by a no-op while building the exported model. This is fragile,
443
+ # so _check_no_assert() validates the result. The resulting model should be fine
444
+ # to read on future versions of TF, even if this workaround at export time
445
+ # may break eventually. (Failing unit tests will tell.)
446
+
447
+
448
+ def _dont_assert(condition, data, summarize=None, name="Assert"):
449
+ """The no-op version of tf.Assert installed by _maybe_disable_assert."""
450
+ del condition, data, summarize # Unused.
451
+ if tf.executing_eagerly():
452
+ return
453
+ with tf.name_scope(name):
454
+ return tf.no_op(name="dont_assert")
455
+
456
+
457
+ @contextlib.contextmanager
458
+ def _maybe_disable_assert(disable_assert):
459
+ """Scoped monkey patch of control_flow_assert.Assert to a no-op."""
460
+ if not disable_assert:
461
+ yield
462
+ return
463
+
464
+ original_assert = control_flow_assert.Assert
465
+ control_flow_assert.Assert = _dont_assert
466
+ yield
467
+ control_flow_assert.Assert = original_assert
468
+
469
+
470
+ def _check_no_assert(saved_model_path):
471
+ """Raises AssertionError if SavedModel contains Assert ops."""
472
+ saved_model_filename = os.path.join(saved_model_path, "saved_model.pb")
473
+ with tf.io.gfile.GFile(saved_model_filename, "rb") as f:
474
+ saved_model = saved_model_pb2.SavedModel.FromString(f.read())
475
+
476
+ assert_nodes = []
477
+ graph_def = saved_model.meta_graphs[0].graph_def
478
+ assert_nodes += [
479
+ "node '{}' in global graph".format(n.name)
480
+ for n in graph_def.node
481
+ if n.op == "Assert"
482
+ ]
483
+ for fdef in graph_def.library.function:
484
+ assert_nodes += [
485
+ "node '{}' in function '{}'".format(n.name, fdef.signature.name)
486
+ for n in fdef.node_def
487
+ if n.op == "Assert"
488
+ ]
489
+ if assert_nodes:
490
+ raise AssertionError(
491
+ "Internal tool error: "
492
+ "failed to suppress {} Assert ops in SavedModel:\n{}".format(
493
+ len(assert_nodes), "\n".join(assert_nodes[:10])))
export_tfhub_lib_test.py ADDED
@@ -0,0 +1,1080 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The TensorFlow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Tests export_tfhub_lib."""
16
+
17
+ import os
18
+ import tempfile
19
+
20
+ from absl.testing import parameterized
21
+ import numpy as np
22
+ import tensorflow as tf, tf_keras
23
+ from tensorflow import estimator as tf_estimator
24
+ import tensorflow_hub as hub
25
+ import tensorflow_text as text
26
+
27
+ from sentencepiece import SentencePieceTrainer
28
+ from official.legacy.bert import configs
29
+ from official.modeling import tf_utils
30
+ from official.nlp.configs import encoders
31
+ from official.nlp.modeling import layers
32
+ from official.nlp.modeling import models
33
+ from official.nlp.tools import export_tfhub_lib
34
+
35
+
36
+ def _get_bert_config_or_encoder_config(use_bert_config,
37
+ hidden_size,
38
+ num_hidden_layers,
39
+ encoder_type="albert",
40
+ vocab_size=100):
41
+ """Generates config args for export_tfhub_lib._create_model().
42
+
43
+ Args:
44
+ use_bert_config: bool. If True, returns legacy BertConfig.
45
+ hidden_size: int.
46
+ num_hidden_layers: int.
47
+ encoder_type: str. Can be ['albert', 'bert', 'bert_v2']. If use_bert_config
48
+ == True, then model_type is not used.
49
+ vocab_size: int.
50
+
51
+ Returns:
52
+ bert_config, encoder_config. Only one is not None. If
53
+ `use_bert_config` == True, the first config is valid. Otherwise
54
+ `bert_config` == None.
55
+ """
56
+ if use_bert_config:
57
+ bert_config = configs.BertConfig(
58
+ vocab_size=vocab_size,
59
+ hidden_size=hidden_size,
60
+ intermediate_size=32,
61
+ max_position_embeddings=128,
62
+ num_attention_heads=2,
63
+ num_hidden_layers=num_hidden_layers)
64
+ encoder_config = None
65
+ else:
66
+ bert_config = None
67
+ if encoder_type == "albert":
68
+ encoder_config = encoders.EncoderConfig(
69
+ type="albert",
70
+ albert=encoders.AlbertEncoderConfig(
71
+ vocab_size=vocab_size,
72
+ embedding_width=16,
73
+ hidden_size=hidden_size,
74
+ intermediate_size=32,
75
+ max_position_embeddings=128,
76
+ num_attention_heads=2,
77
+ num_layers=num_hidden_layers,
78
+ dropout_rate=0.1))
79
+ else:
80
+ # encoder_type can be 'bert' or 'bert_v2'.
81
+ model_config = encoders.BertEncoderConfig(
82
+ vocab_size=vocab_size,
83
+ embedding_size=16,
84
+ hidden_size=hidden_size,
85
+ intermediate_size=32,
86
+ max_position_embeddings=128,
87
+ num_attention_heads=2,
88
+ num_layers=num_hidden_layers,
89
+ dropout_rate=0.1)
90
+ kwargs = {"type": encoder_type, encoder_type: model_config}
91
+ encoder_config = encoders.EncoderConfig(**kwargs)
92
+
93
+ return bert_config, encoder_config
94
+
95
+
96
+ def _get_vocab_or_sp_model_dummy(temp_dir, use_sp_model):
97
+ """Returns tokenizer asset args for export_tfhub_lib.export_model()."""
98
+ dummy_file = os.path.join(temp_dir, "dummy_file.txt")
99
+ with tf.io.gfile.GFile(dummy_file, "w") as f:
100
+ f.write("dummy content")
101
+ if use_sp_model:
102
+ vocab_file, sp_model_file = None, dummy_file
103
+ else:
104
+ vocab_file, sp_model_file = dummy_file, None
105
+ return vocab_file, sp_model_file
106
+
107
+
108
+ def _read_asset(asset: tf.saved_model.Asset):
109
+ return tf.io.gfile.GFile(asset.asset_path.numpy()).read()
110
+
111
+
112
+ def _find_lambda_layers(layer):
113
+ """Returns list of all Lambda layers in a Keras model."""
114
+ if isinstance(layer, tf_keras.layers.Lambda):
115
+ return [layer]
116
+ elif hasattr(layer, "layers"): # It's nested, like a Model.
117
+ result = []
118
+ for l in layer.layers:
119
+ result += _find_lambda_layers(l)
120
+ return result
121
+ else:
122
+ return []
123
+
124
+
125
+ class ExportModelTest(tf.test.TestCase, parameterized.TestCase):
126
+ """Tests exporting a Transformer Encoder model as a SavedModel.
127
+
128
+ This covers export from an Encoder checkpoint to a SavedModel without
129
+ the .mlm subobject. This is no longer preferred, but still useful
130
+ for models like Electra that are trained without the MLM task.
131
+
132
+ The export code is generic. This test focuses on two main cases
133
+ (the most important ones in practice when this was written in 2020):
134
+ - BERT built from a legacy BertConfig, for use with BertTokenizer.
135
+ - ALBERT built from an EncoderConfig (as a representative of all other
136
+ choices beyond BERT, for use with SentencepieceTokenizer (the one
137
+ alternative to BertTokenizer).
138
+ """
139
+
140
+ @parameterized.named_parameters(
141
+ ("Bert_Legacy", True, None), ("Albert", False, "albert"),
142
+ ("BertEncoder", False, "bert"), ("BertEncoderV2", False, "bert_v2"))
143
+ def test_export_model(self, use_bert, encoder_type):
144
+ # Create the encoder and export it.
145
+ hidden_size = 16
146
+ num_hidden_layers = 1
147
+ bert_config, encoder_config = _get_bert_config_or_encoder_config(
148
+ use_bert,
149
+ hidden_size=hidden_size,
150
+ num_hidden_layers=num_hidden_layers,
151
+ encoder_type=encoder_type)
152
+ bert_model, encoder = export_tfhub_lib._create_model(
153
+ bert_config=bert_config, encoder_config=encoder_config, with_mlm=False)
154
+ self.assertEmpty(
155
+ _find_lambda_layers(bert_model),
156
+ "Lambda layers are non-portable since they serialize Python bytecode.")
157
+ model_checkpoint_dir = os.path.join(self.get_temp_dir(), "checkpoint")
158
+ checkpoint = tf.train.Checkpoint(encoder=encoder)
159
+ checkpoint.save(os.path.join(model_checkpoint_dir, "test"))
160
+ model_checkpoint_path = tf.train.latest_checkpoint(model_checkpoint_dir)
161
+
162
+ vocab_file, sp_model_file = _get_vocab_or_sp_model_dummy(
163
+ self.get_temp_dir(), use_sp_model=not use_bert)
164
+ export_path = os.path.join(self.get_temp_dir(), "hub")
165
+ export_tfhub_lib.export_model(
166
+ export_path=export_path,
167
+ bert_config=bert_config,
168
+ encoder_config=encoder_config,
169
+ model_checkpoint_path=model_checkpoint_path,
170
+ with_mlm=False,
171
+ vocab_file=vocab_file,
172
+ sp_model_file=sp_model_file,
173
+ do_lower_case=True)
174
+
175
+ # Restore the exported model.
176
+ hub_layer = hub.KerasLayer(export_path, trainable=True)
177
+
178
+ # Check legacy tokenization data.
179
+ if use_bert:
180
+ self.assertTrue(hub_layer.resolved_object.do_lower_case.numpy())
181
+ self.assertEqual("dummy content",
182
+ _read_asset(hub_layer.resolved_object.vocab_file))
183
+ self.assertFalse(hasattr(hub_layer.resolved_object, "sp_model_file"))
184
+ else:
185
+ self.assertFalse(hasattr(hub_layer.resolved_object, "do_lower_case"))
186
+ self.assertFalse(hasattr(hub_layer.resolved_object, "vocab_file"))
187
+ self.assertEqual("dummy content",
188
+ _read_asset(hub_layer.resolved_object.sp_model_file))
189
+
190
+ # Check restored weights.
191
+ self.assertEqual(
192
+ len(bert_model.trainable_weights), len(hub_layer.trainable_weights))
193
+ for source_weight, hub_weight in zip(bert_model.trainable_weights,
194
+ hub_layer.trainable_weights):
195
+ self.assertAllClose(source_weight.numpy(), hub_weight.numpy())
196
+
197
+ # Check computation.
198
+ seq_length = 10
199
+ dummy_ids = np.zeros((2, seq_length), dtype=np.int32)
200
+ input_dict = dict(
201
+ input_word_ids=dummy_ids,
202
+ input_mask=dummy_ids,
203
+ input_type_ids=dummy_ids)
204
+ hub_output = hub_layer(input_dict)
205
+ source_output = bert_model(input_dict)
206
+ encoder_output = encoder(input_dict)
207
+ self.assertEqual(hub_output["pooled_output"].shape, (2, hidden_size))
208
+ self.assertEqual(hub_output["sequence_output"].shape,
209
+ (2, seq_length, hidden_size))
210
+ self.assertLen(hub_output["encoder_outputs"], num_hidden_layers)
211
+
212
+ for key in ("pooled_output", "sequence_output", "encoder_outputs"):
213
+ self.assertAllClose(source_output[key], hub_output[key])
214
+ self.assertAllClose(source_output[key], encoder_output[key])
215
+
216
+ # The "default" output of BERT as a text representation is pooled_output.
217
+ self.assertAllClose(hub_output["pooled_output"], hub_output["default"])
218
+
219
+ # Test that training=True makes a difference (activates dropout).
220
+ def _dropout_mean_stddev(training, num_runs=20):
221
+ input_ids = np.array([[14, 12, 42, 95, 99]], np.int32)
222
+ input_dict = dict(
223
+ input_word_ids=input_ids,
224
+ input_mask=np.ones_like(input_ids),
225
+ input_type_ids=np.zeros_like(input_ids))
226
+ outputs = np.concatenate([
227
+ hub_layer(input_dict, training=training)["pooled_output"]
228
+ for _ in range(num_runs)
229
+ ])
230
+ return np.mean(np.std(outputs, axis=0))
231
+
232
+ self.assertLess(_dropout_mean_stddev(training=False), 1e-6)
233
+ self.assertGreater(_dropout_mean_stddev(training=True), 1e-3)
234
+
235
+ # Test propagation of seq_length in shape inference.
236
+ input_word_ids = tf_keras.layers.Input(shape=(seq_length,), dtype=tf.int32)
237
+ input_mask = tf_keras.layers.Input(shape=(seq_length,), dtype=tf.int32)
238
+ input_type_ids = tf_keras.layers.Input(shape=(seq_length,), dtype=tf.int32)
239
+ input_dict = dict(
240
+ input_word_ids=input_word_ids,
241
+ input_mask=input_mask,
242
+ input_type_ids=input_type_ids)
243
+ output_dict = hub_layer(input_dict)
244
+ pooled_output = output_dict["pooled_output"]
245
+ sequence_output = output_dict["sequence_output"]
246
+ encoder_outputs = output_dict["encoder_outputs"]
247
+
248
+ self.assertEqual(pooled_output.shape.as_list(), [None, hidden_size])
249
+ self.assertEqual(sequence_output.shape.as_list(),
250
+ [None, seq_length, hidden_size])
251
+ self.assertLen(encoder_outputs, num_hidden_layers)
252
+
253
+
254
+ class ExportModelWithMLMTest(tf.test.TestCase, parameterized.TestCase):
255
+ """Tests exporting a Transformer Encoder model as a SavedModel.
256
+
257
+ This covers export from a Pretrainer checkpoint to a SavedModel including
258
+ the .mlm subobject, which is the preferred way since 2020.
259
+
260
+ The export code is generic. This test focuses on two main cases
261
+ (the most important ones in practice when this was written in 2020):
262
+ - BERT built from a legacy BertConfig, for use with BertTokenizer.
263
+ - ALBERT built from an EncoderConfig (as a representative of all other
264
+ choices beyond BERT, for use with SentencepieceTokenizer (the one
265
+ alternative to BertTokenizer).
266
+ """
267
+
268
+ def test_copy_pooler_dense_to_encoder(self):
269
+ encoder_config = encoders.EncoderConfig(
270
+ type="bert",
271
+ bert=encoders.BertEncoderConfig(
272
+ hidden_size=24, intermediate_size=48, num_layers=2))
273
+ cls_heads = [
274
+ layers.ClassificationHead(
275
+ inner_dim=24, num_classes=2, name="next_sentence")
276
+ ]
277
+ encoder = encoders.build_encoder(encoder_config)
278
+ pretrainer = models.BertPretrainerV2(
279
+ encoder_network=encoder,
280
+ classification_heads=cls_heads,
281
+ mlm_activation=tf_utils.get_activation(
282
+ encoder_config.get().hidden_activation))
283
+ # Makes sure the pretrainer variables are created.
284
+ _ = pretrainer(pretrainer.inputs)
285
+ checkpoint = tf.train.Checkpoint(**pretrainer.checkpoint_items)
286
+ model_checkpoint_dir = os.path.join(self.get_temp_dir(), "checkpoint")
287
+ checkpoint.save(os.path.join(model_checkpoint_dir, "test"))
288
+
289
+ vocab_file, sp_model_file = _get_vocab_or_sp_model_dummy(
290
+ self.get_temp_dir(), use_sp_model=True)
291
+ export_path = os.path.join(self.get_temp_dir(), "hub")
292
+ export_tfhub_lib.export_model(
293
+ export_path=export_path,
294
+ encoder_config=encoder_config,
295
+ model_checkpoint_path=tf.train.latest_checkpoint(model_checkpoint_dir),
296
+ with_mlm=True,
297
+ copy_pooler_dense_to_encoder=True,
298
+ vocab_file=vocab_file,
299
+ sp_model_file=sp_model_file,
300
+ do_lower_case=True)
301
+ # Restores a hub KerasLayer.
302
+ hub_layer = hub.KerasLayer(export_path, trainable=True)
303
+ dummy_ids = np.zeros((2, 10), dtype=np.int32)
304
+ input_dict = dict(
305
+ input_word_ids=dummy_ids,
306
+ input_mask=dummy_ids,
307
+ input_type_ids=dummy_ids)
308
+ hub_pooled_output = hub_layer(input_dict)["pooled_output"]
309
+ encoder_outputs = encoder(input_dict)
310
+ # Verify that hub_layer's pooled_output is the same as the output of next
311
+ # sentence prediction's dense layer.
312
+ pretrained_pooled_output = cls_heads[0].dense(
313
+ (encoder_outputs["sequence_output"][:, 0, :]))
314
+ self.assertAllClose(hub_pooled_output, pretrained_pooled_output)
315
+ # But the pooled_output between encoder and hub_layer are not the same.
316
+ encoder_pooled_output = encoder_outputs["pooled_output"]
317
+ self.assertNotAllClose(hub_pooled_output, encoder_pooled_output)
318
+
319
+ @parameterized.named_parameters(
320
+ ("Bert", True),
321
+ ("Albert", False),
322
+ )
323
+ def test_export_model_with_mlm(self, use_bert):
324
+ # Create the encoder and export it.
325
+ hidden_size = 16
326
+ num_hidden_layers = 2
327
+ bert_config, encoder_config = _get_bert_config_or_encoder_config(
328
+ use_bert, hidden_size, num_hidden_layers)
329
+ bert_model, pretrainer = export_tfhub_lib._create_model(
330
+ bert_config=bert_config, encoder_config=encoder_config, with_mlm=True)
331
+ self.assertEmpty(
332
+ _find_lambda_layers(bert_model),
333
+ "Lambda layers are non-portable since they serialize Python bytecode.")
334
+ bert_model_with_mlm = bert_model.mlm
335
+ model_checkpoint_dir = os.path.join(self.get_temp_dir(), "checkpoint")
336
+
337
+ checkpoint = tf.train.Checkpoint(**pretrainer.checkpoint_items)
338
+
339
+ checkpoint.save(os.path.join(model_checkpoint_dir, "test"))
340
+ model_checkpoint_path = tf.train.latest_checkpoint(model_checkpoint_dir)
341
+
342
+ vocab_file, sp_model_file = _get_vocab_or_sp_model_dummy(
343
+ self.get_temp_dir(), use_sp_model=not use_bert)
344
+ export_path = os.path.join(self.get_temp_dir(), "hub")
345
+ export_tfhub_lib.export_model(
346
+ export_path=export_path,
347
+ bert_config=bert_config,
348
+ encoder_config=encoder_config,
349
+ model_checkpoint_path=model_checkpoint_path,
350
+ with_mlm=True,
351
+ vocab_file=vocab_file,
352
+ sp_model_file=sp_model_file,
353
+ do_lower_case=True)
354
+
355
+ # Restore the exported model.
356
+ hub_layer = hub.KerasLayer(export_path, trainable=True)
357
+
358
+ # Check legacy tokenization data.
359
+ if use_bert:
360
+ self.assertTrue(hub_layer.resolved_object.do_lower_case.numpy())
361
+ self.assertEqual("dummy content",
362
+ _read_asset(hub_layer.resolved_object.vocab_file))
363
+ self.assertFalse(hasattr(hub_layer.resolved_object, "sp_model_file"))
364
+ else:
365
+ self.assertFalse(hasattr(hub_layer.resolved_object, "do_lower_case"))
366
+ self.assertFalse(hasattr(hub_layer.resolved_object, "vocab_file"))
367
+ self.assertEqual("dummy content",
368
+ _read_asset(hub_layer.resolved_object.sp_model_file))
369
+
370
+ # Check restored weights.
371
+ # Note that we set `_auto_track_sub_layers` to False when exporting the
372
+ # SavedModel, so hub_layer has the same number of weights as bert_model;
373
+ # otherwise, hub_layer will have extra weights from its `mlm` subobject.
374
+ self.assertEqual(
375
+ len(bert_model.trainable_weights), len(hub_layer.trainable_weights))
376
+ for source_weight, hub_weight in zip(bert_model.trainable_weights,
377
+ hub_layer.trainable_weights):
378
+ self.assertAllClose(source_weight, hub_weight)
379
+
380
+ # Check computation.
381
+ seq_length = 10
382
+ dummy_ids = np.zeros((2, seq_length), dtype=np.int32)
383
+ input_dict = dict(
384
+ input_word_ids=dummy_ids,
385
+ input_mask=dummy_ids,
386
+ input_type_ids=dummy_ids)
387
+ hub_outputs_dict = hub_layer(input_dict)
388
+ source_outputs_dict = bert_model(input_dict)
389
+ encoder_outputs_dict = pretrainer.encoder_network(
390
+ [dummy_ids, dummy_ids, dummy_ids])
391
+ self.assertEqual(hub_outputs_dict["pooled_output"].shape, (2, hidden_size))
392
+ self.assertEqual(hub_outputs_dict["sequence_output"].shape,
393
+ (2, seq_length, hidden_size))
394
+ for output_key in ("pooled_output", "sequence_output", "encoder_outputs"):
395
+ self.assertAllClose(source_outputs_dict[output_key],
396
+ hub_outputs_dict[output_key])
397
+ self.assertAllClose(source_outputs_dict[output_key],
398
+ encoder_outputs_dict[output_key])
399
+
400
+ # The "default" output of BERT as a text representation is pooled_output.
401
+ self.assertAllClose(hub_outputs_dict["pooled_output"],
402
+ hub_outputs_dict["default"])
403
+
404
+ # Test that training=True makes a difference (activates dropout).
405
+ def _dropout_mean_stddev(training, num_runs=20):
406
+ input_ids = np.array([[14, 12, 42, 95, 99]], np.int32)
407
+ input_dict = dict(
408
+ input_word_ids=input_ids,
409
+ input_mask=np.ones_like(input_ids),
410
+ input_type_ids=np.zeros_like(input_ids))
411
+ outputs = np.concatenate([
412
+ hub_layer(input_dict, training=training)["pooled_output"]
413
+ for _ in range(num_runs)
414
+ ])
415
+ return np.mean(np.std(outputs, axis=0))
416
+
417
+ self.assertLess(_dropout_mean_stddev(training=False), 1e-6)
418
+ self.assertGreater(_dropout_mean_stddev(training=True), 1e-3)
419
+
420
+ # Checks sub-object `mlm`.
421
+ self.assertTrue(hasattr(hub_layer.resolved_object, "mlm"))
422
+
423
+ self.assertLen(hub_layer.resolved_object.mlm.trainable_variables,
424
+ len(bert_model_with_mlm.trainable_weights))
425
+ self.assertLen(hub_layer.resolved_object.mlm.trainable_variables,
426
+ len(pretrainer.trainable_weights))
427
+ for source_weight, hub_weight, pretrainer_weight in zip(
428
+ bert_model_with_mlm.trainable_weights,
429
+ hub_layer.resolved_object.mlm.trainable_variables,
430
+ pretrainer.trainable_weights):
431
+ self.assertAllClose(source_weight, hub_weight)
432
+ self.assertAllClose(source_weight, pretrainer_weight)
433
+
434
+ max_predictions_per_seq = 4
435
+ mlm_positions = np.zeros((2, max_predictions_per_seq), dtype=np.int32)
436
+ input_dict = dict(
437
+ input_word_ids=dummy_ids,
438
+ input_mask=dummy_ids,
439
+ input_type_ids=dummy_ids,
440
+ masked_lm_positions=mlm_positions)
441
+ hub_mlm_outputs_dict = hub_layer.resolved_object.mlm(input_dict)
442
+ source_mlm_outputs_dict = bert_model_with_mlm(input_dict)
443
+ for output_key in ("pooled_output", "sequence_output", "mlm_logits",
444
+ "encoder_outputs"):
445
+ self.assertAllClose(hub_mlm_outputs_dict[output_key],
446
+ source_mlm_outputs_dict[output_key])
447
+
448
+ pretrainer_mlm_logits_output = pretrainer(input_dict)["mlm_logits"]
449
+ self.assertAllClose(hub_mlm_outputs_dict["mlm_logits"],
450
+ pretrainer_mlm_logits_output)
451
+
452
+ # Test that training=True makes a difference (activates dropout).
453
+ def _dropout_mean_stddev_mlm(training, num_runs=20):
454
+ input_ids = np.array([[14, 12, 42, 95, 99]], np.int32)
455
+ mlm_position_ids = np.array([[1, 2, 3, 4]], np.int32)
456
+ input_dict = dict(
457
+ input_word_ids=input_ids,
458
+ input_mask=np.ones_like(input_ids),
459
+ input_type_ids=np.zeros_like(input_ids),
460
+ masked_lm_positions=mlm_position_ids)
461
+ outputs = np.concatenate([
462
+ hub_layer.resolved_object.mlm(input_dict,
463
+ training=training)["pooled_output"]
464
+ for _ in range(num_runs)
465
+ ])
466
+ return np.mean(np.std(outputs, axis=0))
467
+
468
+ self.assertLess(_dropout_mean_stddev_mlm(training=False), 1e-6)
469
+ self.assertGreater(_dropout_mean_stddev_mlm(training=True), 1e-3)
470
+
471
+ # Test propagation of seq_length in shape inference.
472
+ input_word_ids = tf_keras.layers.Input(shape=(seq_length,), dtype=tf.int32)
473
+ input_mask = tf_keras.layers.Input(shape=(seq_length,), dtype=tf.int32)
474
+ input_type_ids = tf_keras.layers.Input(shape=(seq_length,), dtype=tf.int32)
475
+ input_dict = dict(
476
+ input_word_ids=input_word_ids,
477
+ input_mask=input_mask,
478
+ input_type_ids=input_type_ids)
479
+ hub_outputs_dict = hub_layer(input_dict)
480
+ self.assertEqual(hub_outputs_dict["pooled_output"].shape.as_list(),
481
+ [None, hidden_size])
482
+ self.assertEqual(hub_outputs_dict["sequence_output"].shape.as_list(),
483
+ [None, seq_length, hidden_size])
484
+
485
+
486
+ _STRING_NOT_TO_LEAK = "private_path_component_"
487
+
488
+
489
+ class ExportPreprocessingTest(tf.test.TestCase, parameterized.TestCase):
490
+
491
+ def _make_vocab_file(self, vocab, filename="vocab.txt", add_mask_token=False):
492
+ """Creates wordpiece vocab file with given words plus special tokens.
493
+
494
+ The tokens of the resulting model are, in this order:
495
+ [PAD], [UNK], [CLS], [SEP], [MASK]*, ...vocab...
496
+ *=if requested by args.
497
+
498
+ This function also accepts wordpieces that start with the ## continuation
499
+ marker, but avoiding those makes this function interchangeable with
500
+ _make_sp_model_file(), up to the extra dimension returned by BertTokenizer.
501
+
502
+ Args:
503
+ vocab: a list of strings with the words or wordpieces to put into the
504
+ model's vocabulary. Do not include special tokens here.
505
+ filename: Optionally, a filename (relative to the temporary directory
506
+ created by this function).
507
+ add_mask_token: an optional bool, whether to include a [MASK] token.
508
+
509
+ Returns:
510
+ The absolute filename of the created vocab file.
511
+ """
512
+ full_vocab = ["[PAD]", "[UNK]", "[CLS]", "[SEP]"
513
+ ] + ["[MASK]"] * add_mask_token + vocab
514
+ path = os.path.join(
515
+ tempfile.mkdtemp(
516
+ dir=self.get_temp_dir(), # New subdir each time.
517
+ prefix=_STRING_NOT_TO_LEAK),
518
+ filename)
519
+ with tf.io.gfile.GFile(path, "w") as f:
520
+ f.write("\n".join(full_vocab + [""]))
521
+ return path
522
+
523
+ def _make_sp_model_file(self, vocab, prefix="spm", add_mask_token=False):
524
+ """Creates Sentencepiece word model with given words plus special tokens.
525
+
526
+ The tokens of the resulting model are, in this order:
527
+ <pad>, <unk>, [CLS], [SEP], [MASK]*, ...vocab..., <s>, </s>
528
+ *=if requested by args.
529
+
530
+ The words in the input vocab are plain text, without the whitespace marker.
531
+ That makes this function interchangeable with _make_vocab_file().
532
+
533
+ Args:
534
+ vocab: a list of strings with the words to put into the model's
535
+ vocabulary. Do not include special tokens here.
536
+ prefix: an optional string, to change the filename prefix for the model
537
+ (relative to the temporary directory created by this function).
538
+ add_mask_token: an optional bool, whether to include a [MASK] token.
539
+
540
+ Returns:
541
+ The absolute filename of the created Sentencepiece model file.
542
+ """
543
+ model_prefix = os.path.join(
544
+ tempfile.mkdtemp(dir=self.get_temp_dir()), # New subdir each time.
545
+ prefix)
546
+ input_file = model_prefix + "_train_input.txt"
547
+ # Create input text for training the sp model from the tokens provided.
548
+ # Repeat tokens, the earlier the more, because they are sorted by frequency.
549
+ input_text = []
550
+ for i, token in enumerate(vocab):
551
+ input_text.append(" ".join([token] * (len(vocab) - i)))
552
+ with tf.io.gfile.GFile(input_file, "w") as f:
553
+ f.write("\n".join(input_text + [""]))
554
+ control_symbols = "[CLS],[SEP]"
555
+ full_vocab_size = len(vocab) + 6 # <pad>, <unk>, [CLS], [SEP], <s>, </s>.
556
+ if add_mask_token:
557
+ control_symbols += ",[MASK]"
558
+ full_vocab_size += 1
559
+ flags = dict(
560
+ model_prefix=model_prefix,
561
+ model_type="word",
562
+ input=input_file,
563
+ pad_id=0,
564
+ unk_id=1,
565
+ control_symbols=control_symbols,
566
+ vocab_size=full_vocab_size,
567
+ bos_id=full_vocab_size - 2,
568
+ eos_id=full_vocab_size - 1)
569
+ SentencePieceTrainer.Train(" ".join(
570
+ ["--{}={}".format(k, v) for k, v in flags.items()]))
571
+ return model_prefix + ".model"
572
+
573
+ def _do_export(self,
574
+ vocab,
575
+ do_lower_case,
576
+ default_seq_length=128,
577
+ tokenize_with_offsets=True,
578
+ use_sp_model=False,
579
+ experimental_disable_assert=False,
580
+ add_mask_token=False):
581
+ """Runs SavedModel export and returns the export_path."""
582
+ export_path = tempfile.mkdtemp(dir=self.get_temp_dir())
583
+ vocab_file = sp_model_file = None
584
+ if use_sp_model:
585
+ sp_model_file = self._make_sp_model_file(
586
+ vocab, add_mask_token=add_mask_token)
587
+ else:
588
+ vocab_file = self._make_vocab_file(vocab, add_mask_token=add_mask_token)
589
+ export_tfhub_lib.export_preprocessing(
590
+ export_path,
591
+ vocab_file=vocab_file,
592
+ sp_model_file=sp_model_file,
593
+ do_lower_case=do_lower_case,
594
+ tokenize_with_offsets=tokenize_with_offsets,
595
+ default_seq_length=default_seq_length,
596
+ experimental_disable_assert=experimental_disable_assert)
597
+ # Invalidate the original filename to verify loading from the SavedModel.
598
+ tf.io.gfile.remove(sp_model_file or vocab_file)
599
+ return export_path
600
+
601
+ def test_no_leaks(self):
602
+ """Tests not leaking the path to the original vocab file."""
603
+ path = self._do_export(["d", "ef", "abc", "xy"],
604
+ do_lower_case=True,
605
+ use_sp_model=False)
606
+ with tf.io.gfile.GFile(os.path.join(path, "saved_model.pb"), "rb") as f:
607
+ self.assertFalse( # pylint: disable=g-generic-assert
608
+ _STRING_NOT_TO_LEAK.encode("ascii") in f.read())
609
+
610
+ @parameterized.named_parameters(("Bert", False), ("Sentencepiece", True))
611
+ def test_exported_callables(self, use_sp_model):
612
+ preprocess = tf.saved_model.load(
613
+ self._do_export(
614
+ ["d", "ef", "abc", "xy"],
615
+ do_lower_case=True,
616
+ # TODO(b/181866850): drop this.
617
+ tokenize_with_offsets=not use_sp_model,
618
+ # TODO(b/175369555): drop this.
619
+ experimental_disable_assert=True,
620
+ use_sp_model=use_sp_model))
621
+
622
+ def fold_dim(rt):
623
+ """Removes the word/subword distinction of BertTokenizer."""
624
+ return rt if use_sp_model else rt.merge_dims(1, 2)
625
+
626
+ # .tokenize()
627
+ inputs = tf.constant(["abc d ef", "ABC D EF d"])
628
+ token_ids = preprocess.tokenize(inputs)
629
+ self.assertAllEqual(
630
+ fold_dim(token_ids), tf.ragged.constant([[6, 4, 5], [6, 4, 5, 4]]))
631
+
632
+ special_tokens_dict = {
633
+ k: v.numpy().item() # Expecting eager Tensor, converting to Python.
634
+ for k, v in preprocess.tokenize.get_special_tokens_dict().items()
635
+ }
636
+ self.assertDictEqual(
637
+ special_tokens_dict,
638
+ dict(
639
+ padding_id=0,
640
+ start_of_sequence_id=2,
641
+ end_of_segment_id=3,
642
+ vocab_size=4 + 6 if use_sp_model else 4 + 4))
643
+
644
+ # .tokenize_with_offsets()
645
+ if use_sp_model:
646
+ # TODO(b/181866850): Enable tokenize_with_offsets when it works and test.
647
+ self.assertFalse(hasattr(preprocess, "tokenize_with_offsets"))
648
+ else:
649
+ token_ids, start_offsets, limit_offsets = (
650
+ preprocess.tokenize_with_offsets(inputs))
651
+ self.assertAllEqual(
652
+ fold_dim(token_ids), tf.ragged.constant([[6, 4, 5], [6, 4, 5, 4]]))
653
+ self.assertAllEqual(
654
+ fold_dim(start_offsets), tf.ragged.constant([[0, 4, 6], [0, 4, 6,
655
+ 9]]))
656
+ self.assertAllEqual(
657
+ fold_dim(limit_offsets), tf.ragged.constant([[3, 5, 8], [3, 5, 8,
658
+ 10]]))
659
+ self.assertIs(preprocess.tokenize.get_special_tokens_dict,
660
+ preprocess.tokenize_with_offsets.get_special_tokens_dict)
661
+
662
+ # Root callable.
663
+ bert_inputs = preprocess(inputs)
664
+ self.assertAllEqual(bert_inputs["input_word_ids"].shape.as_list(), [2, 128])
665
+ self.assertAllEqual(
666
+ bert_inputs["input_word_ids"][:, :10],
667
+ tf.constant([[2, 6, 4, 5, 3, 0, 0, 0, 0, 0],
668
+ [2, 6, 4, 5, 4, 3, 0, 0, 0, 0]]))
669
+ self.assertAllEqual(bert_inputs["input_mask"].shape.as_list(), [2, 128])
670
+ self.assertAllEqual(
671
+ bert_inputs["input_mask"][:, :10],
672
+ tf.constant([[1, 1, 1, 1, 1, 0, 0, 0, 0, 0],
673
+ [1, 1, 1, 1, 1, 1, 0, 0, 0, 0]]))
674
+ self.assertAllEqual(bert_inputs["input_type_ids"].shape.as_list(), [2, 128])
675
+ self.assertAllEqual(
676
+ bert_inputs["input_type_ids"][:, :10],
677
+ tf.constant([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
678
+ [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]))
679
+
680
+ # .bert_pack_inputs()
681
+ inputs_2 = tf.constant(["d xy", "xy abc"])
682
+ token_ids_2 = preprocess.tokenize(inputs_2)
683
+ bert_inputs = preprocess.bert_pack_inputs([token_ids, token_ids_2],
684
+ seq_length=256)
685
+ self.assertAllEqual(bert_inputs["input_word_ids"].shape.as_list(), [2, 256])
686
+ self.assertAllEqual(
687
+ bert_inputs["input_word_ids"][:, :10],
688
+ tf.constant([[2, 6, 4, 5, 3, 4, 7, 3, 0, 0],
689
+ [2, 6, 4, 5, 4, 3, 7, 6, 3, 0]]))
690
+ self.assertAllEqual(bert_inputs["input_mask"].shape.as_list(), [2, 256])
691
+ self.assertAllEqual(
692
+ bert_inputs["input_mask"][:, :10],
693
+ tf.constant([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0],
694
+ [1, 1, 1, 1, 1, 1, 1, 1, 1, 0]]))
695
+ self.assertAllEqual(bert_inputs["input_type_ids"].shape.as_list(), [2, 256])
696
+ self.assertAllEqual(
697
+ bert_inputs["input_type_ids"][:, :10],
698
+ tf.constant([[0, 0, 0, 0, 0, 1, 1, 1, 0, 0],
699
+ [0, 0, 0, 0, 0, 0, 1, 1, 1, 0]]))
700
+
701
+ # For BertTokenizer only: repeat relevant parts for do_lower_case=False,
702
+ # default_seq_length=10, experimental_disable_assert=False,
703
+ # tokenize_with_offsets=False, and without folding the word/subword dimension.
704
+ def test_cased_length10(self):
705
+ preprocess = tf.saved_model.load(
706
+ self._do_export(["d", "##ef", "abc", "ABC"],
707
+ do_lower_case=False,
708
+ default_seq_length=10,
709
+ tokenize_with_offsets=False,
710
+ use_sp_model=False,
711
+ experimental_disable_assert=False))
712
+ inputs = tf.constant(["abc def", "ABC DEF"])
713
+ token_ids = preprocess.tokenize(inputs)
714
+ self.assertAllEqual(token_ids,
715
+ tf.ragged.constant([[[6], [4, 5]], [[7], [1]]]))
716
+
717
+ self.assertFalse(hasattr(preprocess, "tokenize_with_offsets"))
718
+
719
+ bert_inputs = preprocess(inputs)
720
+ self.assertAllEqual(
721
+ bert_inputs["input_word_ids"],
722
+ tf.constant([[2, 6, 4, 5, 3, 0, 0, 0, 0, 0],
723
+ [2, 7, 1, 3, 0, 0, 0, 0, 0, 0]]))
724
+ self.assertAllEqual(
725
+ bert_inputs["input_mask"],
726
+ tf.constant([[1, 1, 1, 1, 1, 0, 0, 0, 0, 0],
727
+ [1, 1, 1, 1, 0, 0, 0, 0, 0, 0]]))
728
+ self.assertAllEqual(
729
+ bert_inputs["input_type_ids"],
730
+ tf.constant([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
731
+ [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]))
732
+
733
+ inputs_2 = tf.constant(["d ABC", "ABC abc"])
734
+ token_ids_2 = preprocess.tokenize(inputs_2)
735
+ bert_inputs = preprocess.bert_pack_inputs([token_ids, token_ids_2])
736
+ # Test default seq_length=10.
737
+ self.assertAllEqual(
738
+ bert_inputs["input_word_ids"],
739
+ tf.constant([[2, 6, 4, 5, 3, 4, 7, 3, 0, 0],
740
+ [2, 7, 1, 3, 7, 6, 3, 0, 0, 0]]))
741
+ self.assertAllEqual(
742
+ bert_inputs["input_mask"],
743
+ tf.constant([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0],
744
+ [1, 1, 1, 1, 1, 1, 1, 0, 0, 0]]))
745
+ self.assertAllEqual(
746
+ bert_inputs["input_type_ids"],
747
+ tf.constant([[0, 0, 0, 0, 0, 1, 1, 1, 0, 0],
748
+ [0, 0, 0, 0, 1, 1, 1, 0, 0, 0]]))
749
+
750
+ # XLA requires fixed shapes for tensors found in graph mode.
751
+ # Statically known shapes in Python are a particularly firm way to
752
+ # guarantee that, and they are generally more convenient to work with.
753
+ # We test that the exported SavedModel plays well with TF's shape
754
+ # inference when applied to fully or partially known input shapes.
755
+ @parameterized.named_parameters(("Bert", False), ("Sentencepiece", True))
756
+ def test_shapes(self, use_sp_model):
757
+ preprocess = tf.saved_model.load(
758
+ self._do_export(
759
+ ["abc", "def"],
760
+ do_lower_case=True,
761
+ # TODO(b/181866850): drop this.
762
+ tokenize_with_offsets=not use_sp_model,
763
+ # TODO(b/175369555): drop this.
764
+ experimental_disable_assert=True,
765
+ use_sp_model=use_sp_model))
766
+
767
+ def expected_bert_input_shapes(batch_size, seq_length):
768
+ return dict(
769
+ input_word_ids=[batch_size, seq_length],
770
+ input_mask=[batch_size, seq_length],
771
+ input_type_ids=[batch_size, seq_length])
772
+
773
+ for batch_size in [7, None]:
774
+ if use_sp_model:
775
+ token_out_shape = [batch_size, None] # No word/subword distinction.
776
+ else:
777
+ token_out_shape = [batch_size, None, None]
778
+ self.assertEqual(
779
+ _result_shapes_in_tf_function(preprocess.tokenize,
780
+ tf.TensorSpec([batch_size], tf.string)),
781
+ token_out_shape, "with batch_size=%s" % batch_size)
782
+ # TODO(b/181866850): Enable tokenize_with_offsets when it works and test.
783
+ if use_sp_model:
784
+ self.assertFalse(hasattr(preprocess, "tokenize_with_offsets"))
785
+ else:
786
+ self.assertEqual(
787
+ _result_shapes_in_tf_function(
788
+ preprocess.tokenize_with_offsets,
789
+ tf.TensorSpec([batch_size], tf.string)), [token_out_shape] * 3,
790
+ "with batch_size=%s" % batch_size)
791
+ self.assertEqual(
792
+ _result_shapes_in_tf_function(
793
+ preprocess.bert_pack_inputs,
794
+ [tf.RaggedTensorSpec([batch_size, None, None], tf.int32)] * 2,
795
+ seq_length=256), expected_bert_input_shapes(batch_size, 256),
796
+ "with batch_size=%s" % batch_size)
797
+ self.assertEqual(
798
+ _result_shapes_in_tf_function(preprocess,
799
+ tf.TensorSpec([batch_size], tf.string)),
800
+ expected_bert_input_shapes(batch_size, 128),
801
+ "with batch_size=%s" % batch_size)
802
+
803
+ @parameterized.named_parameters(("Bert", False), ("Sentencepiece", True))
804
+ def test_reexport(self, use_sp_model):
805
+ """Test that preprocess keeps working after another save/load cycle."""
806
+ path1 = self._do_export(
807
+ ["d", "ef", "abc", "xy"],
808
+ do_lower_case=True,
809
+ default_seq_length=10,
810
+ tokenize_with_offsets=False,
811
+ experimental_disable_assert=True, # TODO(b/175369555): drop this.
812
+ use_sp_model=use_sp_model)
813
+ path2 = path1.rstrip("/") + ".2"
814
+ model1 = tf.saved_model.load(path1)
815
+ tf.saved_model.save(model1, path2)
816
+ # Delete the first SavedModel to test that the sceond one loads by itself.
817
+ # https://github.com/tensorflow/tensorflow/issues/46456 reports such a
818
+ # failure case for BertTokenizer.
819
+ tf.io.gfile.rmtree(path1)
820
+ model2 = tf.saved_model.load(path2)
821
+
822
+ inputs = tf.constant(["abc d ef", "ABC D EF d"])
823
+ bert_inputs = model2(inputs)
824
+ self.assertAllEqual(
825
+ bert_inputs["input_word_ids"],
826
+ tf.constant([[2, 6, 4, 5, 3, 0, 0, 0, 0, 0],
827
+ [2, 6, 4, 5, 4, 3, 0, 0, 0, 0]]))
828
+ self.assertAllEqual(
829
+ bert_inputs["input_mask"],
830
+ tf.constant([[1, 1, 1, 1, 1, 0, 0, 0, 0, 0],
831
+ [1, 1, 1, 1, 1, 1, 0, 0, 0, 0]]))
832
+ self.assertAllEqual(
833
+ bert_inputs["input_type_ids"],
834
+ tf.constant([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
835
+ [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]))
836
+
837
+ @parameterized.named_parameters(("Bert", True), ("Albert", False))
838
+ def test_preprocessing_for_mlm(self, use_bert):
839
+ """Combines both SavedModel types and TF.text helpers for MLM."""
840
+ # Create the preprocessing SavedModel with a [MASK] token.
841
+ non_special_tokens = [
842
+ "hello", "world", "nice", "movie", "great", "actors", "quick", "fox",
843
+ "lazy", "dog"
844
+ ]
845
+
846
+ preprocess = tf.saved_model.load(
847
+ self._do_export(
848
+ non_special_tokens,
849
+ do_lower_case=True,
850
+ tokenize_with_offsets=use_bert, # TODO(b/181866850): drop this.
851
+ experimental_disable_assert=True, # TODO(b/175369555): drop this.
852
+ add_mask_token=True,
853
+ use_sp_model=not use_bert))
854
+ vocab_size = len(non_special_tokens) + (5 if use_bert else 7)
855
+
856
+ # Create the encoder SavedModel with an .mlm subobject.
857
+ hidden_size = 16
858
+ num_hidden_layers = 2
859
+ bert_config, encoder_config = _get_bert_config_or_encoder_config(
860
+ use_bert_config=use_bert,
861
+ hidden_size=hidden_size,
862
+ num_hidden_layers=num_hidden_layers,
863
+ vocab_size=vocab_size)
864
+ _, pretrainer = export_tfhub_lib._create_model(
865
+ bert_config=bert_config, encoder_config=encoder_config, with_mlm=True)
866
+ model_checkpoint_dir = os.path.join(self.get_temp_dir(), "checkpoint")
867
+ checkpoint = tf.train.Checkpoint(**pretrainer.checkpoint_items)
868
+ checkpoint.save(os.path.join(model_checkpoint_dir, "test"))
869
+ model_checkpoint_path = tf.train.latest_checkpoint(model_checkpoint_dir)
870
+ vocab_file, sp_model_file = _get_vocab_or_sp_model_dummy( # Not used below.
871
+ self.get_temp_dir(), use_sp_model=not use_bert)
872
+ encoder_export_path = os.path.join(self.get_temp_dir(), "encoder_export")
873
+ export_tfhub_lib.export_model(
874
+ export_path=encoder_export_path,
875
+ bert_config=bert_config,
876
+ encoder_config=encoder_config,
877
+ model_checkpoint_path=model_checkpoint_path,
878
+ with_mlm=True,
879
+ vocab_file=vocab_file,
880
+ sp_model_file=sp_model_file,
881
+ do_lower_case=True)
882
+ encoder = tf.saved_model.load(encoder_export_path)
883
+
884
+ # Get special tokens from the vocab (and vocab size).
885
+ special_tokens_dict = preprocess.tokenize.get_special_tokens_dict()
886
+ self.assertEqual(int(special_tokens_dict["vocab_size"]), vocab_size)
887
+ padding_id = int(special_tokens_dict["padding_id"])
888
+ self.assertEqual(padding_id, 0)
889
+ start_of_sequence_id = int(special_tokens_dict["start_of_sequence_id"])
890
+ self.assertEqual(start_of_sequence_id, 2)
891
+ end_of_segment_id = int(special_tokens_dict["end_of_segment_id"])
892
+ self.assertEqual(end_of_segment_id, 3)
893
+ mask_id = int(special_tokens_dict["mask_id"])
894
+ self.assertEqual(mask_id, 4)
895
+
896
+ # A batch of 3 segment pairs.
897
+ raw_segments = [
898
+ tf.constant(["hello", "nice movie", "quick fox"]),
899
+ tf.constant(["world", "great actors", "lazy dog"])
900
+ ]
901
+ batch_size = 3
902
+
903
+ # Misc hyperparameters.
904
+ seq_length = 10
905
+ max_selections_per_seq = 2
906
+
907
+ # Tokenize inputs.
908
+ tokenized_segments = [preprocess.tokenize(s) for s in raw_segments]
909
+ # Trim inputs to eventually fit seq_lentgh.
910
+ num_special_tokens = len(raw_segments) + 1
911
+ trimmed_segments = text.WaterfallTrimmer(
912
+ seq_length - num_special_tokens).trim(tokenized_segments)
913
+ # Combine input segments into one input sequence.
914
+ input_ids, segment_ids = text.combine_segments(
915
+ trimmed_segments,
916
+ start_of_sequence_id=start_of_sequence_id,
917
+ end_of_segment_id=end_of_segment_id)
918
+ # Apply random masking controlled by policy objects.
919
+ (masked_input_ids, masked_lm_positions,
920
+ masked_ids) = text.mask_language_model(
921
+ input_ids=input_ids,
922
+ item_selector=text.RandomItemSelector(
923
+ max_selections_per_seq,
924
+ selection_rate=0.5, # Adjusted for the short test examples.
925
+ unselectable_ids=[start_of_sequence_id, end_of_segment_id]),
926
+ mask_values_chooser=text.MaskValuesChooser(
927
+ vocab_size=vocab_size,
928
+ mask_token=mask_id,
929
+ # Always put [MASK] to have a predictable result.
930
+ mask_token_rate=1.0,
931
+ random_token_rate=0.0))
932
+ # Pad to fixed-length Transformer encoder inputs.
933
+ input_word_ids, _ = text.pad_model_inputs(
934
+ masked_input_ids, seq_length, pad_value=padding_id)
935
+ input_type_ids, input_mask = text.pad_model_inputs(
936
+ segment_ids, seq_length, pad_value=0)
937
+ masked_lm_positions, _ = text.pad_model_inputs(
938
+ masked_lm_positions, max_selections_per_seq, pad_value=0)
939
+ masked_lm_positions = tf.cast(masked_lm_positions, tf.int32)
940
+ num_predictions = int(tf.shape(masked_lm_positions)[1])
941
+
942
+ # Test transformer inputs.
943
+ self.assertEqual(num_predictions, max_selections_per_seq)
944
+ expected_word_ids = np.array([
945
+ # [CLS] hello [SEP] world [SEP]
946
+ [2, 5, 3, 6, 3, 0, 0, 0, 0, 0],
947
+ # [CLS] nice movie [SEP] great actors [SEP]
948
+ [2, 7, 8, 3, 9, 10, 3, 0, 0, 0],
949
+ # [CLS] brown fox [SEP] lazy dog [SEP]
950
+ [2, 11, 12, 3, 13, 14, 3, 0, 0, 0]
951
+ ])
952
+ for i in range(batch_size):
953
+ for j in range(num_predictions):
954
+ k = int(masked_lm_positions[i, j])
955
+ if k != 0:
956
+ expected_word_ids[i, k] = 4 # [MASK]
957
+ self.assertAllEqual(input_word_ids, expected_word_ids)
958
+
959
+ # Call the MLM head of the Transformer encoder.
960
+ mlm_inputs = dict(
961
+ input_word_ids=input_word_ids,
962
+ input_mask=input_mask,
963
+ input_type_ids=input_type_ids,
964
+ masked_lm_positions=masked_lm_positions,
965
+ )
966
+ mlm_outputs = encoder.mlm(mlm_inputs)
967
+ self.assertEqual(mlm_outputs["pooled_output"].shape,
968
+ (batch_size, hidden_size))
969
+ self.assertEqual(mlm_outputs["sequence_output"].shape,
970
+ (batch_size, seq_length, hidden_size))
971
+ self.assertEqual(mlm_outputs["mlm_logits"].shape,
972
+ (batch_size, num_predictions, vocab_size))
973
+ self.assertLen(mlm_outputs["encoder_outputs"], num_hidden_layers)
974
+
975
+ # A real trainer would now compute the loss of mlm_logits
976
+ # trying to predict the masked_ids.
977
+ del masked_ids # Unused.
978
+
979
+ @parameterized.named_parameters(("Bert", False), ("Sentencepiece", True))
980
+ def test_special_tokens_in_estimator(self, use_sp_model):
981
+ """Tests getting special tokens without an Eager init context."""
982
+ preprocess_export_path = self._do_export(["d", "ef", "abc", "xy"],
983
+ do_lower_case=True,
984
+ use_sp_model=use_sp_model,
985
+ tokenize_with_offsets=False)
986
+
987
+ def _get_special_tokens_dict(obj):
988
+ """Returns special tokens of restored tokenizer as Python values."""
989
+ if tf.executing_eagerly():
990
+ special_tokens_numpy = {
991
+ k: v.numpy() for k, v in obj.get_special_tokens_dict()
992
+ }
993
+ else:
994
+ with tf.Graph().as_default():
995
+ # This code expects `get_special_tokens_dict()` to be a tf.function
996
+ # with no dependencies (bound args) from the context it was loaded in,
997
+ # and boldly assumes that it can just be called in a dfferent context.
998
+ special_tokens_tensors = obj.get_special_tokens_dict()
999
+ with tf.compat.v1.Session() as sess:
1000
+ special_tokens_numpy = sess.run(special_tokens_tensors)
1001
+ return {
1002
+ k: v.item() # Numpy to Python.
1003
+ for k, v in special_tokens_numpy.items()
1004
+ }
1005
+
1006
+ def input_fn():
1007
+ self.assertFalse(tf.executing_eagerly())
1008
+ # Build a preprocessing Model.
1009
+ sentences = tf_keras.layers.Input(shape=[], dtype=tf.string)
1010
+ preprocess = tf.saved_model.load(preprocess_export_path)
1011
+ tokenize = hub.KerasLayer(preprocess.tokenize)
1012
+ special_tokens_dict = _get_special_tokens_dict(tokenize.resolved_object)
1013
+ for k, v in special_tokens_dict.items():
1014
+ self.assertIsInstance(v, int, "Unexpected type for {}".format(k))
1015
+ tokens = tokenize(sentences)
1016
+ packed_inputs = layers.BertPackInputs(
1017
+ 4, special_tokens_dict=special_tokens_dict)(
1018
+ tokens)
1019
+ preprocessing = tf_keras.Model(sentences, packed_inputs)
1020
+ # Map the dataset.
1021
+ ds = tf.data.Dataset.from_tensors(
1022
+ (tf.constant(["abc", "D EF"]), tf.constant([0, 1])))
1023
+ ds = ds.map(lambda features, labels: (preprocessing(features), labels))
1024
+ return ds
1025
+
1026
+ def model_fn(features, labels, mode):
1027
+ del labels # Unused.
1028
+ return tf_estimator.EstimatorSpec(
1029
+ mode=mode, predictions=features["input_word_ids"])
1030
+
1031
+ estimator = tf_estimator.Estimator(model_fn=model_fn)
1032
+ outputs = list(estimator.predict(input_fn))
1033
+ self.assertAllEqual(outputs, np.array([[2, 6, 3, 0], [2, 4, 5, 3]]))
1034
+
1035
+ # TODO(b/175369555): Remove that code and its test.
1036
+ @parameterized.named_parameters(("Bert", False), ("Sentencepiece", True))
1037
+ def test_check_no_assert(self, use_sp_model):
1038
+ """Tests the self-check during export without assertions."""
1039
+ preprocess_export_path = self._do_export(["d", "ef", "abc", "xy"],
1040
+ do_lower_case=True,
1041
+ use_sp_model=use_sp_model,
1042
+ tokenize_with_offsets=False,
1043
+ experimental_disable_assert=False)
1044
+ with self.assertRaisesRegex(AssertionError,
1045
+ r"failed to suppress \d+ Assert ops"):
1046
+ export_tfhub_lib._check_no_assert(preprocess_export_path)
1047
+
1048
+
1049
+ def _result_shapes_in_tf_function(fn, *args, **kwargs):
1050
+ """Returns shapes (as lists) observed on the result of `fn`.
1051
+
1052
+ Args:
1053
+ fn: A callable.
1054
+ *args: TensorSpecs for Tensor-valued arguments and actual values for
1055
+ Python-valued arguments to fn.
1056
+ **kwargs: Same for keyword arguments.
1057
+
1058
+ Returns:
1059
+ The nest of partial tensor shapes (as lists) that is statically known inside
1060
+ tf.function(fn)(*args, **kwargs) for the nest of its results.
1061
+ """
1062
+ # Use a captured mutable container for a side outout from the wrapper.
1063
+ uninitialized = "uninitialized!"
1064
+ result_shapes_container = [uninitialized]
1065
+ assert result_shapes_container[0] is uninitialized
1066
+
1067
+ @tf.function
1068
+ def shape_reporting_wrapper(*args, **kwargs):
1069
+ result = fn(*args, **kwargs)
1070
+ result_shapes_container[0] = tf.nest.map_structure(
1071
+ lambda x: x.shape.as_list(), result)
1072
+ return result
1073
+
1074
+ shape_reporting_wrapper.get_concrete_function(*args, **kwargs)
1075
+ assert result_shapes_container[0] is not uninitialized
1076
+ return result_shapes_container[0]
1077
+
1078
+
1079
+ if __name__ == "__main__":
1080
+ tf.test.main()
squad_evaluate_v1_1.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The TensorFlow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Evaluation of SQuAD predictions (version 1.1).
16
+
17
+ The functions are copied from
18
+ https://worksheets.codalab.org/rest/bundles/0xbcd57bee090b421c982906709c8c27e1/contents/blob/.
19
+
20
+ The SQuAD dataset is described in this paper:
21
+ SQuAD: 100,000+ Questions for Machine Comprehension of Text
22
+ Pranav Rajpurkar, Jian Zhang, Konstantin Lopyrev, Percy Liang
23
+ https://nlp.stanford.edu/pubs/rajpurkar2016squad.pdf
24
+ """
25
+
26
+ import collections
27
+ import re
28
+ import string
29
+
30
+ # pylint: disable=g-bad-import-order
31
+
32
+ from absl import logging
33
+ # pylint: enable=g-bad-import-order
34
+
35
+
36
+ def _normalize_answer(s):
37
+ """Lowers text and remove punctuation, articles and extra whitespace."""
38
+
39
+ def remove_articles(text):
40
+ return re.sub(r"\b(a|an|the)\b", " ", text)
41
+
42
+ def white_space_fix(text):
43
+ return " ".join(text.split())
44
+
45
+ def remove_punc(text):
46
+ exclude = set(string.punctuation)
47
+ return "".join(ch for ch in text if ch not in exclude)
48
+
49
+ def lower(text):
50
+ return text.lower()
51
+
52
+ return white_space_fix(remove_articles(remove_punc(lower(s))))
53
+
54
+
55
+ def _f1_score(prediction, ground_truth):
56
+ """Computes F1 score by comparing prediction to ground truth."""
57
+ prediction_tokens = _normalize_answer(prediction).split()
58
+ ground_truth_tokens = _normalize_answer(ground_truth).split()
59
+ prediction_counter = collections.Counter(prediction_tokens)
60
+ ground_truth_counter = collections.Counter(ground_truth_tokens)
61
+ common = prediction_counter & ground_truth_counter
62
+ num_same = sum(common.values())
63
+ if num_same == 0:
64
+ return 0
65
+ precision = 1.0 * num_same / len(prediction_tokens)
66
+ recall = 1.0 * num_same / len(ground_truth_tokens)
67
+ f1 = (2 * precision * recall) / (precision + recall)
68
+ return f1
69
+
70
+
71
+ def _exact_match_score(prediction, ground_truth):
72
+ """Checks if predicted answer exactly matches ground truth answer."""
73
+ return _normalize_answer(prediction) == _normalize_answer(ground_truth)
74
+
75
+
76
+ def _metric_max_over_ground_truths(metric_fn, prediction, ground_truths):
77
+ """Computes the max over all metric scores."""
78
+ scores_for_ground_truths = []
79
+ for ground_truth in ground_truths:
80
+ score = metric_fn(prediction, ground_truth)
81
+ scores_for_ground_truths.append(score)
82
+ return max(scores_for_ground_truths)
83
+
84
+
85
+ def evaluate(dataset, predictions):
86
+ """Evaluates predictions for a dataset."""
87
+ f1 = exact_match = total = 0
88
+ for article in dataset:
89
+ for paragraph in article["paragraphs"]:
90
+ for qa in paragraph["qas"]:
91
+ total += 1
92
+ if qa["id"] not in predictions:
93
+ message = "Unanswered question " + qa["id"] + " will receive score 0."
94
+ logging.error(message)
95
+ continue
96
+ ground_truths = [entry["text"] for entry in qa["answers"]]
97
+ prediction = predictions[qa["id"]]
98
+ exact_match += _metric_max_over_ground_truths(_exact_match_score,
99
+ prediction, ground_truths)
100
+ f1 += _metric_max_over_ground_truths(_f1_score, prediction,
101
+ ground_truths)
102
+
103
+ exact_match = exact_match / total
104
+ f1 = f1 / total
105
+
106
+ return {"exact_match": exact_match, "final_f1": f1}
squad_evaluate_v2_0.py ADDED
@@ -0,0 +1,249 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The TensorFlow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Evaluation script for SQuAD version 2.0.
16
+
17
+ The functions are copied and modified from
18
+ https://raw.githubusercontent.com/white127/SQUAD-2.0-bidaf/master/evaluate-v2.0.py
19
+
20
+ In addition to basic functionality, we also compute additional statistics and
21
+ plot precision-recall curves if an additional na_prob.json file is provided.
22
+ This file is expected to map question ID's to the model's predicted probability
23
+ that a question is unanswerable.
24
+ """
25
+
26
+ import collections
27
+ import re
28
+ import string
29
+
30
+ from absl import logging
31
+
32
+
33
+ def _make_qid_to_has_ans(dataset):
34
+ qid_to_has_ans = {}
35
+ for article in dataset:
36
+ for p in article['paragraphs']:
37
+ for qa in p['qas']:
38
+ qid_to_has_ans[qa['id']] = bool(qa['answers'])
39
+ return qid_to_has_ans
40
+
41
+
42
+ def _normalize_answer(s):
43
+ """Lower text and remove punctuation, articles and extra whitespace."""
44
+ def remove_articles(text):
45
+ regex = re.compile(r'\b(a|an|the)\b', re.UNICODE)
46
+ return re.sub(regex, ' ', text)
47
+ def white_space_fix(text):
48
+ return ' '.join(text.split())
49
+ def remove_punc(text):
50
+ exclude = set(string.punctuation)
51
+ return ''.join(ch for ch in text if ch not in exclude)
52
+ def lower(text):
53
+ return text.lower()
54
+ return white_space_fix(remove_articles(remove_punc(lower(s))))
55
+
56
+
57
+ def _get_tokens(s):
58
+ if not s: return []
59
+ return _normalize_answer(s).split()
60
+
61
+
62
+ def _compute_exact(a_gold, a_pred):
63
+ return int(_normalize_answer(a_gold) == _normalize_answer(a_pred))
64
+
65
+
66
+ def _compute_f1(a_gold, a_pred):
67
+ """Compute F1-score."""
68
+ gold_toks = _get_tokens(a_gold)
69
+ pred_toks = _get_tokens(a_pred)
70
+ common = collections.Counter(gold_toks) & collections.Counter(pred_toks)
71
+ num_same = sum(common.values())
72
+ if not gold_toks or not pred_toks:
73
+ # If either is no-answer, then F1 is 1 if they agree, 0 otherwise
74
+ return int(gold_toks == pred_toks)
75
+ if num_same == 0:
76
+ return 0
77
+ precision = 1.0 * num_same / len(pred_toks)
78
+ recall = 1.0 * num_same / len(gold_toks)
79
+ f1 = (2 * precision * recall) / (precision + recall)
80
+ return f1
81
+
82
+
83
+ def _get_raw_scores(dataset, predictions):
84
+ """Compute raw scores."""
85
+ exact_scores = {}
86
+ f1_scores = {}
87
+ for article in dataset:
88
+ for p in article['paragraphs']:
89
+ for qa in p['qas']:
90
+ qid = qa['id']
91
+ gold_answers = [a['text'] for a in qa['answers']
92
+ if _normalize_answer(a['text'])]
93
+ if not gold_answers:
94
+ # For unanswerable questions, only correct answer is empty string
95
+ gold_answers = ['']
96
+ if qid not in predictions:
97
+ logging.error('Missing prediction for %s', qid)
98
+ continue
99
+ a_pred = predictions[qid]
100
+ # Take max over all gold answers
101
+ exact_scores[qid] = max(_compute_exact(a, a_pred) for a in gold_answers)
102
+ f1_scores[qid] = max(_compute_f1(a, a_pred) for a in gold_answers)
103
+ return exact_scores, f1_scores
104
+
105
+
106
+ def _apply_no_ans_threshold(
107
+ scores, na_probs, qid_to_has_ans, na_prob_thresh=1.0):
108
+ new_scores = {}
109
+ for qid, s in scores.items():
110
+ pred_na = na_probs[qid] > na_prob_thresh
111
+ if pred_na:
112
+ new_scores[qid] = float(not qid_to_has_ans[qid])
113
+ else:
114
+ new_scores[qid] = s
115
+ return new_scores
116
+
117
+
118
+ def _make_eval_dict(exact_scores, f1_scores, qid_list=None):
119
+ """Make evaluation result dictionary."""
120
+ if not qid_list:
121
+ total = len(exact_scores)
122
+ return collections.OrderedDict([
123
+ ('exact', 100.0 * sum(exact_scores.values()) / total),
124
+ ('f1', 100.0 * sum(f1_scores.values()) / total),
125
+ ('total', total),
126
+ ])
127
+ else:
128
+ total = len(qid_list)
129
+ return collections.OrderedDict([
130
+ ('exact', 100.0 * sum(exact_scores[k] for k in qid_list) / total),
131
+ ('f1', 100.0 * sum(f1_scores[k] for k in qid_list) / total),
132
+ ('total', total),
133
+ ])
134
+
135
+
136
+ def _merge_eval(main_eval, new_eval, prefix):
137
+ for k in new_eval:
138
+ main_eval['%s_%s' % (prefix, k)] = new_eval[k]
139
+
140
+
141
+ def _make_precision_recall_eval(scores, na_probs, num_true_pos, qid_to_has_ans):
142
+ """Make evaluation dictionary containing average recision recall."""
143
+ qid_list = sorted(na_probs, key=lambda k: na_probs[k])
144
+ true_pos = 0.0
145
+ cur_p = 1.0
146
+ cur_r = 0.0
147
+ precisions = [1.0]
148
+ recalls = [0.0]
149
+ avg_prec = 0.0
150
+ for i, qid in enumerate(qid_list):
151
+ if qid_to_has_ans[qid]:
152
+ true_pos += scores[qid]
153
+ cur_p = true_pos / float(i+1)
154
+ cur_r = true_pos / float(num_true_pos)
155
+ if i == len(qid_list) - 1 or na_probs[qid] != na_probs[qid_list[i+1]]:
156
+ # i.e., if we can put a threshold after this point
157
+ avg_prec += cur_p * (cur_r - recalls[-1])
158
+ precisions.append(cur_p)
159
+ recalls.append(cur_r)
160
+ return {'ap': 100.0 * avg_prec}
161
+
162
+
163
+ def _run_precision_recall_analysis(
164
+ main_eval, exact_raw, f1_raw, na_probs, qid_to_has_ans):
165
+ """Run precision recall analysis and return result dictionary."""
166
+ num_true_pos = sum(1 for v in qid_to_has_ans.values() if v)
167
+ if num_true_pos == 0:
168
+ return
169
+ pr_exact = _make_precision_recall_eval(
170
+ exact_raw, na_probs, num_true_pos, qid_to_has_ans)
171
+ pr_f1 = _make_precision_recall_eval(
172
+ f1_raw, na_probs, num_true_pos, qid_to_has_ans)
173
+ oracle_scores = {k: float(v) for k, v in qid_to_has_ans.items()}
174
+ pr_oracle = _make_precision_recall_eval(
175
+ oracle_scores, na_probs, num_true_pos, qid_to_has_ans)
176
+ _merge_eval(main_eval, pr_exact, 'pr_exact')
177
+ _merge_eval(main_eval, pr_f1, 'pr_f1')
178
+ _merge_eval(main_eval, pr_oracle, 'pr_oracle')
179
+
180
+
181
+ def _find_best_thresh(predictions, scores, na_probs, qid_to_has_ans):
182
+ """Find the best threshold for no answer probability."""
183
+ num_no_ans = sum(1 for k in qid_to_has_ans if not qid_to_has_ans[k])
184
+ cur_score = num_no_ans
185
+ best_score = cur_score
186
+ best_thresh = 0.0
187
+ qid_list = sorted(na_probs, key=lambda k: na_probs[k])
188
+ for qid in qid_list:
189
+ if qid not in scores: continue
190
+ if qid_to_has_ans[qid]:
191
+ diff = scores[qid]
192
+ else:
193
+ if predictions[qid]:
194
+ diff = -1
195
+ else:
196
+ diff = 0
197
+ cur_score += diff
198
+ if cur_score > best_score:
199
+ best_score = cur_score
200
+ best_thresh = na_probs[qid]
201
+ return 100.0 * best_score / len(scores), best_thresh
202
+
203
+
204
+ def _find_all_best_thresh(
205
+ main_eval, predictions, exact_raw, f1_raw, na_probs, qid_to_has_ans):
206
+ best_exact, exact_thresh = _find_best_thresh(
207
+ predictions, exact_raw, na_probs, qid_to_has_ans)
208
+ best_f1, f1_thresh = _find_best_thresh(
209
+ predictions, f1_raw, na_probs, qid_to_has_ans)
210
+ main_eval['final_exact'] = best_exact
211
+ main_eval['final_exact_thresh'] = exact_thresh
212
+ main_eval['final_f1'] = best_f1
213
+ main_eval['final_f1_thresh'] = f1_thresh
214
+
215
+
216
+ def evaluate(dataset, predictions, na_probs=None):
217
+ """Evaluate prediction results."""
218
+ new_orig_data = []
219
+ for article in dataset:
220
+ for p in article['paragraphs']:
221
+ for qa in p['qas']:
222
+ if qa['id'] in predictions:
223
+ new_para = {'qas': [qa]}
224
+ new_article = {'paragraphs': [new_para]}
225
+ new_orig_data.append(new_article)
226
+ dataset = new_orig_data
227
+
228
+ if na_probs is None:
229
+ na_probs = {k: 0.0 for k in predictions}
230
+ qid_to_has_ans = _make_qid_to_has_ans(dataset) # maps qid to True/False
231
+ has_ans_qids = [k for k, v in qid_to_has_ans.items() if v]
232
+ no_ans_qids = [k for k, v in qid_to_has_ans.items() if not v]
233
+ exact_raw, f1_raw = _get_raw_scores(dataset, predictions)
234
+ exact_thresh = _apply_no_ans_threshold(exact_raw, na_probs, qid_to_has_ans)
235
+ f1_thresh = _apply_no_ans_threshold(f1_raw, na_probs, qid_to_has_ans)
236
+ out_eval = _make_eval_dict(exact_thresh, f1_thresh)
237
+ if has_ans_qids:
238
+ has_ans_eval = _make_eval_dict(
239
+ exact_thresh, f1_thresh, qid_list=has_ans_qids)
240
+ _merge_eval(out_eval, has_ans_eval, 'HasAns')
241
+ if no_ans_qids:
242
+ no_ans_eval = _make_eval_dict(exact_thresh, f1_thresh, qid_list=no_ans_qids)
243
+ _merge_eval(out_eval, no_ans_eval, 'NoAns')
244
+
245
+ _find_all_best_thresh(
246
+ out_eval, predictions, exact_raw, f1_raw, na_probs, qid_to_has_ans)
247
+ _run_precision_recall_analysis(
248
+ out_eval, exact_raw, f1_raw, na_probs, qid_to_has_ans)
249
+ return out_eval
tf1_bert_checkpoint_converter_lib.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The TensorFlow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ r"""Convert checkpoints created by Estimator (tf1) to be Keras compatible."""
16
+
17
+ import numpy as np
18
+ import tensorflow.compat.v1 as tf # TF 1.x
19
+
20
+ # Mapping between old <=> new names. The source pattern in original variable
21
+ # name will be replaced by destination pattern.
22
+ BERT_NAME_REPLACEMENTS = (
23
+ ("bert", "bert_model"),
24
+ ("embeddings/word_embeddings", "word_embeddings/embeddings"),
25
+ ("embeddings/token_type_embeddings",
26
+ "embedding_postprocessor/type_embeddings"),
27
+ ("embeddings/position_embeddings",
28
+ "embedding_postprocessor/position_embeddings"),
29
+ ("embeddings/LayerNorm", "embedding_postprocessor/layer_norm"),
30
+ ("attention/self", "self_attention"),
31
+ ("attention/output/dense", "self_attention_output"),
32
+ ("attention/output/LayerNorm", "self_attention_layer_norm"),
33
+ ("intermediate/dense", "intermediate"),
34
+ ("output/dense", "output"),
35
+ ("output/LayerNorm", "output_layer_norm"),
36
+ ("pooler/dense", "pooler_transform"),
37
+ )
38
+
39
+ BERT_V2_NAME_REPLACEMENTS = (
40
+ ("bert/", ""),
41
+ ("encoder", "transformer"),
42
+ ("embeddings/word_embeddings", "word_embeddings/embeddings"),
43
+ ("embeddings/token_type_embeddings", "type_embeddings/embeddings"),
44
+ ("embeddings/position_embeddings", "position_embedding/embeddings"),
45
+ ("embeddings/LayerNorm", "embeddings/layer_norm"),
46
+ ("attention/self", "self_attention"),
47
+ ("attention/output/dense", "self_attention/attention_output"),
48
+ ("attention/output/LayerNorm", "self_attention_layer_norm"),
49
+ ("intermediate/dense", "intermediate"),
50
+ ("output/dense", "output"),
51
+ ("output/LayerNorm", "output_layer_norm"),
52
+ ("pooler/dense", "pooler_transform"),
53
+ ("cls/predictions", "bert/cls/predictions"),
54
+ ("cls/predictions/output_bias", "cls/predictions/output_bias/bias"),
55
+ ("cls/seq_relationship/output_bias", "predictions/transform/logits/bias"),
56
+ ("cls/seq_relationship/output_weights",
57
+ "predictions/transform/logits/kernel"),
58
+ )
59
+
60
+ BERT_PERMUTATIONS = ()
61
+
62
+ BERT_V2_PERMUTATIONS = (("cls/seq_relationship/output_weights", (1, 0)),)
63
+
64
+
65
+ def _bert_name_replacement(var_name, name_replacements):
66
+ """Gets the variable name replacement."""
67
+ for src_pattern, tgt_pattern in name_replacements:
68
+ if src_pattern in var_name:
69
+ old_var_name = var_name
70
+ var_name = var_name.replace(src_pattern, tgt_pattern)
71
+ tf.logging.info("Converted: %s --> %s", old_var_name, var_name)
72
+ return var_name
73
+
74
+
75
+ def _has_exclude_patterns(name, exclude_patterns):
76
+ """Checks if a string contains substrings that match patterns to exclude."""
77
+ for p in exclude_patterns:
78
+ if p in name:
79
+ return True
80
+ return False
81
+
82
+
83
+ def _get_permutation(name, permutations):
84
+ """Checks whether a variable requires transposition by pattern matching."""
85
+ for src_pattern, permutation in permutations:
86
+ if src_pattern in name:
87
+ tf.logging.info("Permuted: %s --> %s", name, permutation)
88
+ return permutation
89
+
90
+ return None
91
+
92
+
93
+ def _get_new_shape(name, shape, num_heads):
94
+ """Checks whether a variable requires reshape by pattern matching."""
95
+ if "self_attention/attention_output/kernel" in name:
96
+ return tuple([num_heads, shape[0] // num_heads, shape[1]])
97
+ if "self_attention/attention_output/bias" in name:
98
+ return shape
99
+
100
+ patterns = [
101
+ "self_attention/query", "self_attention/value", "self_attention/key"
102
+ ]
103
+ for pattern in patterns:
104
+ if pattern in name:
105
+ if "kernel" in name:
106
+ return tuple([shape[0], num_heads, shape[1] // num_heads])
107
+ if "bias" in name:
108
+ return tuple([num_heads, shape[0] // num_heads])
109
+ return None
110
+
111
+
112
+ def create_v2_checkpoint(model,
113
+ src_checkpoint,
114
+ output_path,
115
+ checkpoint_model_name="model"):
116
+ """Converts a name-based matched TF V1 checkpoint to TF V2 checkpoint."""
117
+ # Uses streaming-restore in eager model to read V1 name-based checkpoints.
118
+ model.load_weights(src_checkpoint).assert_existing_objects_matched()
119
+ if hasattr(model, "checkpoint_items"):
120
+ checkpoint_items = model.checkpoint_items
121
+ else:
122
+ checkpoint_items = {}
123
+
124
+ checkpoint_items[checkpoint_model_name] = model
125
+ checkpoint = tf.train.Checkpoint(**checkpoint_items)
126
+ checkpoint.save(output_path)
127
+
128
+
129
+ def convert(checkpoint_from_path,
130
+ checkpoint_to_path,
131
+ num_heads,
132
+ name_replacements,
133
+ permutations,
134
+ exclude_patterns=None):
135
+ """Migrates the names of variables within a checkpoint.
136
+
137
+ Args:
138
+ checkpoint_from_path: Path to source checkpoint to be read in.
139
+ checkpoint_to_path: Path to checkpoint to be written out.
140
+ num_heads: The number of heads of the model.
141
+ name_replacements: A list of tuples of the form (match_str, replace_str)
142
+ describing variable names to adjust.
143
+ permutations: A list of tuples of the form (match_str, permutation)
144
+ describing permutations to apply to given variables. Note that match_str
145
+ should match the original variable name, not the replaced one.
146
+ exclude_patterns: A list of string patterns to exclude variables from
147
+ checkpoint conversion.
148
+
149
+ Returns:
150
+ A dictionary that maps the new variable names to the Variable objects.
151
+ A dictionary that maps the old variable names to the new variable names.
152
+ """
153
+ with tf.Graph().as_default():
154
+ tf.logging.info("Reading checkpoint_from_path %s", checkpoint_from_path)
155
+ reader = tf.train.NewCheckpointReader(checkpoint_from_path)
156
+ name_shape_map = reader.get_variable_to_shape_map()
157
+ new_variable_map = {}
158
+ conversion_map = {}
159
+ for var_name in name_shape_map:
160
+ if exclude_patterns and _has_exclude_patterns(var_name, exclude_patterns):
161
+ continue
162
+ # Get the original tensor data.
163
+ tensor = reader.get_tensor(var_name)
164
+
165
+ # Look up the new variable name, if any.
166
+ new_var_name = _bert_name_replacement(var_name, name_replacements)
167
+
168
+ # See if we need to reshape the underlying tensor.
169
+ new_shape = None
170
+ if num_heads > 0:
171
+ new_shape = _get_new_shape(new_var_name, tensor.shape, num_heads)
172
+ if new_shape:
173
+ tf.logging.info("Veriable %s has a shape change from %s to %s",
174
+ var_name, tensor.shape, new_shape)
175
+ tensor = np.reshape(tensor, new_shape)
176
+
177
+ # See if we need to permute the underlying tensor.
178
+ permutation = _get_permutation(var_name, permutations)
179
+ if permutation:
180
+ tensor = np.transpose(tensor, permutation)
181
+
182
+ # Create a new variable with the possibly-reshaped or transposed tensor.
183
+ var = tf.Variable(tensor, name=var_name)
184
+
185
+ # Save the variable into the new variable map.
186
+ new_variable_map[new_var_name] = var
187
+
188
+ # Keep a list of converter variables for sanity checking.
189
+ if new_var_name != var_name:
190
+ conversion_map[var_name] = new_var_name
191
+
192
+ saver = tf.train.Saver(new_variable_map)
193
+
194
+ with tf.Session() as sess:
195
+ sess.run(tf.global_variables_initializer())
196
+ tf.logging.info("Writing checkpoint_to_path %s", checkpoint_to_path)
197
+ saver.save(sess, checkpoint_to_path, write_meta_graph=False)
198
+
199
+ tf.logging.info("Summary:")
200
+ tf.logging.info(" Converted %d variable name(s).", len(new_variable_map))
201
+ tf.logging.info(" Converted: %s", str(conversion_map))
tf2_albert_encoder_checkpoint_converter.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The TensorFlow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """A converter from a tf1 ALBERT encoder checkpoint to a tf2 encoder checkpoint.
16
+
17
+ The conversion will yield an object-oriented checkpoint that can be used
18
+ to restore an AlbertEncoder object.
19
+ """
20
+ import os
21
+
22
+ from absl import app
23
+ from absl import flags
24
+
25
+ import tensorflow as tf, tf_keras
26
+ from official.legacy.albert import configs
27
+ from official.modeling import tf_utils
28
+ from official.nlp.modeling import models
29
+ from official.nlp.modeling import networks
30
+ from official.nlp.tools import tf1_bert_checkpoint_converter_lib
31
+
32
+ FLAGS = flags.FLAGS
33
+
34
+ flags.DEFINE_string("albert_config_file", None,
35
+ "Albert configuration file to define core bert layers.")
36
+ flags.DEFINE_string(
37
+ "checkpoint_to_convert", None,
38
+ "Initial checkpoint from a pretrained BERT model core (that is, only the "
39
+ "BertModel, with no task heads.)")
40
+ flags.DEFINE_string("converted_checkpoint_path", None,
41
+ "Name for the created object-based V2 checkpoint.")
42
+ flags.DEFINE_string("checkpoint_model_name", "encoder",
43
+ "The name of the model when saving the checkpoint, i.e., "
44
+ "the checkpoint will be saved using: "
45
+ "tf.train.Checkpoint(FLAGS.checkpoint_model_name=model).")
46
+ flags.DEFINE_enum(
47
+ "converted_model", "encoder", ["encoder", "pretrainer"],
48
+ "Whether to convert the checkpoint to a `AlbertEncoder` model or a "
49
+ "`BertPretrainerV2` model (with mlm but without classification heads).")
50
+
51
+
52
+ ALBERT_NAME_REPLACEMENTS = (
53
+ ("bert/encoder/", ""),
54
+ ("bert/", ""),
55
+ ("embeddings/word_embeddings", "word_embeddings/embeddings"),
56
+ ("embeddings/position_embeddings", "position_embedding/embeddings"),
57
+ ("embeddings/token_type_embeddings", "type_embeddings/embeddings"),
58
+ ("embeddings/LayerNorm", "embeddings/layer_norm"),
59
+ ("embedding_hidden_mapping_in", "embedding_projection"),
60
+ ("group_0/inner_group_0/", ""),
61
+ ("attention_1/self", "self_attention"),
62
+ ("attention_1/output/dense", "self_attention/attention_output"),
63
+ ("transformer/LayerNorm/", "transformer/self_attention_layer_norm/"),
64
+ ("ffn_1/intermediate/dense", "intermediate"),
65
+ ("ffn_1/intermediate/output/dense", "output"),
66
+ ("transformer/LayerNorm_1/", "transformer/output_layer_norm/"),
67
+ ("pooler/dense", "pooler_transform"),
68
+ ("cls/predictions", "bert/cls/predictions"),
69
+ ("cls/predictions/output_bias", "cls/predictions/output_bias/bias"),
70
+ ("cls/seq_relationship/output_bias", "predictions/transform/logits/bias"),
71
+ ("cls/seq_relationship/output_weights",
72
+ "predictions/transform/logits/kernel"),
73
+ )
74
+
75
+
76
+ def _create_albert_model(cfg):
77
+ """Creates an ALBERT keras core model from BERT configuration.
78
+
79
+ Args:
80
+ cfg: A `AlbertConfig` to create the core model.
81
+
82
+ Returns:
83
+ A keras model.
84
+ """
85
+ albert_encoder = networks.AlbertEncoder(
86
+ vocab_size=cfg.vocab_size,
87
+ hidden_size=cfg.hidden_size,
88
+ embedding_width=cfg.embedding_size,
89
+ num_layers=cfg.num_hidden_layers,
90
+ num_attention_heads=cfg.num_attention_heads,
91
+ intermediate_size=cfg.intermediate_size,
92
+ activation=tf_utils.get_activation(cfg.hidden_act),
93
+ dropout_rate=cfg.hidden_dropout_prob,
94
+ attention_dropout_rate=cfg.attention_probs_dropout_prob,
95
+ max_sequence_length=cfg.max_position_embeddings,
96
+ type_vocab_size=cfg.type_vocab_size,
97
+ initializer=tf_keras.initializers.TruncatedNormal(
98
+ stddev=cfg.initializer_range))
99
+ return albert_encoder
100
+
101
+
102
+ def _create_pretrainer_model(cfg):
103
+ """Creates a pretrainer with AlbertEncoder from ALBERT configuration.
104
+
105
+ Args:
106
+ cfg: A `BertConfig` to create the core model.
107
+
108
+ Returns:
109
+ A BertPretrainerV2 model.
110
+ """
111
+ albert_encoder = _create_albert_model(cfg)
112
+ pretrainer = models.BertPretrainerV2(
113
+ encoder_network=albert_encoder,
114
+ mlm_activation=tf_utils.get_activation(cfg.hidden_act),
115
+ mlm_initializer=tf_keras.initializers.TruncatedNormal(
116
+ stddev=cfg.initializer_range))
117
+ # Makes sure masked_lm layer's variables in pretrainer are created.
118
+ _ = pretrainer(pretrainer.inputs)
119
+ return pretrainer
120
+
121
+
122
+ def convert_checkpoint(bert_config, output_path, v1_checkpoint,
123
+ checkpoint_model_name,
124
+ converted_model="encoder"):
125
+ """Converts a V1 checkpoint into an OO V2 checkpoint."""
126
+ output_dir, _ = os.path.split(output_path)
127
+
128
+ # Create a temporary V1 name-converted checkpoint in the output directory.
129
+ temporary_checkpoint_dir = os.path.join(output_dir, "temp_v1")
130
+ temporary_checkpoint = os.path.join(temporary_checkpoint_dir, "ckpt")
131
+ tf1_bert_checkpoint_converter_lib.convert(
132
+ checkpoint_from_path=v1_checkpoint,
133
+ checkpoint_to_path=temporary_checkpoint,
134
+ num_heads=bert_config.num_attention_heads,
135
+ name_replacements=ALBERT_NAME_REPLACEMENTS,
136
+ permutations=tf1_bert_checkpoint_converter_lib.BERT_V2_PERMUTATIONS,
137
+ exclude_patterns=["adam", "Adam"])
138
+
139
+ # Create a V2 checkpoint from the temporary checkpoint.
140
+ if converted_model == "encoder":
141
+ model = _create_albert_model(bert_config)
142
+ elif converted_model == "pretrainer":
143
+ model = _create_pretrainer_model(bert_config)
144
+ else:
145
+ raise ValueError("Unsupported converted_model: %s" % converted_model)
146
+
147
+ tf1_bert_checkpoint_converter_lib.create_v2_checkpoint(
148
+ model, temporary_checkpoint, output_path, checkpoint_model_name)
149
+
150
+ # Clean up the temporary checkpoint, if it exists.
151
+ try:
152
+ tf.io.gfile.rmtree(temporary_checkpoint_dir)
153
+ except tf.errors.OpError:
154
+ # If it doesn't exist, we don't need to clean it up; continue.
155
+ pass
156
+
157
+
158
+ def main(_):
159
+ output_path = FLAGS.converted_checkpoint_path
160
+ v1_checkpoint = FLAGS.checkpoint_to_convert
161
+ checkpoint_model_name = FLAGS.checkpoint_model_name
162
+ converted_model = FLAGS.converted_model
163
+ albert_config = configs.AlbertConfig.from_json_file(FLAGS.albert_config_file)
164
+ convert_checkpoint(albert_config, output_path, v1_checkpoint,
165
+ checkpoint_model_name,
166
+ converted_model=converted_model)
167
+
168
+
169
+ if __name__ == "__main__":
170
+ app.run(main)
tf2_bert_encoder_checkpoint_converter.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The TensorFlow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """A converter from a V1 BERT encoder checkpoint to a V2 encoder checkpoint.
16
+
17
+ The conversion will yield an object-oriented checkpoint that can be used
18
+ to restore a BertEncoder or BertPretrainerV2 object (see the `converted_model`
19
+ FLAG below).
20
+ """
21
+
22
+ import os
23
+
24
+ from absl import app
25
+ from absl import flags
26
+
27
+ import tensorflow as tf, tf_keras
28
+ from official.legacy.bert import configs
29
+ from official.modeling import tf_utils
30
+ from official.nlp.modeling import models
31
+ from official.nlp.modeling import networks
32
+ from official.nlp.tools import tf1_bert_checkpoint_converter_lib
33
+
34
+ FLAGS = flags.FLAGS
35
+
36
+ flags.DEFINE_string("bert_config_file", None,
37
+ "Bert configuration file to define core bert layers.")
38
+ flags.DEFINE_string(
39
+ "checkpoint_to_convert", None,
40
+ "Initial checkpoint from a pretrained BERT model core (that is, only the "
41
+ "BertModel, with no task heads.)")
42
+ flags.DEFINE_string("converted_checkpoint_path", None,
43
+ "Name for the created object-based V2 checkpoint.")
44
+ flags.DEFINE_string("checkpoint_model_name", "encoder",
45
+ "The name of the model when saving the checkpoint, i.e., "
46
+ "the checkpoint will be saved using: "
47
+ "tf.train.Checkpoint(FLAGS.checkpoint_model_name=model).")
48
+ flags.DEFINE_enum(
49
+ "converted_model", "encoder", ["encoder", "pretrainer"],
50
+ "Whether to convert the checkpoint to a `BertEncoder` model or a "
51
+ "`BertPretrainerV2` model (with mlm but without classification heads).")
52
+
53
+
54
+ def _create_bert_model(cfg):
55
+ """Creates a BERT keras core model from BERT configuration.
56
+
57
+ Args:
58
+ cfg: A `BertConfig` to create the core model.
59
+
60
+ Returns:
61
+ A BertEncoder network.
62
+ """
63
+ bert_encoder = networks.BertEncoder(
64
+ vocab_size=cfg.vocab_size,
65
+ hidden_size=cfg.hidden_size,
66
+ num_layers=cfg.num_hidden_layers,
67
+ num_attention_heads=cfg.num_attention_heads,
68
+ intermediate_size=cfg.intermediate_size,
69
+ activation=tf_utils.get_activation(cfg.hidden_act),
70
+ dropout_rate=cfg.hidden_dropout_prob,
71
+ attention_dropout_rate=cfg.attention_probs_dropout_prob,
72
+ max_sequence_length=cfg.max_position_embeddings,
73
+ type_vocab_size=cfg.type_vocab_size,
74
+ initializer=tf_keras.initializers.TruncatedNormal(
75
+ stddev=cfg.initializer_range),
76
+ embedding_width=cfg.embedding_size)
77
+
78
+ return bert_encoder
79
+
80
+
81
+ def _create_bert_pretrainer_model(cfg):
82
+ """Creates a BERT keras core model from BERT configuration.
83
+
84
+ Args:
85
+ cfg: A `BertConfig` to create the core model.
86
+
87
+ Returns:
88
+ A BertPretrainerV2 model.
89
+ """
90
+ bert_encoder = _create_bert_model(cfg)
91
+ pretrainer = models.BertPretrainerV2(
92
+ encoder_network=bert_encoder,
93
+ mlm_activation=tf_utils.get_activation(cfg.hidden_act),
94
+ mlm_initializer=tf_keras.initializers.TruncatedNormal(
95
+ stddev=cfg.initializer_range))
96
+ # Makes sure the pretrainer variables are created.
97
+ _ = pretrainer(pretrainer.inputs)
98
+ return pretrainer
99
+
100
+
101
+ def convert_checkpoint(bert_config,
102
+ output_path,
103
+ v1_checkpoint,
104
+ checkpoint_model_name="model",
105
+ converted_model="encoder"):
106
+ """Converts a V1 checkpoint into an OO V2 checkpoint."""
107
+ output_dir, _ = os.path.split(output_path)
108
+ tf.io.gfile.makedirs(output_dir)
109
+
110
+ # Create a temporary V1 name-converted checkpoint in the output directory.
111
+ temporary_checkpoint_dir = os.path.join(output_dir, "temp_v1")
112
+ temporary_checkpoint = os.path.join(temporary_checkpoint_dir, "ckpt")
113
+
114
+ tf1_bert_checkpoint_converter_lib.convert(
115
+ checkpoint_from_path=v1_checkpoint,
116
+ checkpoint_to_path=temporary_checkpoint,
117
+ num_heads=bert_config.num_attention_heads,
118
+ name_replacements=(
119
+ tf1_bert_checkpoint_converter_lib.BERT_V2_NAME_REPLACEMENTS),
120
+ permutations=tf1_bert_checkpoint_converter_lib.BERT_V2_PERMUTATIONS,
121
+ exclude_patterns=["adam", "Adam"])
122
+
123
+ if converted_model == "encoder":
124
+ model = _create_bert_model(bert_config)
125
+ elif converted_model == "pretrainer":
126
+ model = _create_bert_pretrainer_model(bert_config)
127
+ else:
128
+ raise ValueError("Unsupported converted_model: %s" % converted_model)
129
+
130
+ # Create a V2 checkpoint from the temporary checkpoint.
131
+ tf1_bert_checkpoint_converter_lib.create_v2_checkpoint(
132
+ model, temporary_checkpoint, output_path, checkpoint_model_name)
133
+
134
+ # Clean up the temporary checkpoint, if it exists.
135
+ try:
136
+ tf.io.gfile.rmtree(temporary_checkpoint_dir)
137
+ except tf.errors.OpError:
138
+ # If it doesn't exist, we don't need to clean it up; continue.
139
+ pass
140
+
141
+
142
+ def main(argv):
143
+ if len(argv) > 1:
144
+ raise app.UsageError("Too many command-line arguments.")
145
+
146
+ output_path = FLAGS.converted_checkpoint_path
147
+ v1_checkpoint = FLAGS.checkpoint_to_convert
148
+ checkpoint_model_name = FLAGS.checkpoint_model_name
149
+ converted_model = FLAGS.converted_model
150
+ bert_config = configs.BertConfig.from_json_file(FLAGS.bert_config_file)
151
+ convert_checkpoint(
152
+ bert_config=bert_config,
153
+ output_path=output_path,
154
+ v1_checkpoint=v1_checkpoint,
155
+ checkpoint_model_name=checkpoint_model_name,
156
+ converted_model=converted_model)
157
+
158
+
159
+ if __name__ == "__main__":
160
+ app.run(main)
tokenization_test.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The TensorFlow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import os
16
+ import tempfile
17
+
18
+ import six
19
+ import tensorflow as tf, tf_keras
20
+
21
+ from official.nlp.tools import tokenization
22
+
23
+
24
+ class TokenizationTest(tf.test.TestCase):
25
+ """Tokenization test.
26
+
27
+ The implementation is forked from
28
+ https://github.com/google-research/bert/blob/master/tokenization_test.py."
29
+ """
30
+
31
+ def test_full_tokenizer(self):
32
+ vocab_tokens = [
33
+ "[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn",
34
+ "##ing", ","
35
+ ]
36
+ with tempfile.NamedTemporaryFile(delete=False) as vocab_writer:
37
+ if six.PY2:
38
+ vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))
39
+ else:
40
+ vocab_writer.write("".join([x + "\n" for x in vocab_tokens
41
+ ]).encode("utf-8"))
42
+
43
+ vocab_file = vocab_writer.name
44
+
45
+ tokenizer = tokenization.FullTokenizer(vocab_file)
46
+ os.unlink(vocab_file)
47
+
48
+ tokens = tokenizer.tokenize(u"UNwant\u00E9d,running")
49
+ self.assertAllEqual(tokens, ["un", "##want", "##ed", ",", "runn", "##ing"])
50
+
51
+ self.assertAllEqual(
52
+ tokenizer.convert_tokens_to_ids(tokens), [7, 4, 5, 10, 8, 9])
53
+
54
+ def test_chinese(self):
55
+ tokenizer = tokenization.BasicTokenizer()
56
+
57
+ self.assertAllEqual(
58
+ tokenizer.tokenize(u"ah\u535A\u63A8zz"),
59
+ [u"ah", u"\u535A", u"\u63A8", u"zz"])
60
+
61
+ def test_basic_tokenizer_lower(self):
62
+ tokenizer = tokenization.BasicTokenizer(do_lower_case=True)
63
+
64
+ self.assertAllEqual(
65
+ tokenizer.tokenize(u" \tHeLLo!how \n Are yoU? "),
66
+ ["hello", "!", "how", "are", "you", "?"])
67
+ self.assertAllEqual(tokenizer.tokenize(u"H\u00E9llo"), ["hello"])
68
+
69
+ def test_basic_tokenizer_no_lower(self):
70
+ tokenizer = tokenization.BasicTokenizer(do_lower_case=False)
71
+
72
+ self.assertAllEqual(
73
+ tokenizer.tokenize(u" \tHeLLo!how \n Are yoU? "),
74
+ ["HeLLo", "!", "how", "Are", "yoU", "?"])
75
+
76
+ def test_basic_tokenizer_no_split_on_punc(self):
77
+ tokenizer = tokenization.BasicTokenizer(
78
+ do_lower_case=True, split_on_punc=False)
79
+
80
+ self.assertAllEqual(
81
+ tokenizer.tokenize(u" \tHeLLo!how \n Are yoU? "),
82
+ ["hello!how", "are", "you?"])
83
+
84
+ def test_wordpiece_tokenizer(self):
85
+ vocab_tokens = [
86
+ "[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn",
87
+ "##ing", "##!", "!"
88
+ ]
89
+
90
+ vocab = {}
91
+ for (i, token) in enumerate(vocab_tokens):
92
+ vocab[token] = i
93
+ tokenizer = tokenization.WordpieceTokenizer(vocab=vocab)
94
+
95
+ self.assertAllEqual(tokenizer.tokenize(""), [])
96
+
97
+ self.assertAllEqual(
98
+ tokenizer.tokenize("unwanted running"),
99
+ ["un", "##want", "##ed", "runn", "##ing"])
100
+
101
+ self.assertAllEqual(
102
+ tokenizer.tokenize("unwanted running !"),
103
+ ["un", "##want", "##ed", "runn", "##ing", "!"])
104
+
105
+ self.assertAllEqual(
106
+ tokenizer.tokenize("unwanted running!"),
107
+ ["un", "##want", "##ed", "runn", "##ing", "##!"])
108
+
109
+ self.assertAllEqual(
110
+ tokenizer.tokenize("unwantedX running"), ["[UNK]", "runn", "##ing"])
111
+
112
+ def test_convert_tokens_to_ids(self):
113
+ vocab_tokens = [
114
+ "[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn",
115
+ "##ing"
116
+ ]
117
+
118
+ vocab = {}
119
+ for (i, token) in enumerate(vocab_tokens):
120
+ vocab[token] = i
121
+
122
+ self.assertAllEqual(
123
+ tokenization.convert_tokens_to_ids(
124
+ vocab, ["un", "##want", "##ed", "runn", "##ing"]), [7, 4, 5, 8, 9])
125
+
126
+ def test_is_whitespace(self):
127
+ self.assertTrue(tokenization._is_whitespace(u" "))
128
+ self.assertTrue(tokenization._is_whitespace(u"\t"))
129
+ self.assertTrue(tokenization._is_whitespace(u"\r"))
130
+ self.assertTrue(tokenization._is_whitespace(u"\n"))
131
+ self.assertTrue(tokenization._is_whitespace(u"\u00A0"))
132
+
133
+ self.assertFalse(tokenization._is_whitespace(u"A"))
134
+ self.assertFalse(tokenization._is_whitespace(u"-"))
135
+
136
+ def test_is_control(self):
137
+ self.assertTrue(tokenization._is_control(u"\u0005"))
138
+
139
+ self.assertFalse(tokenization._is_control(u"A"))
140
+ self.assertFalse(tokenization._is_control(u" "))
141
+ self.assertFalse(tokenization._is_control(u"\t"))
142
+ self.assertFalse(tokenization._is_control(u"\r"))
143
+ self.assertFalse(tokenization._is_control(u"\U0001F4A9"))
144
+
145
+ def test_is_punctuation(self):
146
+ self.assertTrue(tokenization._is_punctuation(u"-"))
147
+ self.assertTrue(tokenization._is_punctuation(u"$"))
148
+ self.assertTrue(tokenization._is_punctuation(u"`"))
149
+ self.assertTrue(tokenization._is_punctuation(u"."))
150
+
151
+ self.assertFalse(tokenization._is_punctuation(u"A"))
152
+ self.assertFalse(tokenization._is_punctuation(u" "))
153
+
154
+
155
+ if __name__ == "__main__":
156
+ tf.test.main()