Dr. Jorge Abreu Vicente
commited on
Commit
•
26d020f
1
Parent(s):
561cc87
Create convert_biomegatron_checkpoint.py
Browse files
convert_biomegatron_checkpoint.py
ADDED
@@ -0,0 +1,198 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import json
|
3 |
+
import os
|
4 |
+
import re
|
5 |
+
import zipfile
|
6 |
+
|
7 |
+
import torch
|
8 |
+
|
9 |
+
####################################################################################################
|
10 |
+
# This file is a modification of the original
|
11 |
+
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/megatron_bert/convert_megatron_bert_checkpoint.py
|
12 |
+
|
13 |
+
def recursive_print(name, val, spaces=0):
|
14 |
+
# Format the message.
|
15 |
+
if name is None:
|
16 |
+
msg = None
|
17 |
+
else:
|
18 |
+
fmt = "." * max(0, spaces - 2) + "# {:" + str(50 - spaces) + "s}"
|
19 |
+
msg = fmt.format(name)
|
20 |
+
|
21 |
+
# Print and recurse (if needed).
|
22 |
+
if isinstance(val, dict):
|
23 |
+
if msg is not None:
|
24 |
+
print(msg)
|
25 |
+
for k in val.keys():
|
26 |
+
recursive_print(k, val[k], spaces + 2)
|
27 |
+
elif isinstance(val, torch.Tensor):
|
28 |
+
print(msg, ":", val.size())
|
29 |
+
else:
|
30 |
+
print(msg, ":", val)
|
31 |
+
|
32 |
+
|
33 |
+
def convert_megatron_checkpoint(input_state_dict, head_model=True):
|
34 |
+
# The converted output model.
|
35 |
+
output_state_dict = {}
|
36 |
+
|
37 |
+
# The model.
|
38 |
+
model = input_state_dict["model"]
|
39 |
+
# The language model.
|
40 |
+
lm = model["language_model"]
|
41 |
+
# The embeddings.
|
42 |
+
embeddings = lm["embedding"]
|
43 |
+
|
44 |
+
# The word embeddings.
|
45 |
+
word_embeddings = embeddings["word_embeddings"]["weight"]
|
46 |
+
# Store the word embeddings.
|
47 |
+
output_state_dict["bert.embeddings.word_embeddings.weight"] = word_embeddings
|
48 |
+
|
49 |
+
# The position embeddings.
|
50 |
+
pos_embeddings = embeddings["position_embeddings"]["weight"]
|
51 |
+
# Trained for 512 x 1024.
|
52 |
+
assert pos_embeddings.size(0) == 512 and pos_embeddings.size(1) == 1024
|
53 |
+
# Store the position embeddings.
|
54 |
+
output_state_dict["bert.embeddings.position_embeddings.weight"] = pos_embeddings
|
55 |
+
|
56 |
+
# The token-type embeddings.
|
57 |
+
tokentype_embeddings = embeddings["tokentype_embeddings"]["weight"]
|
58 |
+
# Store the position embeddings.
|
59 |
+
output_state_dict["bert.embeddings.token_type_embeddings.weight"] = tokentype_embeddings
|
60 |
+
|
61 |
+
# The transformer.
|
62 |
+
transformer = lm["transformer"]
|
63 |
+
|
64 |
+
# The regex to extract layer names.
|
65 |
+
layer_re = re.compile("layers\.(\d+)\.([a-z0-9_.]+)\.([a-z]+)")
|
66 |
+
|
67 |
+
# The simple map of names for "automated" rules.
|
68 |
+
megatron_to_transformers = {
|
69 |
+
"attention.dense": ".attention.output.dense.",
|
70 |
+
"mlp.dense_h_to_4h": ".intermediate.dense.",
|
71 |
+
"mlp.dense_4h_to_h": ".output.dense.",
|
72 |
+
}
|
73 |
+
|
74 |
+
# Keep track of the attention/query/value tensor.
|
75 |
+
attention_qkv_weight = None
|
76 |
+
|
77 |
+
# Extract the layers.
|
78 |
+
for key, val in transformer.items():
|
79 |
+
# Match the name.
|
80 |
+
m = layer_re.match(key)
|
81 |
+
|
82 |
+
# Stop if that's not a layer
|
83 |
+
if m is None:
|
84 |
+
break
|
85 |
+
|
86 |
+
# The index of the layer.
|
87 |
+
layer_idx = int(m.group(1))
|
88 |
+
# The name of the operation.
|
89 |
+
op_name = m.group(2)
|
90 |
+
# Is it a weight or a bias?
|
91 |
+
weight_or_bias = m.group(3)
|
92 |
+
|
93 |
+
# The name of the layer.
|
94 |
+
layer_name = f"bert.encoder.layer.{layer_idx}"
|
95 |
+
|
96 |
+
# For layernorm(s), simply store the layer norm.
|
97 |
+
if op_name.endswith("layernorm"):
|
98 |
+
|
99 |
+
ln_name = "attention.ln" if op_name.startswith("input") else "ln"
|
100 |
+
output_state_dict[layer_name + "." + ln_name + "." + weight_or_bias] = val
|
101 |
+
|
102 |
+
# Transpose the QKV matrix.
|
103 |
+
elif op_name == "attention.query_key_value" and weight_or_bias == "weight":
|
104 |
+
|
105 |
+
# Make sure the QKV pointer is nil.
|
106 |
+
assert attention_qkv_weight is None, ""
|
107 |
+
|
108 |
+
# Store the tensor as we need the bias as well to interleave QKV and biases.
|
109 |
+
attention_qkv_weight = val
|
110 |
+
|
111 |
+
# Transpose the bias.
|
112 |
+
elif op_name == "attention.query_key_value" and weight_or_bias == "bias":
|
113 |
+
|
114 |
+
# Make sure we read the weight tensor.
|
115 |
+
assert attention_qkv_weight is not None, ""
|
116 |
+
|
117 |
+
# Split the QKV matrix into Q, K and V. Megatron stores Q,K,V interleaved.
|
118 |
+
q = attention_qkv_weight[0 * 1024 : 1 * 1024, :]
|
119 |
+
k = attention_qkv_weight[1 * 1024 : 2 * 1024, :]
|
120 |
+
v = attention_qkv_weight[2 * 1024 : 3 * 1024, :]
|
121 |
+
|
122 |
+
# Split the bias.
|
123 |
+
q_bias = val[0 * 1024 : 1 * 1024]
|
124 |
+
k_bias = val[1 * 1024 : 2 * 1024]
|
125 |
+
v_bias = val[2 * 1024 : 3 * 1024]
|
126 |
+
|
127 |
+
# Store.
|
128 |
+
output_state_dict[f"{layer_name}.attention.self.query.weight"] = q
|
129 |
+
output_state_dict[f"{layer_name}.attention.self.query.bias"] = q_bias
|
130 |
+
output_state_dict[f"{layer_name}.attention.self.key.weight"] = k
|
131 |
+
output_state_dict[f"{layer_name}.attention.self.key.bias"] = k_bias
|
132 |
+
output_state_dict[f"{layer_name}.attention.self.value.weight"] = v
|
133 |
+
output_state_dict[f"{layer_name}.attention.self.value.bias"] = v_bias
|
134 |
+
|
135 |
+
# Clear the stored tensor.
|
136 |
+
attention_qkv_weight = None
|
137 |
+
|
138 |
+
# Copy weights and biases as is.
|
139 |
+
elif weight_or_bias in ["weight", "bias"]:
|
140 |
+
|
141 |
+
out_name = megatron_to_transformers[op_name]
|
142 |
+
output_state_dict[layer_name + out_name + weight_or_bias] = val
|
143 |
+
|
144 |
+
# The final layernorm.
|
145 |
+
output_state_dict["bert.encoder.ln.weight"] = transformer["final_layernorm.weight"]
|
146 |
+
output_state_dict["bert.encoder.ln.bias"] = transformer["final_layernorm.bias"]
|
147 |
+
|
148 |
+
# The config.
|
149 |
+
output_config = {
|
150 |
+
"vocab_size": word_embeddings.size(0),
|
151 |
+
"hidden_size": 1024,
|
152 |
+
"num_hidden_layers": 24,
|
153 |
+
"num_attention_heads": 16,
|
154 |
+
"hidden_act": "gelu_new",
|
155 |
+
"intermediate_size": 4096,
|
156 |
+
"hidden_dropout_prob": 0.1,
|
157 |
+
"attention_probs_dropout_prob": 0.1,
|
158 |
+
"max_position_embeddings": 512,
|
159 |
+
"type_vocab_size": 2,
|
160 |
+
"initializer_range": 0.2,
|
161 |
+
"layer_norm_eps": 1e-12,
|
162 |
+
"position_embedding_type": "absolute",
|
163 |
+
"use_cache": False,
|
164 |
+
"model_type": "megatron-bert",
|
165 |
+
}
|
166 |
+
|
167 |
+
if head_model:
|
168 |
+
# The pooler.
|
169 |
+
pooler = lm["pooler"]
|
170 |
+
|
171 |
+
# Store the matrix and the bias.
|
172 |
+
output_state_dict["bert.pooler.dense.weight"] = pooler["dense.weight"]
|
173 |
+
output_state_dict["bert.pooler.dense.bias"] = pooler["dense.bias"]
|
174 |
+
|
175 |
+
# The LM head from Megatron (for RACE).
|
176 |
+
lm_head = model["lm_head"]
|
177 |
+
|
178 |
+
# The transform matrix.
|
179 |
+
output_state_dict["cls.predictions.transform.dense.weight"] = lm_head["dense.weight"]
|
180 |
+
output_state_dict["cls.predictions.transform.dense.bias"] = lm_head["dense.bias"]
|
181 |
+
|
182 |
+
# The transform LN.
|
183 |
+
output_state_dict["cls.predictions.transform.LayerNorm.weight"] = lm_head["layernorm.weight"]
|
184 |
+
output_state_dict["cls.predictions.transform.LayerNorm.bias"] = lm_head["layernorm.bias"]
|
185 |
+
|
186 |
+
# For the decoder, we replicate the weights.
|
187 |
+
output_state_dict["cls.predictions.decoder.weight"] = word_embeddings
|
188 |
+
output_state_dict["cls.predictions.bias"] = lm_head["bias"]
|
189 |
+
|
190 |
+
# The classifier from Megatron (for MLNI).
|
191 |
+
binary_head = model["binary_head"]
|
192 |
+
|
193 |
+
# Store the classifier.
|
194 |
+
output_state_dict["cls.seq_relationship.weight"] = binary_head["weight"]
|
195 |
+
output_state_dict["cls.seq_relationship.bias"] = binary_head["bias"]
|
196 |
+
|
197 |
+
# It should be done!
|
198 |
+
return output_state_dict, output_config
|