Pradeep Kumar commited on
Commit
adc4eb5
·
verified ·
1 Parent(s): f724412

Delete tf1_bert_checkpoint_converter_lib.py

Browse files
Files changed (1) hide show
  1. tf1_bert_checkpoint_converter_lib.py +0 -201
tf1_bert_checkpoint_converter_lib.py DELETED
@@ -1,201 +0,0 @@
1
- # Copyright 2024 The TensorFlow Authors. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- r"""Convert checkpoints created by Estimator (tf1) to be Keras compatible."""
16
-
17
- import numpy as np
18
- import tensorflow.compat.v1 as tf # TF 1.x
19
-
20
- # Mapping between old <=> new names. The source pattern in original variable
21
- # name will be replaced by destination pattern.
22
- BERT_NAME_REPLACEMENTS = (
23
- ("bert", "bert_model"),
24
- ("embeddings/word_embeddings", "word_embeddings/embeddings"),
25
- ("embeddings/token_type_embeddings",
26
- "embedding_postprocessor/type_embeddings"),
27
- ("embeddings/position_embeddings",
28
- "embedding_postprocessor/position_embeddings"),
29
- ("embeddings/LayerNorm", "embedding_postprocessor/layer_norm"),
30
- ("attention/self", "self_attention"),
31
- ("attention/output/dense", "self_attention_output"),
32
- ("attention/output/LayerNorm", "self_attention_layer_norm"),
33
- ("intermediate/dense", "intermediate"),
34
- ("output/dense", "output"),
35
- ("output/LayerNorm", "output_layer_norm"),
36
- ("pooler/dense", "pooler_transform"),
37
- )
38
-
39
- BERT_V2_NAME_REPLACEMENTS = (
40
- ("bert/", ""),
41
- ("encoder", "transformer"),
42
- ("embeddings/word_embeddings", "word_embeddings/embeddings"),
43
- ("embeddings/token_type_embeddings", "type_embeddings/embeddings"),
44
- ("embeddings/position_embeddings", "position_embedding/embeddings"),
45
- ("embeddings/LayerNorm", "embeddings/layer_norm"),
46
- ("attention/self", "self_attention"),
47
- ("attention/output/dense", "self_attention/attention_output"),
48
- ("attention/output/LayerNorm", "self_attention_layer_norm"),
49
- ("intermediate/dense", "intermediate"),
50
- ("output/dense", "output"),
51
- ("output/LayerNorm", "output_layer_norm"),
52
- ("pooler/dense", "pooler_transform"),
53
- ("cls/predictions", "bert/cls/predictions"),
54
- ("cls/predictions/output_bias", "cls/predictions/output_bias/bias"),
55
- ("cls/seq_relationship/output_bias", "predictions/transform/logits/bias"),
56
- ("cls/seq_relationship/output_weights",
57
- "predictions/transform/logits/kernel"),
58
- )
59
-
60
- BERT_PERMUTATIONS = ()
61
-
62
- BERT_V2_PERMUTATIONS = (("cls/seq_relationship/output_weights", (1, 0)),)
63
-
64
-
65
- def _bert_name_replacement(var_name, name_replacements):
66
- """Gets the variable name replacement."""
67
- for src_pattern, tgt_pattern in name_replacements:
68
- if src_pattern in var_name:
69
- old_var_name = var_name
70
- var_name = var_name.replace(src_pattern, tgt_pattern)
71
- tf.logging.info("Converted: %s --> %s", old_var_name, var_name)
72
- return var_name
73
-
74
-
75
- def _has_exclude_patterns(name, exclude_patterns):
76
- """Checks if a string contains substrings that match patterns to exclude."""
77
- for p in exclude_patterns:
78
- if p in name:
79
- return True
80
- return False
81
-
82
-
83
- def _get_permutation(name, permutations):
84
- """Checks whether a variable requires transposition by pattern matching."""
85
- for src_pattern, permutation in permutations:
86
- if src_pattern in name:
87
- tf.logging.info("Permuted: %s --> %s", name, permutation)
88
- return permutation
89
-
90
- return None
91
-
92
-
93
- def _get_new_shape(name, shape, num_heads):
94
- """Checks whether a variable requires reshape by pattern matching."""
95
- if "self_attention/attention_output/kernel" in name:
96
- return tuple([num_heads, shape[0] // num_heads, shape[1]])
97
- if "self_attention/attention_output/bias" in name:
98
- return shape
99
-
100
- patterns = [
101
- "self_attention/query", "self_attention/value", "self_attention/key"
102
- ]
103
- for pattern in patterns:
104
- if pattern in name:
105
- if "kernel" in name:
106
- return tuple([shape[0], num_heads, shape[1] // num_heads])
107
- if "bias" in name:
108
- return tuple([num_heads, shape[0] // num_heads])
109
- return None
110
-
111
-
112
- def create_v2_checkpoint(model,
113
- src_checkpoint,
114
- output_path,
115
- checkpoint_model_name="model"):
116
- """Converts a name-based matched TF V1 checkpoint to TF V2 checkpoint."""
117
- # Uses streaming-restore in eager model to read V1 name-based checkpoints.
118
- model.load_weights(src_checkpoint).assert_existing_objects_matched()
119
- if hasattr(model, "checkpoint_items"):
120
- checkpoint_items = model.checkpoint_items
121
- else:
122
- checkpoint_items = {}
123
-
124
- checkpoint_items[checkpoint_model_name] = model
125
- checkpoint = tf.train.Checkpoint(**checkpoint_items)
126
- checkpoint.save(output_path)
127
-
128
-
129
- def convert(checkpoint_from_path,
130
- checkpoint_to_path,
131
- num_heads,
132
- name_replacements,
133
- permutations,
134
- exclude_patterns=None):
135
- """Migrates the names of variables within a checkpoint.
136
-
137
- Args:
138
- checkpoint_from_path: Path to source checkpoint to be read in.
139
- checkpoint_to_path: Path to checkpoint to be written out.
140
- num_heads: The number of heads of the model.
141
- name_replacements: A list of tuples of the form (match_str, replace_str)
142
- describing variable names to adjust.
143
- permutations: A list of tuples of the form (match_str, permutation)
144
- describing permutations to apply to given variables. Note that match_str
145
- should match the original variable name, not the replaced one.
146
- exclude_patterns: A list of string patterns to exclude variables from
147
- checkpoint conversion.
148
-
149
- Returns:
150
- A dictionary that maps the new variable names to the Variable objects.
151
- A dictionary that maps the old variable names to the new variable names.
152
- """
153
- with tf.Graph().as_default():
154
- tf.logging.info("Reading checkpoint_from_path %s", checkpoint_from_path)
155
- reader = tf.train.NewCheckpointReader(checkpoint_from_path)
156
- name_shape_map = reader.get_variable_to_shape_map()
157
- new_variable_map = {}
158
- conversion_map = {}
159
- for var_name in name_shape_map:
160
- if exclude_patterns and _has_exclude_patterns(var_name, exclude_patterns):
161
- continue
162
- # Get the original tensor data.
163
- tensor = reader.get_tensor(var_name)
164
-
165
- # Look up the new variable name, if any.
166
- new_var_name = _bert_name_replacement(var_name, name_replacements)
167
-
168
- # See if we need to reshape the underlying tensor.
169
- new_shape = None
170
- if num_heads > 0:
171
- new_shape = _get_new_shape(new_var_name, tensor.shape, num_heads)
172
- if new_shape:
173
- tf.logging.info("Veriable %s has a shape change from %s to %s",
174
- var_name, tensor.shape, new_shape)
175
- tensor = np.reshape(tensor, new_shape)
176
-
177
- # See if we need to permute the underlying tensor.
178
- permutation = _get_permutation(var_name, permutations)
179
- if permutation:
180
- tensor = np.transpose(tensor, permutation)
181
-
182
- # Create a new variable with the possibly-reshaped or transposed tensor.
183
- var = tf.Variable(tensor, name=var_name)
184
-
185
- # Save the variable into the new variable map.
186
- new_variable_map[new_var_name] = var
187
-
188
- # Keep a list of converter variables for sanity checking.
189
- if new_var_name != var_name:
190
- conversion_map[var_name] = new_var_name
191
-
192
- saver = tf.train.Saver(new_variable_map)
193
-
194
- with tf.Session() as sess:
195
- sess.run(tf.global_variables_initializer())
196
- tf.logging.info("Writing checkpoint_to_path %s", checkpoint_to_path)
197
- saver.save(sess, checkpoint_to_path, write_meta_graph=False)
198
-
199
- tf.logging.info("Summary:")
200
- tf.logging.info(" Converted %d variable name(s).", len(new_variable_map))
201
- tf.logging.info(" Converted: %s", str(conversion_map))