Spaces:
Running
Running
# Copyright 2024 The TensorFlow Authors. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
r"""Exports a BERT-like encoder and its preprocessing as SavedModels for TF Hub. | |
This tool creates preprocessor and encoder SavedModels suitable for uploading | |
to https://tfhub.dev that implement the preprocessor and encoder APIs defined | |
at https://www.tensorflow.org/hub/common_saved_model_apis/text. | |
For a full usage guide, see | |
https://github.com/tensorflow/models/blob/master/official/nlp/docs/tfhub.md | |
Minimal usage examples: | |
1) Exporting an Encoder from checkpoint and config. | |
``` | |
export_tfhub \ | |
--encoder_config_file=${BERT_DIR:?}/bert_encoder.yaml \ | |
--model_checkpoint_path=${BERT_DIR:?}/bert_model.ckpt \ | |
--vocab_file=${BERT_DIR:?}/vocab.txt \ | |
--export_type=model \ | |
--export_path=/tmp/bert_model | |
``` | |
An --encoder_config_file can specify encoder types other than BERT. | |
For BERT, a --bert_config_file in the legacy JSON format can be passed instead. | |
Flag --vocab_file (and flag --do_lower_case, whose default value is guessed | |
from the vocab_file path) capture how BertTokenizer was used in pre-training. | |
Use flag --sp_model_file instead if SentencepieceTokenizer was used. | |
Changing --export_type to model_with_mlm additionally creates an `.mlm` | |
subobject on the exported SavedModel that can be called to produce | |
the logits of the Masked Language Model task from pretraining. | |
The help string for flag --model_checkpoint_path explains the checkpoint | |
formats required for each --export_type. | |
2) Exporting a preprocessor SavedModel | |
``` | |
export_tfhub \ | |
--vocab_file ${BERT_DIR:?}/vocab.txt \ | |
--export_type preprocessing --export_path /tmp/bert_preprocessing | |
``` | |
Be sure to use flag values that match the encoder and how it has been | |
pre-trained (see above for --vocab_file vs --sp_model_file). | |
If your encoder has been trained with text preprocessing for which tfhub.dev | |
already has SavedModel, you could guide your users to reuse that one instead | |
of exporting and publishing your own. | |
TODO(b/175369555): When exporting to users of TensorFlow 2.4, add flag | |
`--experimental_disable_assert_in_preprocessing`. | |
""" | |
from absl import app | |
from absl import flags | |
import gin | |
from official.legacy.bert import configs | |
from official.modeling import hyperparams | |
from official.nlp.configs import encoders | |
from official.nlp.tools import export_tfhub_lib | |
FLAGS = flags.FLAGS | |
flags.DEFINE_enum( | |
"export_type", "model", | |
["model", "model_with_mlm", "preprocessing"], | |
"The overall type of SavedModel to export. Flags " | |
"--bert_config_file/--encoder_config_file and --vocab_file/--sp_model_file " | |
"control which particular encoder model and preprocessing are exported.") | |
flags.DEFINE_string( | |
"export_path", None, | |
"Directory to which the SavedModel is written.") | |
flags.DEFINE_string( | |
"encoder_config_file", None, | |
"A yaml file representing `encoders.EncoderConfig` to define the encoder " | |
"(BERT or other). " | |
"Exactly one of --bert_config_file and --encoder_config_file can be set. " | |
"Needed for --export_type model and model_with_mlm.") | |
flags.DEFINE_string( | |
"bert_config_file", None, | |
"A JSON file with a legacy BERT configuration to define the BERT encoder. " | |
"Exactly one of --bert_config_file and --encoder_config_file can be set. " | |
"Needed for --export_type model and model_with_mlm.") | |
flags.DEFINE_bool( | |
"copy_pooler_dense_to_encoder", False, | |
"When the model is trained using `BertPretrainerV2`, the pool layer " | |
"of next sentence prediction task exists in `ClassificationHead` passed " | |
"to `BertPretrainerV2`. If True, we will copy this pooler's dense layer " | |
"to the encoder that is exported by this tool (as in classic BERT). " | |
"Using `BertPretrainerV2` and leaving this False exports an untrained " | |
"(randomly initialized) pooling layer, which some authors recommend for " | |
"subsequent fine-tuning,") | |
flags.DEFINE_string( | |
"model_checkpoint_path", None, | |
"File path to a pre-trained model checkpoint. " | |
"For --export_type model, this has to be an object-based (TF2) checkpoint " | |
"that can be restored to `tf.train.Checkpoint(encoder=encoder)` " | |
"for the `encoder` defined by the config file." | |
"(Legacy checkpoints with `model=` instead of `encoder=` are also " | |
"supported for now.) " | |
"For --export_type model_with_mlm, it must be restorable to " | |
"`tf.train.Checkpoint(**BertPretrainerV2(...).checkpoint_items)`. " | |
"(For now, `tf.train.Checkpoint(pretrainer=BertPretrainerV2(...))` is also " | |
"accepted.)") | |
flags.DEFINE_string( | |
"vocab_file", None, | |
"For encoders trained on BertTokenzier input: " | |
"the vocabulary file that the encoder model was trained with. " | |
"Exactly one of --vocab_file and --sp_model_file can be set. " | |
"Needed for --export_type model, model_with_mlm and preprocessing.") | |
flags.DEFINE_string( | |
"sp_model_file", None, | |
"For encoders trained on SentencepieceTokenzier input: " | |
"the SentencePiece .model file that the encoder model was trained with. " | |
"Exactly one of --vocab_file and --sp_model_file can be set. " | |
"Needed for --export_type model, model_with_mlm and preprocessing.") | |
flags.DEFINE_bool( | |
"do_lower_case", None, | |
"Whether to lowercase before tokenization. " | |
"If left as None, and --vocab_file is set, do_lower_case will be enabled " | |
"if 'uncased' appears in the name of --vocab_file. " | |
"If left as None, and --sp_model_file set, do_lower_case defaults to true. " | |
"Needed for --export_type model, model_with_mlm and preprocessing.") | |
flags.DEFINE_integer( | |
"default_seq_length", 128, | |
"The sequence length of preprocessing results from " | |
"top-level preprocess method. This is also the default " | |
"sequence length for the bert_pack_inputs subobject." | |
"Needed for --export_type preprocessing.") | |
flags.DEFINE_bool( | |
"tokenize_with_offsets", False, # TODO(b/181866850) | |
"Whether to export a .tokenize_with_offsets subobject for " | |
"--export_type preprocessing.") | |
flags.DEFINE_multi_string( | |
"gin_file", default=None, | |
help="List of paths to the config files.") | |
flags.DEFINE_multi_string( | |
"gin_params", default=None, | |
help="List of Gin bindings.") | |
flags.DEFINE_bool( # TODO(b/175369555): Remove this flag and its use. | |
"experimental_disable_assert_in_preprocessing", False, | |
"Export a preprocessing model without tf.Assert ops. " | |
"Usually, that would be a bad idea, except TF2.4 has an issue with " | |
"Assert ops in tf.functions used in Dataset.map() on a TPU worker, " | |
"and omitting the Assert ops lets SavedModels avoid the issue.") | |
def main(argv): | |
if len(argv) > 1: | |
raise app.UsageError("Too many command-line arguments.") | |
gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_params) | |
if bool(FLAGS.vocab_file) == bool(FLAGS.sp_model_file): | |
raise ValueError("Exactly one of `vocab_file` and `sp_model_file` " | |
"can be specified, but got %s and %s." % | |
(FLAGS.vocab_file, FLAGS.sp_model_file)) | |
do_lower_case = export_tfhub_lib.get_do_lower_case( | |
FLAGS.do_lower_case, FLAGS.vocab_file, FLAGS.sp_model_file) | |
if FLAGS.export_type in ("model", "model_with_mlm"): | |
if bool(FLAGS.bert_config_file) == bool(FLAGS.encoder_config_file): | |
raise ValueError("Exactly one of `bert_config_file` and " | |
"`encoder_config_file` can be specified, but got " | |
"%s and %s." % | |
(FLAGS.bert_config_file, FLAGS.encoder_config_file)) | |
if FLAGS.bert_config_file: | |
bert_config = configs.BertConfig.from_json_file(FLAGS.bert_config_file) | |
encoder_config = None | |
else: | |
bert_config = None | |
encoder_config = encoders.EncoderConfig() | |
encoder_config = hyperparams.override_params_dict( | |
encoder_config, FLAGS.encoder_config_file, is_strict=True) | |
export_tfhub_lib.export_model( | |
FLAGS.export_path, | |
bert_config=bert_config, | |
encoder_config=encoder_config, | |
model_checkpoint_path=FLAGS.model_checkpoint_path, | |
vocab_file=FLAGS.vocab_file, | |
sp_model_file=FLAGS.sp_model_file, | |
do_lower_case=do_lower_case, | |
with_mlm=FLAGS.export_type == "model_with_mlm", | |
copy_pooler_dense_to_encoder=FLAGS.copy_pooler_dense_to_encoder) | |
elif FLAGS.export_type == "preprocessing": | |
export_tfhub_lib.export_preprocessing( | |
FLAGS.export_path, | |
vocab_file=FLAGS.vocab_file, | |
sp_model_file=FLAGS.sp_model_file, | |
do_lower_case=do_lower_case, | |
default_seq_length=FLAGS.default_seq_length, | |
tokenize_with_offsets=FLAGS.tokenize_with_offsets, | |
experimental_disable_assert= | |
FLAGS.experimental_disable_assert_in_preprocessing) | |
else: | |
raise app.UsageError( | |
"Unknown value '%s' for flag --export_type" % FLAGS.export_type) | |
if __name__ == "__main__": | |
app.run(main) | |