Spaces:
Running
Running
Pradeep Kumar
commited on
Delete export_tfhub.py
Browse files- export_tfhub.py +0 -219
export_tfhub.py
DELETED
@@ -1,219 +0,0 @@
|
|
1 |
-
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
|
2 |
-
#
|
3 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
-
# you may not use this file except in compliance with the License.
|
5 |
-
# You may obtain a copy of the License at
|
6 |
-
#
|
7 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
-
#
|
9 |
-
# Unless required by applicable law or agreed to in writing, software
|
10 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
-
# See the License for the specific language governing permissions and
|
13 |
-
# limitations under the License.
|
14 |
-
|
15 |
-
r"""Exports a BERT-like encoder and its preprocessing as SavedModels for TF Hub.
|
16 |
-
|
17 |
-
This tool creates preprocessor and encoder SavedModels suitable for uploading
|
18 |
-
to https://tfhub.dev that implement the preprocessor and encoder APIs defined
|
19 |
-
at https://www.tensorflow.org/hub/common_saved_model_apis/text.
|
20 |
-
|
21 |
-
For a full usage guide, see
|
22 |
-
https://github.com/tensorflow/models/blob/master/official/nlp/docs/tfhub.md
|
23 |
-
|
24 |
-
Minimal usage examples:
|
25 |
-
|
26 |
-
1) Exporting an Encoder from checkpoint and config.
|
27 |
-
|
28 |
-
```
|
29 |
-
export_tfhub \
|
30 |
-
--encoder_config_file=${BERT_DIR:?}/bert_encoder.yaml \
|
31 |
-
--model_checkpoint_path=${BERT_DIR:?}/bert_model.ckpt \
|
32 |
-
--vocab_file=${BERT_DIR:?}/vocab.txt \
|
33 |
-
--export_type=model \
|
34 |
-
--export_path=/tmp/bert_model
|
35 |
-
```
|
36 |
-
|
37 |
-
An --encoder_config_file can specify encoder types other than BERT.
|
38 |
-
For BERT, a --bert_config_file in the legacy JSON format can be passed instead.
|
39 |
-
|
40 |
-
Flag --vocab_file (and flag --do_lower_case, whose default value is guessed
|
41 |
-
from the vocab_file path) capture how BertTokenizer was used in pre-training.
|
42 |
-
Use flag --sp_model_file instead if SentencepieceTokenizer was used.
|
43 |
-
|
44 |
-
Changing --export_type to model_with_mlm additionally creates an `.mlm`
|
45 |
-
subobject on the exported SavedModel that can be called to produce
|
46 |
-
the logits of the Masked Language Model task from pretraining.
|
47 |
-
The help string for flag --model_checkpoint_path explains the checkpoint
|
48 |
-
formats required for each --export_type.
|
49 |
-
|
50 |
-
|
51 |
-
2) Exporting a preprocessor SavedModel
|
52 |
-
|
53 |
-
```
|
54 |
-
export_tfhub \
|
55 |
-
--vocab_file ${BERT_DIR:?}/vocab.txt \
|
56 |
-
--export_type preprocessing --export_path /tmp/bert_preprocessing
|
57 |
-
```
|
58 |
-
|
59 |
-
Be sure to use flag values that match the encoder and how it has been
|
60 |
-
pre-trained (see above for --vocab_file vs --sp_model_file).
|
61 |
-
|
62 |
-
If your encoder has been trained with text preprocessing for which tfhub.dev
|
63 |
-
already has SavedModel, you could guide your users to reuse that one instead
|
64 |
-
of exporting and publishing your own.
|
65 |
-
|
66 |
-
TODO(b/175369555): When exporting to users of TensorFlow 2.4, add flag
|
67 |
-
`--experimental_disable_assert_in_preprocessing`.
|
68 |
-
"""
|
69 |
-
|
70 |
-
from absl import app
|
71 |
-
from absl import flags
|
72 |
-
import gin
|
73 |
-
|
74 |
-
from official.legacy.bert import configs
|
75 |
-
from official.modeling import hyperparams
|
76 |
-
from official.nlp.configs import encoders
|
77 |
-
from official.nlp.tools import export_tfhub_lib
|
78 |
-
|
79 |
-
FLAGS = flags.FLAGS
|
80 |
-
|
81 |
-
flags.DEFINE_enum(
|
82 |
-
"export_type", "model",
|
83 |
-
["model", "model_with_mlm", "preprocessing"],
|
84 |
-
"The overall type of SavedModel to export. Flags "
|
85 |
-
"--bert_config_file/--encoder_config_file and --vocab_file/--sp_model_file "
|
86 |
-
"control which particular encoder model and preprocessing are exported.")
|
87 |
-
flags.DEFINE_string(
|
88 |
-
"export_path", None,
|
89 |
-
"Directory to which the SavedModel is written.")
|
90 |
-
flags.DEFINE_string(
|
91 |
-
"encoder_config_file", None,
|
92 |
-
"A yaml file representing `encoders.EncoderConfig` to define the encoder "
|
93 |
-
"(BERT or other). "
|
94 |
-
"Exactly one of --bert_config_file and --encoder_config_file can be set. "
|
95 |
-
"Needed for --export_type model and model_with_mlm.")
|
96 |
-
flags.DEFINE_string(
|
97 |
-
"bert_config_file", None,
|
98 |
-
"A JSON file with a legacy BERT configuration to define the BERT encoder. "
|
99 |
-
"Exactly one of --bert_config_file and --encoder_config_file can be set. "
|
100 |
-
"Needed for --export_type model and model_with_mlm.")
|
101 |
-
flags.DEFINE_bool(
|
102 |
-
"copy_pooler_dense_to_encoder", False,
|
103 |
-
"When the model is trained using `BertPretrainerV2`, the pool layer "
|
104 |
-
"of next sentence prediction task exists in `ClassificationHead` passed "
|
105 |
-
"to `BertPretrainerV2`. If True, we will copy this pooler's dense layer "
|
106 |
-
"to the encoder that is exported by this tool (as in classic BERT). "
|
107 |
-
"Using `BertPretrainerV2` and leaving this False exports an untrained "
|
108 |
-
"(randomly initialized) pooling layer, which some authors recommend for "
|
109 |
-
"subsequent fine-tuning,")
|
110 |
-
flags.DEFINE_string(
|
111 |
-
"model_checkpoint_path", None,
|
112 |
-
"File path to a pre-trained model checkpoint. "
|
113 |
-
"For --export_type model, this has to be an object-based (TF2) checkpoint "
|
114 |
-
"that can be restored to `tf.train.Checkpoint(encoder=encoder)` "
|
115 |
-
"for the `encoder` defined by the config file."
|
116 |
-
"(Legacy checkpoints with `model=` instead of `encoder=` are also "
|
117 |
-
"supported for now.) "
|
118 |
-
"For --export_type model_with_mlm, it must be restorable to "
|
119 |
-
"`tf.train.Checkpoint(**BertPretrainerV2(...).checkpoint_items)`. "
|
120 |
-
"(For now, `tf.train.Checkpoint(pretrainer=BertPretrainerV2(...))` is also "
|
121 |
-
"accepted.)")
|
122 |
-
flags.DEFINE_string(
|
123 |
-
"vocab_file", None,
|
124 |
-
"For encoders trained on BertTokenzier input: "
|
125 |
-
"the vocabulary file that the encoder model was trained with. "
|
126 |
-
"Exactly one of --vocab_file and --sp_model_file can be set. "
|
127 |
-
"Needed for --export_type model, model_with_mlm and preprocessing.")
|
128 |
-
flags.DEFINE_string(
|
129 |
-
"sp_model_file", None,
|
130 |
-
"For encoders trained on SentencepieceTokenzier input: "
|
131 |
-
"the SentencePiece .model file that the encoder model was trained with. "
|
132 |
-
"Exactly one of --vocab_file and --sp_model_file can be set. "
|
133 |
-
"Needed for --export_type model, model_with_mlm and preprocessing.")
|
134 |
-
flags.DEFINE_bool(
|
135 |
-
"do_lower_case", None,
|
136 |
-
"Whether to lowercase before tokenization. "
|
137 |
-
"If left as None, and --vocab_file is set, do_lower_case will be enabled "
|
138 |
-
"if 'uncased' appears in the name of --vocab_file. "
|
139 |
-
"If left as None, and --sp_model_file set, do_lower_case defaults to true. "
|
140 |
-
"Needed for --export_type model, model_with_mlm and preprocessing.")
|
141 |
-
flags.DEFINE_integer(
|
142 |
-
"default_seq_length", 128,
|
143 |
-
"The sequence length of preprocessing results from "
|
144 |
-
"top-level preprocess method. This is also the default "
|
145 |
-
"sequence length for the bert_pack_inputs subobject."
|
146 |
-
"Needed for --export_type preprocessing.")
|
147 |
-
flags.DEFINE_bool(
|
148 |
-
"tokenize_with_offsets", False, # TODO(b/181866850)
|
149 |
-
"Whether to export a .tokenize_with_offsets subobject for "
|
150 |
-
"--export_type preprocessing.")
|
151 |
-
flags.DEFINE_multi_string(
|
152 |
-
"gin_file", default=None,
|
153 |
-
help="List of paths to the config files.")
|
154 |
-
flags.DEFINE_multi_string(
|
155 |
-
"gin_params", default=None,
|
156 |
-
help="List of Gin bindings.")
|
157 |
-
flags.DEFINE_bool( # TODO(b/175369555): Remove this flag and its use.
|
158 |
-
"experimental_disable_assert_in_preprocessing", False,
|
159 |
-
"Export a preprocessing model without tf.Assert ops. "
|
160 |
-
"Usually, that would be a bad idea, except TF2.4 has an issue with "
|
161 |
-
"Assert ops in tf.functions used in Dataset.map() on a TPU worker, "
|
162 |
-
"and omitting the Assert ops lets SavedModels avoid the issue.")
|
163 |
-
|
164 |
-
|
165 |
-
def main(argv):
|
166 |
-
if len(argv) > 1:
|
167 |
-
raise app.UsageError("Too many command-line arguments.")
|
168 |
-
gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_params)
|
169 |
-
|
170 |
-
if bool(FLAGS.vocab_file) == bool(FLAGS.sp_model_file):
|
171 |
-
raise ValueError("Exactly one of `vocab_file` and `sp_model_file` "
|
172 |
-
"can be specified, but got %s and %s." %
|
173 |
-
(FLAGS.vocab_file, FLAGS.sp_model_file))
|
174 |
-
do_lower_case = export_tfhub_lib.get_do_lower_case(
|
175 |
-
FLAGS.do_lower_case, FLAGS.vocab_file, FLAGS.sp_model_file)
|
176 |
-
|
177 |
-
if FLAGS.export_type in ("model", "model_with_mlm"):
|
178 |
-
if bool(FLAGS.bert_config_file) == bool(FLAGS.encoder_config_file):
|
179 |
-
raise ValueError("Exactly one of `bert_config_file` and "
|
180 |
-
"`encoder_config_file` can be specified, but got "
|
181 |
-
"%s and %s." %
|
182 |
-
(FLAGS.bert_config_file, FLAGS.encoder_config_file))
|
183 |
-
if FLAGS.bert_config_file:
|
184 |
-
bert_config = configs.BertConfig.from_json_file(FLAGS.bert_config_file)
|
185 |
-
encoder_config = None
|
186 |
-
else:
|
187 |
-
bert_config = None
|
188 |
-
encoder_config = encoders.EncoderConfig()
|
189 |
-
encoder_config = hyperparams.override_params_dict(
|
190 |
-
encoder_config, FLAGS.encoder_config_file, is_strict=True)
|
191 |
-
export_tfhub_lib.export_model(
|
192 |
-
FLAGS.export_path,
|
193 |
-
bert_config=bert_config,
|
194 |
-
encoder_config=encoder_config,
|
195 |
-
model_checkpoint_path=FLAGS.model_checkpoint_path,
|
196 |
-
vocab_file=FLAGS.vocab_file,
|
197 |
-
sp_model_file=FLAGS.sp_model_file,
|
198 |
-
do_lower_case=do_lower_case,
|
199 |
-
with_mlm=FLAGS.export_type == "model_with_mlm",
|
200 |
-
copy_pooler_dense_to_encoder=FLAGS.copy_pooler_dense_to_encoder)
|
201 |
-
|
202 |
-
elif FLAGS.export_type == "preprocessing":
|
203 |
-
export_tfhub_lib.export_preprocessing(
|
204 |
-
FLAGS.export_path,
|
205 |
-
vocab_file=FLAGS.vocab_file,
|
206 |
-
sp_model_file=FLAGS.sp_model_file,
|
207 |
-
do_lower_case=do_lower_case,
|
208 |
-
default_seq_length=FLAGS.default_seq_length,
|
209 |
-
tokenize_with_offsets=FLAGS.tokenize_with_offsets,
|
210 |
-
experimental_disable_assert=
|
211 |
-
FLAGS.experimental_disable_assert_in_preprocessing)
|
212 |
-
|
213 |
-
else:
|
214 |
-
raise app.UsageError(
|
215 |
-
"Unknown value '%s' for flag --export_type" % FLAGS.export_type)
|
216 |
-
|
217 |
-
|
218 |
-
if __name__ == "__main__":
|
219 |
-
app.run(main)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|