File size: 9,662 Bytes
c130734
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

r"""Exports a BERT-like encoder and its preprocessing as SavedModels for TF Hub.



This tool creates preprocessor and encoder SavedModels suitable for uploading

to https://tfhub.dev that implement the preprocessor and encoder APIs defined

at https://www.tensorflow.org/hub/common_saved_model_apis/text.



For a full usage guide, see

https://github.com/tensorflow/models/blob/master/official/nlp/docs/tfhub.md



Minimal usage examples:



1) Exporting an Encoder from checkpoint and config.



```

export_tfhub \

  --encoder_config_file=${BERT_DIR:?}/bert_encoder.yaml \

  --model_checkpoint_path=${BERT_DIR:?}/bert_model.ckpt \

  --vocab_file=${BERT_DIR:?}/vocab.txt \

  --export_type=model \

  --export_path=/tmp/bert_model

```



An --encoder_config_file can specify encoder types other than BERT.

For BERT, a --bert_config_file in the legacy JSON format can be passed instead.



Flag --vocab_file (and flag --do_lower_case, whose default value is guessed

from the vocab_file path) capture how BertTokenizer was used in pre-training.

Use flag --sp_model_file instead if SentencepieceTokenizer was used.



Changing --export_type to model_with_mlm additionally creates an `.mlm`

subobject on the exported SavedModel that can be called to produce

the logits of the Masked Language Model task from pretraining.

The help string for flag --model_checkpoint_path explains the checkpoint

formats required for each --export_type.





2) Exporting a preprocessor SavedModel



```

export_tfhub \

  --vocab_file ${BERT_DIR:?}/vocab.txt \

  --export_type preprocessing --export_path /tmp/bert_preprocessing

```



Be sure to use flag values that match the encoder and how it has been

pre-trained (see above for --vocab_file vs --sp_model_file).



If your encoder has been trained with text preprocessing for which tfhub.dev

already has SavedModel, you could guide your users to reuse that one instead

of exporting and publishing your own.



TODO(b/175369555): When exporting to users of TensorFlow 2.4, add flag

`--experimental_disable_assert_in_preprocessing`.

"""

from absl import app
from absl import flags
import gin

from official.legacy.bert import configs
from official.modeling import hyperparams
from official.nlp.configs import encoders
from official.nlp.tools import export_tfhub_lib

FLAGS = flags.FLAGS

flags.DEFINE_enum(
    "export_type", "model",
    ["model", "model_with_mlm", "preprocessing"],
    "The overall type of SavedModel to export. Flags "
    "--bert_config_file/--encoder_config_file and --vocab_file/--sp_model_file "
    "control which particular encoder model and preprocessing are exported.")
flags.DEFINE_string(
    "export_path", None,
    "Directory to which the SavedModel is written.")
flags.DEFINE_string(
    "encoder_config_file", None,
    "A yaml file representing `encoders.EncoderConfig` to define the encoder "
    "(BERT or other). "
    "Exactly one of --bert_config_file and --encoder_config_file can be set. "
    "Needed for --export_type model and model_with_mlm.")
flags.DEFINE_string(
    "bert_config_file", None,
    "A JSON file with a legacy BERT configuration to define the BERT encoder. "
    "Exactly one of --bert_config_file and --encoder_config_file can be set. "
    "Needed for --export_type model and model_with_mlm.")
flags.DEFINE_bool(
    "copy_pooler_dense_to_encoder", False,
    "When the model is trained using `BertPretrainerV2`, the pool layer "
    "of next sentence prediction task exists in `ClassificationHead` passed "
    "to `BertPretrainerV2`. If True, we will copy this pooler's dense layer "
    "to the encoder that is exported by this tool (as in classic BERT). "
    "Using `BertPretrainerV2` and leaving this False exports an untrained "
    "(randomly initialized) pooling layer, which some authors recommend for "
    "subsequent fine-tuning,")
flags.DEFINE_string(
    "model_checkpoint_path", None,
    "File path to a pre-trained model checkpoint. "
    "For --export_type model, this has to be an object-based (TF2) checkpoint "
    "that can be restored to `tf.train.Checkpoint(encoder=encoder)` "
    "for the `encoder` defined by the config file."
    "(Legacy checkpoints with `model=` instead of `encoder=` are also "
    "supported for now.) "
    "For --export_type model_with_mlm, it must be restorable to "
    "`tf.train.Checkpoint(**BertPretrainerV2(...).checkpoint_items)`. "
    "(For now, `tf.train.Checkpoint(pretrainer=BertPretrainerV2(...))` is also "
    "accepted.)")
flags.DEFINE_string(
    "vocab_file", None,
    "For encoders trained on BertTokenzier input: "
    "the vocabulary file that the encoder model was trained with. "
    "Exactly one of --vocab_file and --sp_model_file can be set. "
    "Needed for --export_type model, model_with_mlm and preprocessing.")
flags.DEFINE_string(
    "sp_model_file", None,
    "For encoders trained on SentencepieceTokenzier input: "
    "the SentencePiece .model file that the encoder model was trained with. "
    "Exactly one of --vocab_file and --sp_model_file can be set. "
    "Needed for --export_type model, model_with_mlm and preprocessing.")
flags.DEFINE_bool(
    "do_lower_case", None,
    "Whether to lowercase before tokenization. "
    "If left as None, and --vocab_file is set, do_lower_case will be enabled "
    "if 'uncased' appears in the name of --vocab_file. "
    "If left as None, and --sp_model_file set, do_lower_case defaults to true. "
    "Needed for --export_type model, model_with_mlm and preprocessing.")
flags.DEFINE_integer(
    "default_seq_length", 128,
    "The sequence length of preprocessing results from "
    "top-level preprocess method. This is also the default "
    "sequence length for the bert_pack_inputs subobject."
    "Needed for --export_type preprocessing.")
flags.DEFINE_bool(
    "tokenize_with_offsets", False,  # TODO(b/181866850)
    "Whether to export a .tokenize_with_offsets subobject for "
    "--export_type preprocessing.")
flags.DEFINE_multi_string(
    "gin_file", default=None,
    help="List of paths to the config files.")
flags.DEFINE_multi_string(
    "gin_params", default=None,
    help="List of Gin bindings.")
flags.DEFINE_bool(  # TODO(b/175369555): Remove this flag and its use.
    "experimental_disable_assert_in_preprocessing", False,
    "Export a preprocessing model without tf.Assert ops. "
    "Usually, that would be a bad idea, except TF2.4 has an issue with "
    "Assert ops in tf.functions used in Dataset.map() on a TPU worker, "
    "and omitting the Assert ops lets SavedModels avoid the issue.")


def main(argv):
  if len(argv) > 1:
    raise app.UsageError("Too many command-line arguments.")
  gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_params)

  if bool(FLAGS.vocab_file) == bool(FLAGS.sp_model_file):
    raise ValueError("Exactly one of `vocab_file` and `sp_model_file` "
                     "can be specified, but got %s and %s." %
                     (FLAGS.vocab_file, FLAGS.sp_model_file))
  do_lower_case = export_tfhub_lib.get_do_lower_case(
      FLAGS.do_lower_case, FLAGS.vocab_file, FLAGS.sp_model_file)

  if FLAGS.export_type in ("model", "model_with_mlm"):
    if bool(FLAGS.bert_config_file) == bool(FLAGS.encoder_config_file):
      raise ValueError("Exactly one of `bert_config_file` and "
                       "`encoder_config_file` can be specified, but got "
                       "%s and %s." %
                       (FLAGS.bert_config_file, FLAGS.encoder_config_file))
    if FLAGS.bert_config_file:
      bert_config = configs.BertConfig.from_json_file(FLAGS.bert_config_file)
      encoder_config = None
    else:
      bert_config = None
      encoder_config = encoders.EncoderConfig()
      encoder_config = hyperparams.override_params_dict(
          encoder_config, FLAGS.encoder_config_file, is_strict=True)
    export_tfhub_lib.export_model(
        FLAGS.export_path,
        bert_config=bert_config,
        encoder_config=encoder_config,
        model_checkpoint_path=FLAGS.model_checkpoint_path,
        vocab_file=FLAGS.vocab_file,
        sp_model_file=FLAGS.sp_model_file,
        do_lower_case=do_lower_case,
        with_mlm=FLAGS.export_type == "model_with_mlm",
        copy_pooler_dense_to_encoder=FLAGS.copy_pooler_dense_to_encoder)

  elif FLAGS.export_type == "preprocessing":
    export_tfhub_lib.export_preprocessing(
        FLAGS.export_path,
        vocab_file=FLAGS.vocab_file,
        sp_model_file=FLAGS.sp_model_file,
        do_lower_case=do_lower_case,
        default_seq_length=FLAGS.default_seq_length,
        tokenize_with_offsets=FLAGS.tokenize_with_offsets,
        experimental_disable_assert=
        FLAGS.experimental_disable_assert_in_preprocessing)

  else:
    raise app.UsageError(
        "Unknown value '%s' for flag --export_type" % FLAGS.export_type)


if __name__ == "__main__":
  app.run(main)