Pradeep Kumar commited on
Commit
1b88228
·
verified ·
1 Parent(s): f120585

Delete tf2_bert_encoder_checkpoint_converter.py

Browse files
tf2_bert_encoder_checkpoint_converter.py DELETED
@@ -1,160 +0,0 @@
1
- # Copyright 2024 The TensorFlow Authors. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- """A converter from a V1 BERT encoder checkpoint to a V2 encoder checkpoint.
16
-
17
- The conversion will yield an object-oriented checkpoint that can be used
18
- to restore a BertEncoder or BertPretrainerV2 object (see the `converted_model`
19
- FLAG below).
20
- """
21
-
22
- import os
23
-
24
- from absl import app
25
- from absl import flags
26
-
27
- import tensorflow as tf, tf_keras
28
- from official.legacy.bert import configs
29
- from official.modeling import tf_utils
30
- from official.nlp.modeling import models
31
- from official.nlp.modeling import networks
32
- from official.nlp.tools import tf1_bert_checkpoint_converter_lib
33
-
34
- FLAGS = flags.FLAGS
35
-
36
- flags.DEFINE_string("bert_config_file", None,
37
- "Bert configuration file to define core bert layers.")
38
- flags.DEFINE_string(
39
- "checkpoint_to_convert", None,
40
- "Initial checkpoint from a pretrained BERT model core (that is, only the "
41
- "BertModel, with no task heads.)")
42
- flags.DEFINE_string("converted_checkpoint_path", None,
43
- "Name for the created object-based V2 checkpoint.")
44
- flags.DEFINE_string("checkpoint_model_name", "encoder",
45
- "The name of the model when saving the checkpoint, i.e., "
46
- "the checkpoint will be saved using: "
47
- "tf.train.Checkpoint(FLAGS.checkpoint_model_name=model).")
48
- flags.DEFINE_enum(
49
- "converted_model", "encoder", ["encoder", "pretrainer"],
50
- "Whether to convert the checkpoint to a `BertEncoder` model or a "
51
- "`BertPretrainerV2` model (with mlm but without classification heads).")
52
-
53
-
54
- def _create_bert_model(cfg):
55
- """Creates a BERT keras core model from BERT configuration.
56
-
57
- Args:
58
- cfg: A `BertConfig` to create the core model.
59
-
60
- Returns:
61
- A BertEncoder network.
62
- """
63
- bert_encoder = networks.BertEncoder(
64
- vocab_size=cfg.vocab_size,
65
- hidden_size=cfg.hidden_size,
66
- num_layers=cfg.num_hidden_layers,
67
- num_attention_heads=cfg.num_attention_heads,
68
- intermediate_size=cfg.intermediate_size,
69
- activation=tf_utils.get_activation(cfg.hidden_act),
70
- dropout_rate=cfg.hidden_dropout_prob,
71
- attention_dropout_rate=cfg.attention_probs_dropout_prob,
72
- max_sequence_length=cfg.max_position_embeddings,
73
- type_vocab_size=cfg.type_vocab_size,
74
- initializer=tf_keras.initializers.TruncatedNormal(
75
- stddev=cfg.initializer_range),
76
- embedding_width=cfg.embedding_size)
77
-
78
- return bert_encoder
79
-
80
-
81
- def _create_bert_pretrainer_model(cfg):
82
- """Creates a BERT keras core model from BERT configuration.
83
-
84
- Args:
85
- cfg: A `BertConfig` to create the core model.
86
-
87
- Returns:
88
- A BertPretrainerV2 model.
89
- """
90
- bert_encoder = _create_bert_model(cfg)
91
- pretrainer = models.BertPretrainerV2(
92
- encoder_network=bert_encoder,
93
- mlm_activation=tf_utils.get_activation(cfg.hidden_act),
94
- mlm_initializer=tf_keras.initializers.TruncatedNormal(
95
- stddev=cfg.initializer_range))
96
- # Makes sure the pretrainer variables are created.
97
- _ = pretrainer(pretrainer.inputs)
98
- return pretrainer
99
-
100
-
101
- def convert_checkpoint(bert_config,
102
- output_path,
103
- v1_checkpoint,
104
- checkpoint_model_name="model",
105
- converted_model="encoder"):
106
- """Converts a V1 checkpoint into an OO V2 checkpoint."""
107
- output_dir, _ = os.path.split(output_path)
108
- tf.io.gfile.makedirs(output_dir)
109
-
110
- # Create a temporary V1 name-converted checkpoint in the output directory.
111
- temporary_checkpoint_dir = os.path.join(output_dir, "temp_v1")
112
- temporary_checkpoint = os.path.join(temporary_checkpoint_dir, "ckpt")
113
-
114
- tf1_bert_checkpoint_converter_lib.convert(
115
- checkpoint_from_path=v1_checkpoint,
116
- checkpoint_to_path=temporary_checkpoint,
117
- num_heads=bert_config.num_attention_heads,
118
- name_replacements=(
119
- tf1_bert_checkpoint_converter_lib.BERT_V2_NAME_REPLACEMENTS),
120
- permutations=tf1_bert_checkpoint_converter_lib.BERT_V2_PERMUTATIONS,
121
- exclude_patterns=["adam", "Adam"])
122
-
123
- if converted_model == "encoder":
124
- model = _create_bert_model(bert_config)
125
- elif converted_model == "pretrainer":
126
- model = _create_bert_pretrainer_model(bert_config)
127
- else:
128
- raise ValueError("Unsupported converted_model: %s" % converted_model)
129
-
130
- # Create a V2 checkpoint from the temporary checkpoint.
131
- tf1_bert_checkpoint_converter_lib.create_v2_checkpoint(
132
- model, temporary_checkpoint, output_path, checkpoint_model_name)
133
-
134
- # Clean up the temporary checkpoint, if it exists.
135
- try:
136
- tf.io.gfile.rmtree(temporary_checkpoint_dir)
137
- except tf.errors.OpError:
138
- # If it doesn't exist, we don't need to clean it up; continue.
139
- pass
140
-
141
-
142
- def main(argv):
143
- if len(argv) > 1:
144
- raise app.UsageError("Too many command-line arguments.")
145
-
146
- output_path = FLAGS.converted_checkpoint_path
147
- v1_checkpoint = FLAGS.checkpoint_to_convert
148
- checkpoint_model_name = FLAGS.checkpoint_model_name
149
- converted_model = FLAGS.converted_model
150
- bert_config = configs.BertConfig.from_json_file(FLAGS.bert_config_file)
151
- convert_checkpoint(
152
- bert_config=bert_config,
153
- output_path=output_path,
154
- v1_checkpoint=v1_checkpoint,
155
- checkpoint_model_name=checkpoint_model_name,
156
- converted_model=converted_model)
157
-
158
-
159
- if __name__ == "__main__":
160
- app.run(main)