Pradeep Kumar commited on
Commit
b88c2e8
·
verified ·
1 Parent(s): 7a97586

Delete export_tfhub_lib.py

Browse files
Files changed (1) hide show
  1. export_tfhub_lib.py +0 -493
export_tfhub_lib.py DELETED
@@ -1,493 +0,0 @@
1
- # Copyright 2024 The TensorFlow Authors. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- """Library of components of export_tfhub.py. See docstring there for more."""
16
-
17
- import contextlib
18
- import hashlib
19
- import os
20
- import tempfile
21
-
22
- from typing import Optional, Text, Tuple
23
-
24
- # Import libraries
25
- from absl import logging
26
- import tensorflow as tf, tf_keras
27
- # pylint: disable=g-direct-tensorflow-import TODO(b/175369555): Remove these.
28
- from tensorflow.core.protobuf import saved_model_pb2
29
- from tensorflow.python.ops import control_flow_assert
30
- # pylint: enable=g-direct-tensorflow-import
31
- from official.legacy.bert import configs
32
- from official.modeling import tf_utils
33
- from official.nlp.configs import encoders
34
- from official.nlp.modeling import layers
35
- from official.nlp.modeling import models
36
- from official.nlp.modeling import networks
37
-
38
-
39
- def get_bert_encoder(bert_config):
40
- """Returns a BertEncoder with dict outputs."""
41
- bert_encoder = networks.BertEncoder(
42
- vocab_size=bert_config.vocab_size,
43
- hidden_size=bert_config.hidden_size,
44
- num_layers=bert_config.num_hidden_layers,
45
- num_attention_heads=bert_config.num_attention_heads,
46
- intermediate_size=bert_config.intermediate_size,
47
- activation=tf_utils.get_activation(bert_config.hidden_act),
48
- dropout_rate=bert_config.hidden_dropout_prob,
49
- attention_dropout_rate=bert_config.attention_probs_dropout_prob,
50
- max_sequence_length=bert_config.max_position_embeddings,
51
- type_vocab_size=bert_config.type_vocab_size,
52
- initializer=tf_keras.initializers.TruncatedNormal(
53
- stddev=bert_config.initializer_range),
54
- embedding_width=bert_config.embedding_size,
55
- dict_outputs=True)
56
-
57
- return bert_encoder
58
-
59
-
60
- def get_do_lower_case(do_lower_case, vocab_file=None, sp_model_file=None):
61
- """Returns do_lower_case, replacing None by a guess from vocab file name."""
62
- if do_lower_case is not None:
63
- return do_lower_case
64
- elif vocab_file:
65
- do_lower_case = "uncased" in vocab_file
66
- logging.info("Using do_lower_case=%s based on name of vocab_file=%s",
67
- do_lower_case, vocab_file)
68
- return do_lower_case
69
- elif sp_model_file:
70
- do_lower_case = True # All public ALBERTs (as of Oct 2020) do it.
71
- logging.info("Defaulting to do_lower_case=%s for Sentencepiece tokenizer",
72
- do_lower_case)
73
- return do_lower_case
74
- else:
75
- raise ValueError("Must set vocab_file or sp_model_file.")
76
-
77
-
78
- def _create_model(
79
- *,
80
- bert_config: Optional[configs.BertConfig] = None,
81
- encoder_config: Optional[encoders.EncoderConfig] = None,
82
- with_mlm: bool,
83
- ) -> Tuple[tf_keras.Model, tf_keras.Model]:
84
- """Creates the model to export and the model to restore the checkpoint.
85
-
86
- Args:
87
- bert_config: A legacy `BertConfig` to create a `BertEncoder` object. Exactly
88
- one of encoder_config and bert_config must be set.
89
- encoder_config: An `EncoderConfig` to create an encoder of the configured
90
- type (`BertEncoder` or other).
91
- with_mlm: A bool to control the second component of the result. If True,
92
- will create a `BertPretrainerV2` object; otherwise, will create a
93
- `BertEncoder` object.
94
-
95
- Returns:
96
- A Tuple of (1) a Keras model that will be exported, (2) a `BertPretrainerV2`
97
- object or `BertEncoder` object depending on the value of `with_mlm`
98
- argument, which contains the first model and will be used for restoring
99
- weights from the checkpoint.
100
- """
101
- if (bert_config is not None) == (encoder_config is not None):
102
- raise ValueError("Exactly one of `bert_config` and `encoder_config` "
103
- "can be specified, but got %s and %s" %
104
- (bert_config, encoder_config))
105
-
106
- if bert_config is not None:
107
- encoder = get_bert_encoder(bert_config)
108
- else:
109
- encoder = encoders.build_encoder(encoder_config)
110
-
111
- # Convert from list of named inputs to dict of inputs keyed by name.
112
- # Only the latter accepts a dict of inputs after restoring from SavedModel.
113
- if isinstance(encoder.inputs, list) or isinstance(encoder.inputs, tuple):
114
- encoder_inputs_dict = {x.name: x for x in encoder.inputs}
115
- else:
116
- # encoder.inputs by default is dict for BertEncoderV2.
117
- encoder_inputs_dict = encoder.inputs
118
- encoder_output_dict = encoder(encoder_inputs_dict)
119
- # For interchangeability with other text representations,
120
- # add "default" as an alias for BERT's whole-input reptesentations.
121
- encoder_output_dict["default"] = encoder_output_dict["pooled_output"]
122
- core_model = tf_keras.Model(
123
- inputs=encoder_inputs_dict, outputs=encoder_output_dict)
124
-
125
- if with_mlm:
126
- if bert_config is not None:
127
- hidden_act = bert_config.hidden_act
128
- else:
129
- assert encoder_config is not None
130
- hidden_act = encoder_config.get().hidden_activation
131
-
132
- pretrainer = models.BertPretrainerV2(
133
- encoder_network=encoder,
134
- mlm_activation=tf_utils.get_activation(hidden_act))
135
-
136
- if isinstance(pretrainer.inputs, dict):
137
- pretrainer_inputs_dict = pretrainer.inputs
138
- else:
139
- pretrainer_inputs_dict = {x.name: x for x in pretrainer.inputs}
140
- pretrainer_output_dict = pretrainer(pretrainer_inputs_dict)
141
- mlm_model = tf_keras.Model(
142
- inputs=pretrainer_inputs_dict, outputs=pretrainer_output_dict)
143
- # Set `_auto_track_sub_layers` to False, so that the additional weights
144
- # from `mlm` sub-object will not be included in the core model.
145
- # TODO(b/169210253): Use a public API when available.
146
- core_model._auto_track_sub_layers = False # pylint: disable=protected-access
147
- core_model.mlm = mlm_model
148
- return core_model, pretrainer
149
- else:
150
- return core_model, encoder
151
-
152
-
153
- def export_model(export_path: Text,
154
- *,
155
- bert_config: Optional[configs.BertConfig] = None,
156
- encoder_config: Optional[encoders.EncoderConfig] = None,
157
- model_checkpoint_path: Text,
158
- with_mlm: bool,
159
- copy_pooler_dense_to_encoder: bool = False,
160
- vocab_file: Optional[Text] = None,
161
- sp_model_file: Optional[Text] = None,
162
- do_lower_case: Optional[bool] = None) -> None:
163
- """Exports an Encoder as SavedModel after restoring pre-trained weights.
164
-
165
- The exported SavedModel implements a superset of the Encoder API for
166
- Text embeddings with Transformer Encoders described at
167
- https://www.tensorflow.org/hub/common_saved_model_apis/text.
168
-
169
- In particular, the exported SavedModel can be used in the following way:
170
-
171
- ```
172
- # Calls default interface (encoder only).
173
-
174
- encoder = hub.load(...)
175
- encoder_inputs = dict(
176
- input_word_ids=..., # Shape [batch, seq_length], dtype=int32
177
- input_mask=..., # Shape [batch, seq_length], dtype=int32
178
- input_type_ids=..., # Shape [batch, seq_length], dtype=int32
179
- )
180
- encoder_outputs = encoder(encoder_inputs)
181
- assert encoder_outputs.keys() == {
182
- "pooled_output", # Shape [batch_size, width], dtype=float32
183
- "default", # Alias for "pooled_output" (aligns with other models).
184
- "sequence_output" # Shape [batch_size, seq_length, width], dtype=float32
185
- "encoder_outputs", # List of Tensors with outputs of all transformer layers.
186
- }
187
- ```
188
-
189
- If `with_mlm` is True, the exported SavedModel can also be called in the
190
- following way:
191
-
192
- ```
193
- # Calls expanded interface that includes logits of the Masked Language Model.
194
- mlm_inputs = dict(
195
- input_word_ids=..., # Shape [batch, seq_length], dtype=int32
196
- input_mask=..., # Shape [batch, seq_length], dtype=int32
197
- input_type_ids=..., # Shape [batch, seq_length], dtype=int32
198
- masked_lm_positions=..., # Shape [batch, num_predictions], dtype=int32
199
- )
200
- mlm_outputs = encoder.mlm(mlm_inputs)
201
- assert mlm_outputs.keys() == {
202
- "pooled_output", # Shape [batch, width], dtype=float32
203
- "sequence_output", # Shape [batch, seq_length, width], dtype=float32
204
- "encoder_outputs", # List of Tensors with outputs of all transformer layers.
205
- "mlm_logits" # Shape [batch, num_predictions, vocab_size], dtype=float32
206
- }
207
- ```
208
-
209
- Args:
210
- export_path: The SavedModel output directory.
211
- bert_config: An optional `configs.BertConfig` object. Note: exactly one of
212
- `bert_config` and following `encoder_config` must be specified.
213
- encoder_config: An optional `encoders.EncoderConfig` object.
214
- model_checkpoint_path: The path to the checkpoint.
215
- with_mlm: Whether to export the additional mlm sub-object.
216
- copy_pooler_dense_to_encoder: Whether to copy the pooler's dense layer used
217
- in the next sentence prediction task to the encoder.
218
- vocab_file: The path to the wordpiece vocab file, or None.
219
- sp_model_file: The path to the sentencepiece model file, or None. Exactly
220
- one of vocab_file and sp_model_file must be set.
221
- do_lower_case: Whether to lower-case text before tokenization.
222
- """
223
- if with_mlm:
224
- core_model, pretrainer = _create_model(
225
- bert_config=bert_config,
226
- encoder_config=encoder_config,
227
- with_mlm=with_mlm)
228
- encoder = pretrainer.encoder_network
229
- # It supports both the new pretrainer checkpoint produced by TF-NLP and
230
- # the checkpoint converted from TF1 (original BERT, SmallBERTs).
231
- checkpoint_items = pretrainer.checkpoint_items
232
- checkpoint = tf.train.Checkpoint(**checkpoint_items)
233
- else:
234
- core_model, encoder = _create_model(
235
- bert_config=bert_config,
236
- encoder_config=encoder_config,
237
- with_mlm=with_mlm)
238
- checkpoint = tf.train.Checkpoint(
239
- model=encoder, # Legacy checkpoints.
240
- encoder=encoder)
241
- checkpoint.restore(model_checkpoint_path).assert_existing_objects_matched()
242
-
243
- if copy_pooler_dense_to_encoder:
244
- logging.info("Copy pooler's dense layer to the encoder.")
245
- pooler_checkpoint = tf.train.Checkpoint(
246
- **{"next_sentence.pooler_dense": encoder.pooler_layer})
247
- pooler_checkpoint.restore(
248
- model_checkpoint_path).assert_existing_objects_matched()
249
-
250
- # Before SavedModels for preprocessing appeared in Oct 2020, the encoders
251
- # provided this information to let users do preprocessing themselves.
252
- # We keep doing that for now. It helps users to upgrade incrementally.
253
- # Moreover, it offers an escape hatch for advanced users who want the
254
- # full vocab, not the high-level operations from the preprocessing model.
255
- if vocab_file:
256
- core_model.vocab_file = tf.saved_model.Asset(vocab_file)
257
- if do_lower_case is None:
258
- raise ValueError("Must pass do_lower_case if passing vocab_file.")
259
- core_model.do_lower_case = tf.Variable(do_lower_case, trainable=False)
260
- elif sp_model_file:
261
- # This was used by ALBERT, with implied values of do_lower_case=True
262
- # and strip_diacritics=True.
263
- core_model.sp_model_file = tf.saved_model.Asset(sp_model_file)
264
- else:
265
- raise ValueError("Must set vocab_file or sp_model_file")
266
- core_model.save(export_path, include_optimizer=False, save_format="tf")
267
-
268
-
269
- class BertPackInputsSavedModelWrapper(tf.train.Checkpoint):
270
- """Wraps a BertPackInputs layer for export to SavedModel.
271
-
272
- The wrapper object is suitable for use with `tf.saved_model.save()` and
273
- `.load()`. The wrapper object is callable with inputs and outputs like the
274
- BertPackInputs layer, but differs from saving an unwrapped Keras object:
275
-
276
- - The inputs can be a list of 1 or 2 RaggedTensors of dtype int32 and
277
- ragged rank 1 or 2. (In Keras, saving to a tf.function in a SavedModel
278
- would fix the number of RaggedTensors and their ragged rank.)
279
- - The call accepts an optional keyword argument `seq_length=` to override
280
- the layer's .seq_length hyperparameter. (In Keras, a hyperparameter
281
- could not be changed after saving to a tf.function in a SavedModel.)
282
- """
283
-
284
- def __init__(self, bert_pack_inputs: layers.BertPackInputs):
285
- super().__init__()
286
-
287
- # Preserve the layer's configured seq_length as a default but make it
288
- # overridable. Having this dynamically determined default argument
289
- # requires self.__call__ to be defined in this indirect way.
290
- default_seq_length = bert_pack_inputs.seq_length
291
-
292
- @tf.function(autograph=False)
293
- def call(inputs, seq_length=default_seq_length):
294
- return layers.BertPackInputs.bert_pack_inputs(
295
- inputs,
296
- seq_length=seq_length,
297
- start_of_sequence_id=bert_pack_inputs.start_of_sequence_id,
298
- end_of_segment_id=bert_pack_inputs.end_of_segment_id,
299
- padding_id=bert_pack_inputs.padding_id)
300
-
301
- self.__call__ = call
302
-
303
- for ragged_rank in range(1, 3):
304
- for num_segments in range(1, 3):
305
- _ = self.__call__.get_concrete_function([
306
- tf.RaggedTensorSpec([None] * (ragged_rank + 1), dtype=tf.int32)
307
- for _ in range(num_segments)
308
- ],
309
- seq_length=tf.TensorSpec(
310
- [], tf.int32))
311
-
312
-
313
- def create_preprocessing(*,
314
- vocab_file: Optional[str] = None,
315
- sp_model_file: Optional[str] = None,
316
- do_lower_case: bool,
317
- tokenize_with_offsets: bool,
318
- default_seq_length: int) -> tf_keras.Model:
319
- """Returns a preprocessing Model for given tokenization parameters.
320
-
321
- This function builds a Keras Model with attached subobjects suitable for
322
- saving to a SavedModel. The resulting SavedModel implements the Preprocessor
323
- API for Text embeddings with Transformer Encoders described at
324
- https://www.tensorflow.org/hub/common_saved_model_apis/text.
325
-
326
- Args:
327
- vocab_file: The path to the wordpiece vocab file, or None.
328
- sp_model_file: The path to the sentencepiece model file, or None. Exactly
329
- one of vocab_file and sp_model_file must be set. This determines the type
330
- of tokenzer that is used.
331
- do_lower_case: Whether to do lower case.
332
- tokenize_with_offsets: Whether to include the .tokenize_with_offsets
333
- subobject.
334
- default_seq_length: The sequence length of preprocessing results from root
335
- callable. This is also the default sequence length for the
336
- bert_pack_inputs subobject.
337
-
338
- Returns:
339
- A tf_keras.Model object with several attached subobjects, suitable for
340
- saving as a preprocessing SavedModel.
341
- """
342
- # Select tokenizer.
343
- if bool(vocab_file) == bool(sp_model_file):
344
- raise ValueError("Must set exactly one of vocab_file, sp_model_file")
345
- if vocab_file:
346
- tokenize = layers.BertTokenizer(
347
- vocab_file=vocab_file,
348
- lower_case=do_lower_case,
349
- tokenize_with_offsets=tokenize_with_offsets)
350
- else:
351
- tokenize = layers.SentencepieceTokenizer(
352
- model_file_path=sp_model_file,
353
- lower_case=do_lower_case,
354
- strip_diacritics=True, # Strip diacritics to follow ALBERT model.
355
- tokenize_with_offsets=tokenize_with_offsets)
356
-
357
- # The root object of the preprocessing model can be called to do
358
- # one-shot preprocessing for users with single-sentence inputs.
359
- sentences = tf_keras.layers.Input(shape=(), dtype=tf.string, name="sentences")
360
- if tokenize_with_offsets:
361
- tokens, start_offsets, limit_offsets = tokenize(sentences)
362
- else:
363
- tokens = tokenize(sentences)
364
- pack = layers.BertPackInputs(
365
- seq_length=default_seq_length,
366
- special_tokens_dict=tokenize.get_special_tokens_dict())
367
- model_inputs = pack(tokens)
368
- preprocessing = tf_keras.Model(sentences, model_inputs)
369
-
370
- # Individual steps of preprocessing are made available as named subobjects
371
- # to enable more general preprocessing. For saving, they need to be Models
372
- # in their own right.
373
- preprocessing.tokenize = tf_keras.Model(sentences, tokens)
374
- # Provide an equivalent to tokenize.get_special_tokens_dict().
375
- preprocessing.tokenize.get_special_tokens_dict = tf.train.Checkpoint()
376
- preprocessing.tokenize.get_special_tokens_dict.__call__ = tf.function(
377
- lambda: tokenize.get_special_tokens_dict(), # pylint: disable=[unnecessary-lambda]
378
- input_signature=[])
379
- if tokenize_with_offsets:
380
- preprocessing.tokenize_with_offsets = tf_keras.Model(
381
- sentences, [tokens, start_offsets, limit_offsets])
382
- preprocessing.tokenize_with_offsets.get_special_tokens_dict = (
383
- preprocessing.tokenize.get_special_tokens_dict)
384
- # Conceptually, this should be
385
- # preprocessing.bert_pack_inputs = tf_keras.Model(tokens, model_inputs)
386
- # but technicalities require us to use a wrapper (see comments there).
387
- # In particular, seq_length can be overridden when calling this.
388
- preprocessing.bert_pack_inputs = BertPackInputsSavedModelWrapper(pack)
389
-
390
- return preprocessing
391
-
392
-
393
- def _move_to_tmpdir(file_path: Optional[Text], tmpdir: Text) -> Optional[Text]:
394
- """Returns new path with same basename and hash of original path."""
395
- if file_path is None:
396
- return None
397
- olddir, filename = os.path.split(file_path)
398
- hasher = hashlib.sha1()
399
- hasher.update(olddir.encode("utf-8"))
400
- target_dir = os.path.join(tmpdir, hasher.hexdigest())
401
- target_file = os.path.join(target_dir, filename)
402
- tf.io.gfile.mkdir(target_dir)
403
- tf.io.gfile.copy(file_path, target_file)
404
- return target_file
405
-
406
-
407
- def export_preprocessing(export_path: Text,
408
- *,
409
- vocab_file: Optional[Text] = None,
410
- sp_model_file: Optional[Text] = None,
411
- do_lower_case: bool,
412
- tokenize_with_offsets: bool,
413
- default_seq_length: int,
414
- experimental_disable_assert: bool = False) -> None:
415
- """Exports preprocessing to a SavedModel for TF Hub."""
416
- with tempfile.TemporaryDirectory() as tmpdir:
417
- # TODO(b/175369555): Remove experimental_disable_assert and its use.
418
- with _maybe_disable_assert(experimental_disable_assert):
419
- preprocessing = create_preprocessing(
420
- vocab_file=_move_to_tmpdir(vocab_file, tmpdir),
421
- sp_model_file=_move_to_tmpdir(sp_model_file, tmpdir),
422
- do_lower_case=do_lower_case,
423
- tokenize_with_offsets=tokenize_with_offsets,
424
- default_seq_length=default_seq_length)
425
- preprocessing.save(export_path, include_optimizer=False, save_format="tf")
426
- if experimental_disable_assert:
427
- _check_no_assert(export_path)
428
- # It helps the unit test to prevent stray copies of the vocab file.
429
- if tf.io.gfile.exists(tmpdir):
430
- raise IOError("Failed to clean up TemporaryDirectory")
431
-
432
-
433
- # TODO(b/175369555): Remove all workarounds for this bug of TensorFlow 2.4
434
- # when this bug is no longer a concern for publishing new models.
435
- # TensorFlow 2.4 has a placement issue with Assert ops in tf.functions called
436
- # from Dataset.map() on a TPU worker. They end up on the TPU coordinator,
437
- # and invoking them from the TPU worker is either inefficient (when possible)
438
- # or impossible (notably when using "headless" TPU workers on Cloud that do not
439
- # have a channel to the coordinator). The bug has been fixed in time for TF 2.5.
440
- # To work around this, the following code avoids Assert ops in the exported
441
- # SavedModels. It monkey-patches calls to tf.Assert from inside TensorFlow and
442
- # replaces them by a no-op while building the exported model. This is fragile,
443
- # so _check_no_assert() validates the result. The resulting model should be fine
444
- # to read on future versions of TF, even if this workaround at export time
445
- # may break eventually. (Failing unit tests will tell.)
446
-
447
-
448
- def _dont_assert(condition, data, summarize=None, name="Assert"):
449
- """The no-op version of tf.Assert installed by _maybe_disable_assert."""
450
- del condition, data, summarize # Unused.
451
- if tf.executing_eagerly():
452
- return
453
- with tf.name_scope(name):
454
- return tf.no_op(name="dont_assert")
455
-
456
-
457
- @contextlib.contextmanager
458
- def _maybe_disable_assert(disable_assert):
459
- """Scoped monkey patch of control_flow_assert.Assert to a no-op."""
460
- if not disable_assert:
461
- yield
462
- return
463
-
464
- original_assert = control_flow_assert.Assert
465
- control_flow_assert.Assert = _dont_assert
466
- yield
467
- control_flow_assert.Assert = original_assert
468
-
469
-
470
- def _check_no_assert(saved_model_path):
471
- """Raises AssertionError if SavedModel contains Assert ops."""
472
- saved_model_filename = os.path.join(saved_model_path, "saved_model.pb")
473
- with tf.io.gfile.GFile(saved_model_filename, "rb") as f:
474
- saved_model = saved_model_pb2.SavedModel.FromString(f.read())
475
-
476
- assert_nodes = []
477
- graph_def = saved_model.meta_graphs[0].graph_def
478
- assert_nodes += [
479
- "node '{}' in global graph".format(n.name)
480
- for n in graph_def.node
481
- if n.op == "Assert"
482
- ]
483
- for fdef in graph_def.library.function:
484
- assert_nodes += [
485
- "node '{}' in function '{}'".format(n.name, fdef.signature.name)
486
- for n in fdef.node_def
487
- if n.op == "Assert"
488
- ]
489
- if assert_nodes:
490
- raise AssertionError(
491
- "Internal tool error: "
492
- "failed to suppress {} Assert ops in SavedModel:\n{}".format(
493
- len(assert_nodes), "\n".join(assert_nodes[:10])))