aapot commited on
Commit
69bb3e4
·
1 Parent(s): 459f258

Add 500k step pytorch model

Browse files
config.json ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "./",
3
+ "architectures": [
4
+ "T5ForConditionalGeneration"
5
+ ],
6
+ "d_ff": 1536,
7
+ "d_kv": 64,
8
+ "d_model": 384,
9
+ "decoder_start_token_id": 0,
10
+ "dropout_rate": 0.0,
11
+ "eos_token_id": 1,
12
+ "feed_forward_proj": "gated-gelu",
13
+ "initializer_factor": 1.0,
14
+ "is_encoder_decoder": true,
15
+ "layer_norm_epsilon": 1e-06,
16
+ "model_type": "t5",
17
+ "n_positions": 512,
18
+ "num_decoder_layers": 8,
19
+ "num_heads": 8,
20
+ "num_layers": 8,
21
+ "output_past": true,
22
+ "pad_token_id": 0,
23
+ "relative_attention_max_distance": 128,
24
+ "relative_attention_num_buckets": 32,
25
+ "tie_word_embeddings": false,
26
+ "torch_dtype": "float32",
27
+ "transformers_version": "4.17.0",
28
+ "use_cache": true,
29
+ "vocab_size": 32128
30
+ }
convert_t5x_checkpoint_to_flax.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://gist.github.com/stefan-it/30e4998ef159f33696e377a46f699d9f
2
+
3
+ import argparse
4
+
5
+ from t5x import checkpoints
6
+ from transformers import T5Config, FlaxT5ForConditionalGeneration
7
+
8
+
9
+ def convert_t5x_checkpoint_to_flax(t5x_checkpoint_path, config_name, flax_dump_folder_path):
10
+ config = T5Config.from_pretrained(config_name)
11
+ flax_model = FlaxT5ForConditionalGeneration(config=config)
12
+ t5x_model = checkpoints.load_t5x_checkpoint(t5x_checkpoint_path)
13
+
14
+ split_mlp_wi = "wi_0" in t5x_model["target"]["encoder"]["layers_0"]["mlp"]
15
+
16
+ # Encoder
17
+ for layer_index in range(config.num_layers):
18
+ layer_name = f"layers_{str(layer_index)}"
19
+
20
+ # Self-Attention
21
+ t5x_attention_key = t5x_model["target"]["encoder"][layer_name]["attention"]["key"]["kernel"]
22
+ t5x_attention_out = t5x_model["target"]["encoder"][layer_name]["attention"]["out"]["kernel"]
23
+ t5x_attention_query = t5x_model["target"]["encoder"][layer_name]["attention"]["query"]["kernel"]
24
+ t5x_attention_value = t5x_model["target"]["encoder"][layer_name]["attention"]["value"]["kernel"]
25
+
26
+ ## Layer Normalization
27
+ t5x_attention_layer_norm = t5x_model["target"]["encoder"][layer_name]["pre_attention_layer_norm"]["scale"]
28
+
29
+ if split_mlp_wi:
30
+ t5x_mlp_wi_0 = t5x_model["target"]["encoder"][layer_name]["mlp"]["wi_0"]["kernel"]
31
+ t5x_mlp_wi_1 = t5x_model["target"]["encoder"][layer_name]["mlp"]["wi_1"]["kernel"]
32
+ else:
33
+ t5x_mlp_wi = t5x_model["target"]["encoder"][layer_name]["mlp"]["wi"]["kernel"]
34
+
35
+ t5x_mlp_wo = t5x_model["target"]["encoder"][layer_name]["mlp"]["wo"]["kernel"]
36
+
37
+ ## Layer Normalization
38
+ t5x_mlp_layer_norm = t5x_model["target"]["encoder"][layer_name]["pre_mlp_layer_norm"]["scale"]
39
+
40
+ # Assigning
41
+ flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["0"]["SelfAttention"]["k"]["kernel"] = t5x_attention_key
42
+ flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["0"]["SelfAttention"]["o"]["kernel"] = t5x_attention_out
43
+ flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["0"]["SelfAttention"]["q"]["kernel"] = t5x_attention_query
44
+ flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["0"]["SelfAttention"]["v"]["kernel"] = t5x_attention_value
45
+
46
+ flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["0"]["layer_norm"]["weight"] = t5x_attention_layer_norm
47
+
48
+ if split_mlp_wi:
49
+ flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["1"]["DenseReluDense"]["wi_0"]["kernel"] = t5x_mlp_wi_0
50
+ flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["1"]["DenseReluDense"]["wi_1"]["kernel"] = t5x_mlp_wi_1
51
+ else:
52
+ flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["1"]["DenseReluDense"]["wi"]["kernel"] = t5x_mlp_wi
53
+
54
+ flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["1"]["DenseReluDense"]["wo"]["kernel"] = t5x_mlp_wo
55
+ flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["1"]["layer_norm"]["weight"] = t5x_mlp_layer_norm
56
+
57
+ # Only for layer 0:
58
+ t5x_encoder_rel_embedding = t5x_model["target"]["encoder"]["relpos_bias"]["rel_embedding"].T
59
+ flax_model.params["encoder"]["block"]["0"]["layer"]["0"]["SelfAttention"]["relative_attention_bias"]["embedding"] = t5x_encoder_rel_embedding
60
+
61
+ # Assigning
62
+ t5x_encoder_norm = t5x_model["target"]["encoder"]["encoder_norm"]["scale"]
63
+ flax_model.params["encoder"]["final_layer_norm"]["weight"] = t5x_encoder_norm
64
+
65
+ # Decoder
66
+ for layer_index in range(config.num_layers):
67
+ layer_name = f"layers_{str(layer_index)}"
68
+
69
+ # Self-Attention
70
+ t5x_attention_key = t5x_model["target"]["decoder"][layer_name]["self_attention"]["key"]["kernel"]
71
+ t5x_attention_out = t5x_model["target"]["decoder"][layer_name]["self_attention"]["out"]["kernel"]
72
+ t5x_attention_query = t5x_model["target"]["decoder"][layer_name]["self_attention"]["query"]["kernel"]
73
+ t5x_attention_value = t5x_model["target"]["decoder"][layer_name]["self_attention"]["value"]["kernel"]
74
+
75
+ ## Layer Normalization
76
+ t5x_pre_attention_layer_norm = t5x_model["target"]["decoder"][layer_name]["pre_self_attention_layer_norm"]["scale"]
77
+
78
+ # Encoder-Decoder-Attention
79
+ t5x_enc_dec_attention_key = t5x_model["target"]["decoder"][layer_name]["encoder_decoder_attention"]["key"]["kernel"]
80
+ t5x_enc_dec_attention_out = t5x_model["target"]["decoder"][layer_name]["encoder_decoder_attention"]["out"]["kernel"]
81
+ t5x_enc_dec_attention_query = t5x_model["target"]["decoder"][layer_name]["encoder_decoder_attention"]["query"]["kernel"]
82
+ t5x_enc_dec_attention_value = t5x_model["target"]["decoder"][layer_name]["encoder_decoder_attention"]["value"]["kernel"]
83
+
84
+ ## Layer Normalization
85
+ t5x_cross_layer_norm = t5x_model["target"]["decoder"][layer_name]["pre_cross_attention_layer_norm"]["scale"]
86
+
87
+ # MLP
88
+ if split_mlp_wi:
89
+ t5x_mlp_wi_0 = t5x_model["target"]["decoder"][layer_name]["mlp"]["wi_0"]["kernel"]
90
+ t5x_mlp_wi_1 = t5x_model["target"]["decoder"][layer_name]["mlp"]["wi_1"]["kernel"]
91
+ else:
92
+ t5x_mlp_wi = t5x_model["target"]["decoder"][layer_name]["mlp"]["wi"]["kernel"]
93
+
94
+ t5x_mlp_wo = t5x_model["target"]["decoder"][layer_name]["mlp"]["wo"]["kernel"]
95
+
96
+ ## Layer Normalization
97
+ tx5_mlp_layer_norm = t5x_model["target"]["decoder"][layer_name]["pre_mlp_layer_norm"]["scale"]
98
+
99
+ # Assigning
100
+ flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["0"]["SelfAttention"]["k"]["kernel"] = t5x_attention_key
101
+ flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["0"]["SelfAttention"]["o"]["kernel"] = t5x_attention_out
102
+ flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["0"]["SelfAttention"]["q"]["kernel"] = t5x_attention_query
103
+ flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["0"]["SelfAttention"]["v"]["kernel"] = t5x_attention_value
104
+
105
+ flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["0"]["layer_norm"]["weight"] = t5x_pre_attention_layer_norm
106
+
107
+ flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["1"]["EncDecAttention"]["k"]["kernel"] = t5x_enc_dec_attention_key
108
+ flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["1"]["EncDecAttention"]["o"]["kernel"] = t5x_enc_dec_attention_out
109
+ flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["1"]["EncDecAttention"]["q"]["kernel"] = t5x_enc_dec_attention_query
110
+ flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["1"]["EncDecAttention"]["v"]["kernel"] = t5x_enc_dec_attention_value
111
+
112
+ flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["1"]["layer_norm"]["weight"] = t5x_cross_layer_norm
113
+
114
+ if split_mlp_wi:
115
+ flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["2"]["DenseReluDense"]["wi_0"]["kernel"] = t5x_mlp_wi_0
116
+ flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["2"]["DenseReluDense"]["wi_1"]["kernel"] = t5x_mlp_wi_1
117
+ else:
118
+ flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["2"]["DenseReluDense"]["wi"]["kernel"] = t5x_mlp_wi
119
+
120
+ flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["2"]["DenseReluDense"]["wo"]["kernel"] = t5x_mlp_wo
121
+
122
+ flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["2"]["layer_norm"]["weight"] = tx5_mlp_layer_norm
123
+
124
+ # Decoder Normalization
125
+ tx5_decoder_norm = t5x_model["target"]["decoder"]["decoder_norm"]["scale"]
126
+ flax_model.params["decoder"]["final_layer_norm"]["weight"] = tx5_decoder_norm
127
+
128
+ # Only for layer 0:
129
+ t5x_decoder_rel_embedding = t5x_model["target"]["decoder"]["relpos_bias"]["rel_embedding"].T
130
+ flax_model.params["decoder"]["block"]["0"]["layer"]["0"]["SelfAttention"]["relative_attention_bias"]["embedding"] = t5x_decoder_rel_embedding
131
+
132
+ # Token Embeddings
133
+ tx5_token_embeddings = t5x_model["target"]["token_embedder"]["embedding"]
134
+ flax_model.params["shared"]["embedding"] = tx5_token_embeddings
135
+
136
+ # LM Head
137
+ flax_model.params["lm_head"]["kernel"] = t5x_model["target"]["decoder"]["logits_dense"]["kernel"]
138
+
139
+ flax_model.save_pretrained(flax_dump_folder_path)
140
+ print("T5X Model was sucessfully converted!")
141
+
142
+
143
+ if __name__ == "__main__":
144
+ parser = argparse.ArgumentParser()
145
+ # Required parameters
146
+ parser.add_argument(
147
+ "--t5x_checkpoint_path", default=None, type=str, required=True, help="Path the TX5 checkpoint."
148
+ )
149
+ parser.add_argument(
150
+ "--config_name", default=None, type=str, required=True, help="Config name of T5 model."
151
+ )
152
+ parser.add_argument(
153
+ "--flax_dump_folder_path", default=None, type=str, required=True, help="Path to the output FLAX model."
154
+ )
155
+ args = parser.parse_args()
156
+ convert_t5x_checkpoint_to_flax(args.t5x_checkpoint_path, args.config_name, args.flax_dump_folder_path)
157
+
flax_model.msgpack ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a6569d61dd4c170e1cb81f0196ceb6130831c2bd535c170484a4608b7b6d31af
3
+ size 287515687
flax_model_to_pytorch.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoModelForSeq2SeqLM, FlaxAutoModelForSeq2SeqLM, AutoTokenizer
2
+ import torch
3
+ import numpy as np
4
+ import jax
5
+ import jax.numpy as jnp
6
+
7
+ def to_f32(t):
8
+ return jax.tree_map(lambda x: x.astype(jnp.float32) if x.dtype == jnp.bfloat16 else x, t)
9
+
10
+ jax.config.update('jax_platform_name', 'cpu')
11
+ MODEL_PATH = "./"
12
+ model = FlaxAutoModelForSeq2SeqLM.from_pretrained(MODEL_PATH)
13
+ model.params = to_f32(model.params)
14
+ model.save_pretrained(MODEL_PATH)
15
+
16
+ pt_model = AutoModelForSeq2SeqLM.from_pretrained(
17
+ MODEL_PATH, from_flax=True).to('cpu')
18
+
19
+ input_ids = np.asarray(2 * [128 * [0]], dtype=np.int32)
20
+ input_ids_pt = torch.tensor(input_ids)
21
+
22
+ logits_pt = pt_model(input_ids=input_ids_pt, decoder_input_ids=input_ids_pt).logits
23
+ print(logits_pt)
24
+ logits_fx = model(input_ids=input_ids, decoder_input_ids=input_ids).logits
25
+ print(logits_fx)
26
+
27
+ pt_model.save_pretrained(MODEL_PATH)
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0f5bc318200d102d8d3ac2bea0df1f726baa1cbc87cee7b8db532f125e81fdee
3
+ size 287594521