aapot commited on
Commit
7460603
1 Parent(s): 5d1d3e2

Add 50k train step model

Browse files
config.gin ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __gin__ import dynamic_registration
2
+ import __main__ as train_script
3
+ import seqio
4
+ from t5x import adafactor
5
+ from t5x.examples.scalable_t5 import network
6
+ from t5x import gin_utils
7
+ from t5x import models
8
+ from t5x import partitioning
9
+ from t5x import trainer
10
+ from t5x import utils
11
+ import tasks
12
+
13
+ # Macros:
14
+ # ==============================================================================
15
+ BATCH_SIZE = 256
16
+ DROPOUT_RATE = 0.0
17
+ LABEL_SMOOTHING = 0.0
18
+ LOSS_NORMALIZING_FACTOR = None
19
+ MIXTURE_OR_TASK_MODULE = None
20
+ MIXTURE_OR_TASK_NAME = 'pretrain_finnish'
21
+ MODEL = @models.EncoderDecoderModel()
22
+ MODEL_DIR = '/researchdisk/t5-small-nl16-finnish'
23
+ OPTIMIZER = @adafactor.Adafactor()
24
+ RANDOM_SEED = None
25
+ SHUFFLE_TRAIN_EXAMPLES = True
26
+ TASK_FEATURE_LENGTHS = {'inputs': 512, 'targets': 512}
27
+ TRAIN_STEPS = 500000
28
+ USE_CACHED_TASKS = False
29
+ USE_HARDWARE_RNG = False
30
+ VOCABULARY = @seqio.SentencePieceVocabulary()
31
+ Z_LOSS = 0.0001
32
+
33
+ # Parameters for adafactor.Adafactor:
34
+ # ==============================================================================
35
+ adafactor.Adafactor.decay_rate = 0.8
36
+ adafactor.Adafactor.logical_factor_rules = \
37
+ @adafactor.standard_logical_factor_rules()
38
+ adafactor.Adafactor.step_offset = 0
39
+
40
+ # Parameters for utils.CheckpointConfig:
41
+ # ==============================================================================
42
+ utils.CheckpointConfig.restore = @utils.RestoreCheckpointConfig()
43
+ utils.CheckpointConfig.save = @utils.SaveCheckpointConfig()
44
+
45
+ # Parameters for utils.create_learning_rate_scheduler:
46
+ # ==============================================================================
47
+ utils.create_learning_rate_scheduler.base_learning_rate = 1.0
48
+ utils.create_learning_rate_scheduler.factors = 'constant * rsqrt_decay'
49
+ utils.create_learning_rate_scheduler.warmup_steps = 10000
50
+
51
+ # Parameters for train/utils.DatasetConfig:
52
+ # ==============================================================================
53
+ train/utils.DatasetConfig.batch_size = %BATCH_SIZE
54
+ train/utils.DatasetConfig.mixture_or_task_name = %MIXTURE_OR_TASK_NAME
55
+ train/utils.DatasetConfig.module = %MIXTURE_OR_TASK_MODULE
56
+ train/utils.DatasetConfig.pack = True
57
+ train/utils.DatasetConfig.seed = None
58
+ train/utils.DatasetConfig.shuffle = %SHUFFLE_TRAIN_EXAMPLES
59
+ train/utils.DatasetConfig.split = 'train'
60
+ train/utils.DatasetConfig.task_feature_lengths = %TASK_FEATURE_LENGTHS
61
+ train/utils.DatasetConfig.use_cached = %USE_CACHED_TASKS
62
+
63
+ # Parameters for train_eval/utils.DatasetConfig:
64
+ # ==============================================================================
65
+ train_eval/utils.DatasetConfig.batch_size = %BATCH_SIZE
66
+ train_eval/utils.DatasetConfig.mixture_or_task_name = %MIXTURE_OR_TASK_NAME
67
+ train_eval/utils.DatasetConfig.module = %MIXTURE_OR_TASK_MODULE
68
+ train_eval/utils.DatasetConfig.pack = True
69
+ train_eval/utils.DatasetConfig.seed = 42
70
+ train_eval/utils.DatasetConfig.shuffle = False
71
+ train_eval/utils.DatasetConfig.split = 'validation'
72
+ train_eval/utils.DatasetConfig.task_feature_lengths = %TASK_FEATURE_LENGTHS
73
+ train_eval/utils.DatasetConfig.use_cached = %USE_CACHED_TASKS
74
+
75
+ # Parameters for models.EncoderDecoderModel:
76
+ # ==============================================================================
77
+ models.EncoderDecoderModel.input_vocabulary = %VOCABULARY
78
+ models.EncoderDecoderModel.label_smoothing = %LABEL_SMOOTHING
79
+ models.EncoderDecoderModel.loss_normalizing_factor = %LOSS_NORMALIZING_FACTOR
80
+ models.EncoderDecoderModel.module = @network.Transformer()
81
+ models.EncoderDecoderModel.optimizer_def = %OPTIMIZER
82
+ models.EncoderDecoderModel.output_vocabulary = %VOCABULARY
83
+ models.EncoderDecoderModel.z_loss = %Z_LOSS
84
+
85
+ # Parameters for partitioning.PjitPartitioner:
86
+ # ==============================================================================
87
+ partitioning.PjitPartitioner.logical_axis_rules = \
88
+ @partitioning.standard_logical_axis_rules()
89
+ partitioning.PjitPartitioner.model_parallel_submesh = None
90
+ partitioning.PjitPartitioner.num_partitions = 1
91
+
92
+ # Parameters for utils.RestoreCheckpointConfig:
93
+ # ==============================================================================
94
+ utils.RestoreCheckpointConfig.path = []
95
+
96
+ # Parameters for utils.SaveCheckpointConfig:
97
+ # ==============================================================================
98
+ utils.SaveCheckpointConfig.dtype = 'float32'
99
+ utils.SaveCheckpointConfig.keep = 10
100
+ utils.SaveCheckpointConfig.period = 10000
101
+ utils.SaveCheckpointConfig.save_dataset = False
102
+
103
+ # Parameters for seqio.SentencePieceVocabulary:
104
+ # ==============================================================================
105
+ seqio.SentencePieceVocabulary.sentencepiece_model_file = 'spiece.model'
106
+
107
+ # Parameters for network.T5Config:
108
+ # ==============================================================================
109
+ network.T5Config.dropout_rate = %DROPOUT_RATE
110
+ network.T5Config.dtype = 'bfloat16'
111
+ network.T5Config.emb_dim = 512
112
+ network.T5Config.head_dim = 64
113
+ network.T5Config.logits_via_embedding = False
114
+ network.T5Config.mlp_activations = ('gelu', 'linear')
115
+ network.T5Config.mlp_dim = 2048
116
+ network.T5Config.num_decoder_layers = 16
117
+ network.T5Config.num_encoder_layers = 16
118
+ network.T5Config.num_heads = 8
119
+ network.T5Config.remat_policy = 'minimal'
120
+ network.T5Config.scan_layers = True
121
+ network.T5Config.vocab_size = 32128
122
+
123
+ # Parameters for train_script.train:
124
+ # ==============================================================================
125
+ train_script.train.checkpoint_cfg = @utils.CheckpointConfig()
126
+ train_script.train.eval_period = 10000
127
+ train_script.train.eval_steps = 20
128
+ train_script.train.infer_eval_dataset_cfg = None
129
+ train_script.train.model = %MODEL
130
+ train_script.train.model_dir = %MODEL_DIR
131
+ train_script.train.partitioner = @partitioning.PjitPartitioner()
132
+ train_script.train.random_seed = %RANDOM_SEED
133
+ train_script.train.summarize_config_fn = @gin_utils.summarize_gin_config
134
+ train_script.train.total_steps = %TRAIN_STEPS
135
+ train_script.train.train_dataset_cfg = @train/utils.DatasetConfig()
136
+ train_script.train.train_eval_dataset_cfg = @train_eval/utils.DatasetConfig()
137
+ train_script.train.trainer_cls = @trainer.Trainer
138
+ train_script.train.use_gda = False
139
+ train_script.train.use_hardware_rng = %USE_HARDWARE_RNG
140
+
141
+ # Parameters for trainer.Trainer:
142
+ # ==============================================================================
143
+ trainer.Trainer.learning_rate_fn = @utils.create_learning_rate_scheduler()
144
+ trainer.Trainer.num_microbatches = None
145
+
146
+ # Parameters for network.Transformer:
147
+ # ==============================================================================
148
+ network.Transformer.config = @network.T5Config()
config.json CHANGED
@@ -1,5 +1,5 @@
1
  {
2
- "_name_or_path": "Finnish-NLP/t5-small-nl16-finnish",
3
  "architectures": [
4
  "T5ForConditionalGeneration"
5
  ],
@@ -26,7 +26,7 @@
26
  "relative_attention_num_buckets": 32,
27
  "tie_word_embeddings": false,
28
  "torch_dtype": "float32",
29
- "transformers_version": "4.20.1",
30
  "use_cache": true,
31
  "vocab_size": 32128
32
  }
 
1
  {
2
+ "_name_or_path": "/researchdisk/t5-small-nl16-finnish",
3
  "architectures": [
4
  "T5ForConditionalGeneration"
5
  ],
 
26
  "relative_attention_num_buckets": 32,
27
  "tie_word_embeddings": false,
28
  "torch_dtype": "float32",
29
+ "transformers_version": "4.21.2",
30
  "use_cache": true,
31
  "vocab_size": 32128
32
  }
convert_t5x_checkpoint_to_flax.py CHANGED
@@ -3,7 +3,8 @@
3
  import argparse
4
 
5
  from t5x import checkpoints
6
- from transformers import T5Config, FlaxT5ForConditionalGeneration
 
7
 
8
 
9
  def convert_t5x_checkpoint_to_flax(t5x_checkpoint_path, config_name, flax_dump_folder_path):
@@ -11,37 +12,36 @@ def convert_t5x_checkpoint_to_flax(t5x_checkpoint_path, config_name, flax_dump_f
11
  flax_model = FlaxT5ForConditionalGeneration(config=config)
12
  t5x_model = checkpoints.load_t5x_checkpoint(t5x_checkpoint_path)
13
 
14
- split_mlp_wi = "wi_0" in t5x_model["target"]["encoder"]["layers_0"]["mlp"]
15
 
16
  # Encoder
17
  for layer_index in range(config.num_layers):
18
- layer_name = f"layers_{str(layer_index)}"
19
 
20
  # Self-Attention
21
- t5x_attention_key = t5x_model["target"]["encoder"][layer_name]["attention"]["key"]["kernel"]
22
- t5x_attention_out = t5x_model["target"]["encoder"][layer_name]["attention"]["out"]["kernel"]
23
- t5x_attention_query = t5x_model["target"]["encoder"][layer_name]["attention"]["query"]["kernel"]
24
- t5x_attention_value = t5x_model["target"]["encoder"][layer_name]["attention"]["value"]["kernel"]
25
 
26
  ## Layer Normalization
27
- t5x_attention_layer_norm = t5x_model["target"]["encoder"][layer_name]["pre_attention_layer_norm"]["scale"]
28
 
29
  if split_mlp_wi:
30
- t5x_mlp_wi_0 = t5x_model["target"]["encoder"][layer_name]["mlp"]["wi_0"]["kernel"]
31
- t5x_mlp_wi_1 = t5x_model["target"]["encoder"][layer_name]["mlp"]["wi_1"]["kernel"]
32
  else:
33
- t5x_mlp_wi = t5x_model["target"]["encoder"][layer_name]["mlp"]["wi"]["kernel"]
34
 
35
- t5x_mlp_wo = t5x_model["target"]["encoder"][layer_name]["mlp"]["wo"]["kernel"]
36
 
37
  ## Layer Normalization
38
- t5x_mlp_layer_norm = t5x_model["target"]["encoder"][layer_name]["pre_mlp_layer_norm"]["scale"]
39
 
40
  # Assigning
41
- flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["0"]["SelfAttention"]["k"]["kernel"] = t5x_attention_key
42
- flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["0"]["SelfAttention"]["o"]["kernel"] = t5x_attention_out
43
- flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["0"]["SelfAttention"]["q"]["kernel"] = t5x_attention_query
44
- flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["0"]["SelfAttention"]["v"]["kernel"] = t5x_attention_value
45
 
46
  flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["0"]["layer_norm"]["weight"] = t5x_attention_layer_norm
47
 
@@ -55,59 +55,58 @@ def convert_t5x_checkpoint_to_flax(t5x_checkpoint_path, config_name, flax_dump_f
55
  flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["1"]["layer_norm"]["weight"] = t5x_mlp_layer_norm
56
 
57
  # Only for layer 0:
58
- t5x_encoder_rel_embedding = t5x_model["target"]["encoder"]["relpos_bias"]["rel_embedding"].T
59
- flax_model.params["encoder"]["block"]["0"]["layer"]["0"]["SelfAttention"]["relative_attention_bias"]["embedding"] = t5x_encoder_rel_embedding
60
 
61
  # Assigning
62
  t5x_encoder_norm = t5x_model["target"]["encoder"]["encoder_norm"]["scale"]
63
  flax_model.params["encoder"]["final_layer_norm"]["weight"] = t5x_encoder_norm
64
 
65
  # Decoder
66
- for layer_index in range(config.num_layers):
67
- layer_name = f"layers_{str(layer_index)}"
68
 
69
  # Self-Attention
70
- t5x_attention_key = t5x_model["target"]["decoder"][layer_name]["self_attention"]["key"]["kernel"]
71
- t5x_attention_out = t5x_model["target"]["decoder"][layer_name]["self_attention"]["out"]["kernel"]
72
- t5x_attention_query = t5x_model["target"]["decoder"][layer_name]["self_attention"]["query"]["kernel"]
73
- t5x_attention_value = t5x_model["target"]["decoder"][layer_name]["self_attention"]["value"]["kernel"]
74
 
75
  ## Layer Normalization
76
- t5x_pre_attention_layer_norm = t5x_model["target"]["decoder"][layer_name]["pre_self_attention_layer_norm"]["scale"]
77
 
78
  # Encoder-Decoder-Attention
79
- t5x_enc_dec_attention_key = t5x_model["target"]["decoder"][layer_name]["encoder_decoder_attention"]["key"]["kernel"]
80
- t5x_enc_dec_attention_out = t5x_model["target"]["decoder"][layer_name]["encoder_decoder_attention"]["out"]["kernel"]
81
- t5x_enc_dec_attention_query = t5x_model["target"]["decoder"][layer_name]["encoder_decoder_attention"]["query"]["kernel"]
82
- t5x_enc_dec_attention_value = t5x_model["target"]["decoder"][layer_name]["encoder_decoder_attention"]["value"]["kernel"]
83
 
84
  ## Layer Normalization
85
- t5x_cross_layer_norm = t5x_model["target"]["decoder"][layer_name]["pre_cross_attention_layer_norm"]["scale"]
86
 
87
  # MLP
88
  if split_mlp_wi:
89
- t5x_mlp_wi_0 = t5x_model["target"]["decoder"][layer_name]["mlp"]["wi_0"]["kernel"]
90
- t5x_mlp_wi_1 = t5x_model["target"]["decoder"][layer_name]["mlp"]["wi_1"]["kernel"]
91
  else:
92
- t5x_mlp_wi = t5x_model["target"]["decoder"][layer_name]["mlp"]["wi"]["kernel"]
93
 
94
- t5x_mlp_wo = t5x_model["target"]["decoder"][layer_name]["mlp"]["wo"]["kernel"]
95
 
96
  ## Layer Normalization
97
- tx5_mlp_layer_norm = t5x_model["target"]["decoder"][layer_name]["pre_mlp_layer_norm"]["scale"]
98
 
99
  # Assigning
100
- flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["0"]["SelfAttention"]["k"]["kernel"] = t5x_attention_key
101
- flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["0"]["SelfAttention"]["o"]["kernel"] = t5x_attention_out
102
- flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["0"]["SelfAttention"]["q"]["kernel"] = t5x_attention_query
103
- flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["0"]["SelfAttention"]["v"]["kernel"] = t5x_attention_value
104
 
105
  flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["0"]["layer_norm"]["weight"] = t5x_pre_attention_layer_norm
106
 
107
- flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["1"]["EncDecAttention"]["k"]["kernel"] = t5x_enc_dec_attention_key
108
- flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["1"]["EncDecAttention"]["o"]["kernel"] = t5x_enc_dec_attention_out
109
- flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["1"]["EncDecAttention"]["q"]["kernel"] = t5x_enc_dec_attention_query
110
- flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["1"]["EncDecAttention"]["v"]["kernel"] = t5x_enc_dec_attention_value
111
 
112
  flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["1"]["layer_norm"]["weight"] = t5x_cross_layer_norm
113
 
@@ -126,8 +125,8 @@ def convert_t5x_checkpoint_to_flax(t5x_checkpoint_path, config_name, flax_dump_f
126
  flax_model.params["decoder"]["final_layer_norm"]["weight"] = tx5_decoder_norm
127
 
128
  # Only for layer 0:
129
- t5x_decoder_rel_embedding = t5x_model["target"]["decoder"]["relpos_bias"]["rel_embedding"].T
130
- flax_model.params["decoder"]["block"]["0"]["layer"]["0"]["SelfAttention"]["relative_attention_bias"]["embedding"] = t5x_decoder_rel_embedding
131
 
132
  # Token Embeddings
133
  tx5_token_embeddings = t5x_model["target"]["token_embedder"]["embedding"]
@@ -139,6 +138,10 @@ def convert_t5x_checkpoint_to_flax(t5x_checkpoint_path, config_name, flax_dump_f
139
  flax_model.save_pretrained(flax_dump_folder_path)
140
  print("T5X Model was sucessfully converted!")
141
 
 
 
 
 
142
 
143
  if __name__ == "__main__":
144
  parser = argparse.ArgumentParser()
@@ -154,4 +157,6 @@ if __name__ == "__main__":
154
  )
155
  args = parser.parse_args()
156
  convert_t5x_checkpoint_to_flax(args.t5x_checkpoint_path, args.config_name, args.flax_dump_folder_path)
 
 
157
 
 
3
  import argparse
4
 
5
  from t5x import checkpoints
6
+ from transformers import T5Config, FlaxT5ForConditionalGeneration, AutoModelForSeq2SeqLM
7
+ import torch
8
 
9
 
10
  def convert_t5x_checkpoint_to_flax(t5x_checkpoint_path, config_name, flax_dump_folder_path):
 
12
  flax_model = FlaxT5ForConditionalGeneration(config=config)
13
  t5x_model = checkpoints.load_t5x_checkpoint(t5x_checkpoint_path)
14
 
15
+ split_mlp_wi = "wi_0" in t5x_model["target"]["encoder"]["encoder"]["mlp"]
16
 
17
  # Encoder
18
  for layer_index in range(config.num_layers):
 
19
 
20
  # Self-Attention
21
+ t5x_attention_key = t5x_model["target"]["encoder"]["encoder"]["attention"]["key"]["kernel"][:, layer_index, :, :]
22
+ t5x_attention_out = t5x_model["target"]["encoder"]["encoder"]["attention"]["out"]["kernel"][:, layer_index, :, :]
23
+ t5x_attention_query = t5x_model["target"]["encoder"]["encoder"]["attention"]["query"]["kernel"][:, layer_index, :, :]
24
+ t5x_attention_value = t5x_model["target"]["encoder"]["encoder"]["attention"]["value"]["kernel"][:, layer_index, :, :]
25
 
26
  ## Layer Normalization
27
+ t5x_attention_layer_norm = t5x_model["target"]["encoder"]["encoder"]["pre_attention_layer_norm"]["scale"][:, layer_index]
28
 
29
  if split_mlp_wi:
30
+ t5x_mlp_wi_0 = t5x_model["target"]["encoder"]["encoder"]["mlp"]["wi_0"]["kernel"][:, layer_index, :]
31
+ t5x_mlp_wi_1 = t5x_model["target"]["encoder"]["encoder"]["mlp"]["wi_1"]["kernel"][:, layer_index, :]
32
  else:
33
+ t5x_mlp_wi = t5x_model["target"]["encoder"]["encoder"]["mlp"]["wi"]["kernel"][:, layer_index, :]
34
 
35
+ t5x_mlp_wo = t5x_model["target"]["encoder"]["encoder"]["mlp"]["wo"]["kernel"][:, layer_index, :]
36
 
37
  ## Layer Normalization
38
+ t5x_mlp_layer_norm = t5x_model["target"]["encoder"]["encoder"]["pre_mlp_layer_norm"]["scale"][:, layer_index]
39
 
40
  # Assigning
41
+ flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["0"]["SelfAttention"]["k"]["kernel"] = t5x_attention_key.reshape(*t5x_attention_key.shape[:-2], -1)
42
+ flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["0"]["SelfAttention"]["o"]["kernel"] = t5x_attention_out.reshape(-1, t5x_attention_out.shape[-1])
43
+ flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["0"]["SelfAttention"]["q"]["kernel"] = t5x_attention_query.reshape(*t5x_attention_query.shape[:-2], -1)
44
+ flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["0"]["SelfAttention"]["v"]["kernel"] = t5x_attention_value.reshape(*t5x_attention_value.shape[:-2], -1)
45
 
46
  flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["0"]["layer_norm"]["weight"] = t5x_attention_layer_norm
47
 
 
55
  flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["1"]["layer_norm"]["weight"] = t5x_mlp_layer_norm
56
 
57
  # Only for layer 0:
58
+ t5x_encoder_rel_embedding = t5x_model["target"]["encoder"]["encoder"]["relpos_bias"]["rel_embedding"].T
59
+ flax_model.params["encoder"]["block"]["0"]["layer"]["0"]["SelfAttention"]["relative_attention_bias"]["embedding"] = t5x_encoder_rel_embedding[:, 0, :]
60
 
61
  # Assigning
62
  t5x_encoder_norm = t5x_model["target"]["encoder"]["encoder_norm"]["scale"]
63
  flax_model.params["encoder"]["final_layer_norm"]["weight"] = t5x_encoder_norm
64
 
65
  # Decoder
66
+ for layer_index in range(config.num_decoder_layers):
 
67
 
68
  # Self-Attention
69
+ t5x_attention_key = t5x_model["target"]["decoder"]["decoder"]["self_attention"]["key"]["kernel"][:, layer_index, :, :]
70
+ t5x_attention_out = t5x_model["target"]["decoder"]["decoder"]["self_attention"]["out"]["kernel"][:, layer_index, :, :]
71
+ t5x_attention_query = t5x_model["target"]["decoder"]["decoder"]["self_attention"]["query"]["kernel"][:, layer_index, :, :]
72
+ t5x_attention_value = t5x_model["target"]["decoder"]["decoder"]["self_attention"]["value"]["kernel"][:, layer_index, :, :]
73
 
74
  ## Layer Normalization
75
+ t5x_pre_attention_layer_norm = t5x_model["target"]["decoder"]["decoder"]["pre_self_attention_layer_norm"]["scale"][:, layer_index]
76
 
77
  # Encoder-Decoder-Attention
78
+ t5x_enc_dec_attention_key = t5x_model["target"]["decoder"]["decoder"]["encoder_decoder_attention"]["key"]["kernel"][:, layer_index, :, :]
79
+ t5x_enc_dec_attention_out = t5x_model["target"]["decoder"]["decoder"]["encoder_decoder_attention"]["out"]["kernel"][:, layer_index, :, :]
80
+ t5x_enc_dec_attention_query = t5x_model["target"]["decoder"]["decoder"]["encoder_decoder_attention"]["query"]["kernel"][:, layer_index, :, :]
81
+ t5x_enc_dec_attention_value = t5x_model["target"]["decoder"]["decoder"]["encoder_decoder_attention"]["value"]["kernel"][:, layer_index, :, :]
82
 
83
  ## Layer Normalization
84
+ t5x_cross_layer_norm = t5x_model["target"]["decoder"]["decoder"]["pre_cross_attention_layer_norm"]["scale"][:, layer_index]
85
 
86
  # MLP
87
  if split_mlp_wi:
88
+ t5x_mlp_wi_0 = t5x_model["target"]["decoder"]["decoder"]["mlp"]["wi_0"]["kernel"][:, layer_index, :]
89
+ t5x_mlp_wi_1 = t5x_model["target"]["decoder"]["decoder"]["mlp"]["wi_1"]["kernel"][:, layer_index, :]
90
  else:
91
+ t5x_mlp_wi = t5x_model["target"]["decoder"]["decoder"]["mlp"]["wi"]["kernel"][:, layer_index, :]
92
 
93
+ t5x_mlp_wo = t5x_model["target"]["decoder"]["decoder"]["mlp"]["wo"]["kernel"][:, layer_index, :]
94
 
95
  ## Layer Normalization
96
+ tx5_mlp_layer_norm = t5x_model["target"]["decoder"]["decoder"]["pre_mlp_layer_norm"]["scale"][:, layer_index]
97
 
98
  # Assigning
99
+ flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["0"]["SelfAttention"]["k"]["kernel"] = t5x_attention_key.reshape(*t5x_attention_key.shape[:-2], -1)
100
+ flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["0"]["SelfAttention"]["o"]["kernel"] = t5x_attention_out.reshape(-1, t5x_attention_out.shape[-1])
101
+ flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["0"]["SelfAttention"]["q"]["kernel"] = t5x_attention_query.reshape(*t5x_attention_query.shape[:-2], -1)
102
+ flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["0"]["SelfAttention"]["v"]["kernel"] = t5x_attention_value.reshape(*t5x_attention_value.shape[:-2], -1)
103
 
104
  flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["0"]["layer_norm"]["weight"] = t5x_pre_attention_layer_norm
105
 
106
+ flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["1"]["EncDecAttention"]["k"]["kernel"] = t5x_enc_dec_attention_key.reshape(*t5x_enc_dec_attention_key.shape[:-2], -1)
107
+ flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["1"]["EncDecAttention"]["o"]["kernel"] = t5x_enc_dec_attention_out.reshape(-1, t5x_enc_dec_attention_out.shape[-1])
108
+ flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["1"]["EncDecAttention"]["q"]["kernel"] = t5x_enc_dec_attention_query.reshape(*t5x_enc_dec_attention_query.shape[:-2], -1)
109
+ flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["1"]["EncDecAttention"]["v"]["kernel"] = t5x_enc_dec_attention_value.reshape(*t5x_enc_dec_attention_value.shape[:-2], -1)
110
 
111
  flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["1"]["layer_norm"]["weight"] = t5x_cross_layer_norm
112
 
 
125
  flax_model.params["decoder"]["final_layer_norm"]["weight"] = tx5_decoder_norm
126
 
127
  # Only for layer 0:
128
+ t5x_decoder_rel_embedding = t5x_model["target"]["decoder"]["decoder"]["relpos_bias"]["rel_embedding"].T
129
+ flax_model.params["decoder"]["block"]["0"]["layer"]["0"]["SelfAttention"]["relative_attention_bias"]["embedding"] = t5x_decoder_rel_embedding[:, 0, :]
130
 
131
  # Token Embeddings
132
  tx5_token_embeddings = t5x_model["target"]["token_embedder"]["embedding"]
 
138
  flax_model.save_pretrained(flax_dump_folder_path)
139
  print("T5X Model was sucessfully converted!")
140
 
141
+ def convert_flax_to_pytorch(flax_dump_folder_path, pytorch_dump_folder_path):
142
+ model = AutoModelForSeq2SeqLM.from_pretrained(flax_dump_folder_path, from_flax=True, torch_dtype=torch.float32)
143
+ model.save_pretrained(pytorch_dump_folder_path)
144
+ print("Flax model was sucessfully converted to Pytorch!")
145
 
146
  if __name__ == "__main__":
147
  parser = argparse.ArgumentParser()
 
157
  )
158
  args = parser.parse_args()
159
  convert_t5x_checkpoint_to_flax(args.t5x_checkpoint_path, args.config_name, args.flax_dump_folder_path)
160
+ convert_flax_to_pytorch(args.flax_dump_folder_path, args.flax_dump_folder_path)
161
+
162
 
flax_model.msgpack ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:af1dfe073774d3fd1413eb3ab376f16e6a8448dd3f979197a1297c2538377b89
3
+ size 735762207
model-info.txt ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Variable decoder/decoder/encoder_decoder_attention/key/kernel size 4194304 shape (embed=512, layers=16, heads=8, kv=64) partition spec (None, None, 'model', None)
2
+ Variable decoder/decoder/encoder_decoder_attention/out/kernel size 4194304 shape (heads=8, layers=16, kv=64, embed=512) partition spec ('model', None, None, None)
3
+ Variable decoder/decoder/encoder_decoder_attention/query/kernel size 4194304 shape (embed=512, layers=16, heads=8, kv=64) partition spec (None, None, 'model', None)
4
+ Variable decoder/decoder/encoder_decoder_attention/value/kernel size 4194304 shape (embed=512, layers=16, heads=8, kv=64) partition spec (None, None, 'model', None)
5
+ Variable decoder/decoder/mlp/wi_0/kernel size 16777216 shape (embed=512, layers=16, mlp=2048) partition spec (None, None, 'model')
6
+ Variable decoder/decoder/mlp/wi_1/kernel size 16777216 shape (embed=512, layers=16, mlp=2048) partition spec (None, None, 'model')
7
+ Variable decoder/decoder/mlp/wo/kernel size 16777216 shape (mlp=2048, layers=16, embed=512) partition spec ('model', None, None)
8
+ Variable decoder/decoder/pre_cross_attention_layer_norm/scale size 8192 shape (embed=512, layers=16) partition spec (None, None)
9
+ Variable decoder/decoder/pre_mlp_layer_norm/scale size 8192 shape (embed=512, layers=16) partition spec (None, None)
10
+ Variable decoder/decoder/pre_self_attention_layer_norm/scale size 8192 shape (embed=512, layers=16) partition spec (None, None)
11
+ Variable decoder/decoder/relpos_bias/rel_embedding size 4096 shape (heads=8, layers=16, relpos_buckets=32) partition spec ('model', None, None)
12
+ Variable decoder/decoder/self_attention/key/kernel size 4194304 shape (embed=512, layers=16, heads=8, kv=64) partition spec (None, None, 'model', None)
13
+ Variable decoder/decoder/self_attention/out/kernel size 4194304 shape (heads=8, layers=16, kv=64, embed=512) partition spec ('model', None, None, None)
14
+ Variable decoder/decoder/self_attention/query/kernel size 4194304 shape (embed=512, layers=16, heads=8, kv=64) partition spec (None, None, 'model', None)
15
+ Variable decoder/decoder/self_attention/value/kernel size 4194304 shape (embed=512, layers=16, heads=8, kv=64) partition spec (None, None, 'model', None)
16
+ Variable decoder/decoder_norm/scale size 512 shape (embed=512) partition spec (None,)
17
+ Variable decoder/logits_dense/kernel size 16449536 shape (embed=512, vocab=32128) partition spec (None, 'model')
18
+ Variable encoder/encoder/attention/key/kernel size 4194304 shape (embed=512, layers=16, heads=8, kv=64) partition spec (None, None, 'model', None)
19
+ Variable encoder/encoder/attention/out/kernel size 4194304 shape (heads=8, layers=16, kv=64, embed=512) partition spec ('model', None, None, None)
20
+ Variable encoder/encoder/attention/query/kernel size 4194304 shape (embed=512, layers=16, heads=8, kv=64) partition spec (None, None, 'model', None)
21
+ Variable encoder/encoder/attention/value/kernel size 4194304 shape (embed=512, layers=16, heads=8, kv=64) partition spec (None, None, 'model', None)
22
+ Variable encoder/encoder/mlp/wi_0/kernel size 16777216 shape (embed=512, layers=16, mlp=2048) partition spec (None, None, 'model')
23
+ Variable encoder/encoder/mlp/wi_1/kernel size 16777216 shape (embed=512, layers=16, mlp=2048) partition spec (None, None, 'model')
24
+ Variable encoder/encoder/mlp/wo/kernel size 16777216 shape (mlp=2048, layers=16, embed=512) partition spec ('model', None, None)
25
+ Variable encoder/encoder/pre_attention_layer_norm/scale size 8192 shape (embed=512, layers=16) partition spec (None, None)
26
+ Variable encoder/encoder/pre_mlp_layer_norm/scale size 8192 shape (embed=512, layers=16) partition spec (None, None)
27
+ Variable encoder/encoder/relpos_bias/rel_embedding size 4096 shape (heads=8, layers=16, relpos_buckets=32) partition spec ('model', None, None)
28
+ Variable encoder/encoder_norm/scale size 512 shape (embed=512) partition spec (None,)
29
+ Variable token_embedder/embedding size 16449536 shape (vocab=32128, embed=512) partition spec ('model', None)
30
+ Total number of parameters: 183944192
31
+
32
+ Variable param_states/decoder/decoder/encoder_decoder_attention/key/kernel/m size 1 shape (1,) partition spec None
33
+ Variable param_states/decoder/decoder/encoder_decoder_attention/key/kernel/v size 1 shape (1,) partition spec None
34
+ Variable param_states/decoder/decoder/encoder_decoder_attention/key/kernel/v_col size 8192 shape (16, 8, 64) partition spec None
35
+ Variable param_states/decoder/decoder/encoder_decoder_attention/key/kernel/v_row size 8192 shape (512, 16) partition spec None
36
+ Variable param_states/decoder/decoder/encoder_decoder_attention/out/kernel/m size 1 shape (1,) partition spec None
37
+ Variable param_states/decoder/decoder/encoder_decoder_attention/out/kernel/v size 1 shape (1,) partition spec None
38
+ Variable param_states/decoder/decoder/encoder_decoder_attention/out/kernel/v_col size 8192 shape (8, 16, 64) partition spec None
39
+ Variable param_states/decoder/decoder/encoder_decoder_attention/out/kernel/v_row size 8192 shape (16, 512) partition spec None
40
+ Variable param_states/decoder/decoder/encoder_decoder_attention/query/kernel/m size 1 shape (1,) partition spec None
41
+ Variable param_states/decoder/decoder/encoder_decoder_attention/query/kernel/v size 1 shape (1,) partition spec None
42
+ Variable param_states/decoder/decoder/encoder_decoder_attention/query/kernel/v_col size 8192 shape (16, 8, 64) partition spec None
43
+ Variable param_states/decoder/decoder/encoder_decoder_attention/query/kernel/v_row size 8192 shape (512, 16) partition spec None
44
+ Variable param_states/decoder/decoder/encoder_decoder_attention/value/kernel/m size 1 shape (1,) partition spec None
45
+ Variable param_states/decoder/decoder/encoder_decoder_attention/value/kernel/v size 1 shape (1,) partition spec None
46
+ Variable param_states/decoder/decoder/encoder_decoder_attention/value/kernel/v_col size 8192 shape (16, 8, 64) partition spec None
47
+ Variable param_states/decoder/decoder/encoder_decoder_attention/value/kernel/v_row size 8192 shape (512, 16) partition spec None
48
+ Variable param_states/decoder/decoder/mlp/wi_0/kernel/m size 1 shape (1,) partition spec None
49
+ Variable param_states/decoder/decoder/mlp/wi_0/kernel/v size 1 shape (1,) partition spec None
50
+ Variable param_states/decoder/decoder/mlp/wi_0/kernel/v_col size 32768 shape (16, 2048) partition spec None
51
+ Variable param_states/decoder/decoder/mlp/wi_0/kernel/v_row size 8192 shape (512, 16) partition spec None
52
+ Variable param_states/decoder/decoder/mlp/wi_1/kernel/m size 1 shape (1,) partition spec None
53
+ Variable param_states/decoder/decoder/mlp/wi_1/kernel/v size 1 shape (1,) partition spec None
54
+ Variable param_states/decoder/decoder/mlp/wi_1/kernel/v_col size 32768 shape (16, 2048) partition spec None
55
+ Variable param_states/decoder/decoder/mlp/wi_1/kernel/v_row size 8192 shape (512, 16) partition spec None
56
+ Variable param_states/decoder/decoder/mlp/wo/kernel/m size 1 shape (1,) partition spec None
57
+ Variable param_states/decoder/decoder/mlp/wo/kernel/v size 1 shape (1,) partition spec None
58
+ Variable param_states/decoder/decoder/mlp/wo/kernel/v_col size 32768 shape (2048, 16) partition spec None
59
+ Variable param_states/decoder/decoder/mlp/wo/kernel/v_row size 8192 shape (16, 512) partition spec None
60
+ Variable param_states/decoder/decoder/pre_cross_attention_layer_norm/scale/m size 1 shape (1,) partition spec None
61
+ Variable param_states/decoder/decoder/pre_cross_attention_layer_norm/scale/v size 8192 shape (embed=512, layers=16) partition spec (None, None)
62
+ Variable param_states/decoder/decoder/pre_cross_attention_layer_norm/scale/v_col size 1 shape (1,) partition spec None
63
+ Variable param_states/decoder/decoder/pre_cross_attention_layer_norm/scale/v_row size 1 shape (1,) partition spec None
64
+ Variable param_states/decoder/decoder/pre_mlp_layer_norm/scale/m size 1 shape (1,) partition spec None
65
+ Variable param_states/decoder/decoder/pre_mlp_layer_norm/scale/v size 8192 shape (embed=512, layers=16) partition spec (None, None)
66
+ Variable param_states/decoder/decoder/pre_mlp_layer_norm/scale/v_col size 1 shape (1,) partition spec None
67
+ Variable param_states/decoder/decoder/pre_mlp_layer_norm/scale/v_row size 1 shape (1,) partition spec None
68
+ Variable param_states/decoder/decoder/pre_self_attention_layer_norm/scale/m size 1 shape (1,) partition spec None
69
+ Variable param_states/decoder/decoder/pre_self_attention_layer_norm/scale/v size 8192 shape (embed=512, layers=16) partition spec (None, None)
70
+ Variable param_states/decoder/decoder/pre_self_attention_layer_norm/scale/v_col size 1 shape (1,) partition spec None
71
+ Variable param_states/decoder/decoder/pre_self_attention_layer_norm/scale/v_row size 1 shape (1,) partition spec None
72
+ Variable param_states/decoder/decoder/relpos_bias/rel_embedding/m size 1 shape (1,) partition spec None
73
+ Variable param_states/decoder/decoder/relpos_bias/rel_embedding/v size 4096 shape (heads=8, layers=16, relpos_buckets=32) partition spec ('model', None, None)
74
+ Variable param_states/decoder/decoder/relpos_bias/rel_embedding/v_col size 1 shape (1,) partition spec None
75
+ Variable param_states/decoder/decoder/relpos_bias/rel_embedding/v_row size 1 shape (1,) partition spec None
76
+ Variable param_states/decoder/decoder/self_attention/key/kernel/m size 1 shape (1,) partition spec None
77
+ Variable param_states/decoder/decoder/self_attention/key/kernel/v size 1 shape (1,) partition spec None
78
+ Variable param_states/decoder/decoder/self_attention/key/kernel/v_col size 8192 shape (16, 8, 64) partition spec None
79
+ Variable param_states/decoder/decoder/self_attention/key/kernel/v_row size 8192 shape (512, 16) partition spec None
80
+ Variable param_states/decoder/decoder/self_attention/out/kernel/m size 1 shape (1,) partition spec None
81
+ Variable param_states/decoder/decoder/self_attention/out/kernel/v size 1 shape (1,) partition spec None
82
+ Variable param_states/decoder/decoder/self_attention/out/kernel/v_col size 8192 shape (8, 16, 64) partition spec None
83
+ Variable param_states/decoder/decoder/self_attention/out/kernel/v_row size 8192 shape (16, 512) partition spec None
84
+ Variable param_states/decoder/decoder/self_attention/query/kernel/m size 1 shape (1,) partition spec None
85
+ Variable param_states/decoder/decoder/self_attention/query/kernel/v size 1 shape (1,) partition spec None
86
+ Variable param_states/decoder/decoder/self_attention/query/kernel/v_col size 8192 shape (16, 8, 64) partition spec None
87
+ Variable param_states/decoder/decoder/self_attention/query/kernel/v_row size 8192 shape (512, 16) partition spec None
88
+ Variable param_states/decoder/decoder/self_attention/value/kernel/m size 1 shape (1,) partition spec None
89
+ Variable param_states/decoder/decoder/self_attention/value/kernel/v size 1 shape (1,) partition spec None
90
+ Variable param_states/decoder/decoder/self_attention/value/kernel/v_col size 8192 shape (16, 8, 64) partition spec None
91
+ Variable param_states/decoder/decoder/self_attention/value/kernel/v_row size 8192 shape (512, 16) partition spec None
92
+ Variable param_states/decoder/decoder_norm/scale/m size 1 shape (1,) partition spec None
93
+ Variable param_states/decoder/decoder_norm/scale/v size 512 shape (embed=512) partition spec (None,)
94
+ Variable param_states/decoder/decoder_norm/scale/v_col size 1 shape (1,) partition spec None
95
+ Variable param_states/decoder/decoder_norm/scale/v_row size 1 shape (1,) partition spec None
96
+ Variable param_states/decoder/logits_dense/kernel/m size 1 shape (1,) partition spec None
97
+ Variable param_states/decoder/logits_dense/kernel/v size 1 shape (1,) partition spec None
98
+ Variable param_states/decoder/logits_dense/kernel/v_col size 32128 shape (32128,) partition spec None
99
+ Variable param_states/decoder/logits_dense/kernel/v_row size 512 shape (512,) partition spec None
100
+ Variable param_states/encoder/encoder/attention/key/kernel/m size 1 shape (1,) partition spec None
101
+ Variable param_states/encoder/encoder/attention/key/kernel/v size 1 shape (1,) partition spec None
102
+ Variable param_states/encoder/encoder/attention/key/kernel/v_col size 8192 shape (16, 8, 64) partition spec None
103
+ Variable param_states/encoder/encoder/attention/key/kernel/v_row size 8192 shape (512, 16) partition spec None
104
+ Variable param_states/encoder/encoder/attention/out/kernel/m size 1 shape (1,) partition spec None
105
+ Variable param_states/encoder/encoder/attention/out/kernel/v size 1 shape (1,) partition spec None
106
+ Variable param_states/encoder/encoder/attention/out/kernel/v_col size 8192 shape (8, 16, 64) partition spec None
107
+ Variable param_states/encoder/encoder/attention/out/kernel/v_row size 8192 shape (16, 512) partition spec None
108
+ Variable param_states/encoder/encoder/attention/query/kernel/m size 1 shape (1,) partition spec None
109
+ Variable param_states/encoder/encoder/attention/query/kernel/v size 1 shape (1,) partition spec None
110
+ Variable param_states/encoder/encoder/attention/query/kernel/v_col size 8192 shape (16, 8, 64) partition spec None
111
+ Variable param_states/encoder/encoder/attention/query/kernel/v_row size 8192 shape (512, 16) partition spec None
112
+ Variable param_states/encoder/encoder/attention/value/kernel/m size 1 shape (1,) partition spec None
113
+ Variable param_states/encoder/encoder/attention/value/kernel/v size 1 shape (1,) partition spec None
114
+ Variable param_states/encoder/encoder/attention/value/kernel/v_col size 8192 shape (16, 8, 64) partition spec None
115
+ Variable param_states/encoder/encoder/attention/value/kernel/v_row size 8192 shape (512, 16) partition spec None
116
+ Variable param_states/encoder/encoder/mlp/wi_0/kernel/m size 1 shape (1,) partition spec None
117
+ Variable param_states/encoder/encoder/mlp/wi_0/kernel/v size 1 shape (1,) partition spec None
118
+ Variable param_states/encoder/encoder/mlp/wi_0/kernel/v_col size 32768 shape (16, 2048) partition spec None
119
+ Variable param_states/encoder/encoder/mlp/wi_0/kernel/v_row size 8192 shape (512, 16) partition spec None
120
+ Variable param_states/encoder/encoder/mlp/wi_1/kernel/m size 1 shape (1,) partition spec None
121
+ Variable param_states/encoder/encoder/mlp/wi_1/kernel/v size 1 shape (1,) partition spec None
122
+ Variable param_states/encoder/encoder/mlp/wi_1/kernel/v_col size 32768 shape (16, 2048) partition spec None
123
+ Variable param_states/encoder/encoder/mlp/wi_1/kernel/v_row size 8192 shape (512, 16) partition spec None
124
+ Variable param_states/encoder/encoder/mlp/wo/kernel/m size 1 shape (1,) partition spec None
125
+ Variable param_states/encoder/encoder/mlp/wo/kernel/v size 1 shape (1,) partition spec None
126
+ Variable param_states/encoder/encoder/mlp/wo/kernel/v_col size 32768 shape (2048, 16) partition spec None
127
+ Variable param_states/encoder/encoder/mlp/wo/kernel/v_row size 8192 shape (16, 512) partition spec None
128
+ Variable param_states/encoder/encoder/pre_attention_layer_norm/scale/m size 1 shape (1,) partition spec None
129
+ Variable param_states/encoder/encoder/pre_attention_layer_norm/scale/v size 8192 shape (embed=512, layers=16) partition spec (None, None)
130
+ Variable param_states/encoder/encoder/pre_attention_layer_norm/scale/v_col size 1 shape (1,) partition spec None
131
+ Variable param_states/encoder/encoder/pre_attention_layer_norm/scale/v_row size 1 shape (1,) partition spec None
132
+ Variable param_states/encoder/encoder/pre_mlp_layer_norm/scale/m size 1 shape (1,) partition spec None
133
+ Variable param_states/encoder/encoder/pre_mlp_layer_norm/scale/v size 8192 shape (embed=512, layers=16) partition spec (None, None)
134
+ Variable param_states/encoder/encoder/pre_mlp_layer_norm/scale/v_col size 1 shape (1,) partition spec None
135
+ Variable param_states/encoder/encoder/pre_mlp_layer_norm/scale/v_row size 1 shape (1,) partition spec None
136
+ Variable param_states/encoder/encoder/relpos_bias/rel_embedding/m size 1 shape (1,) partition spec None
137
+ Variable param_states/encoder/encoder/relpos_bias/rel_embedding/v size 4096 shape (heads=8, layers=16, relpos_buckets=32) partition spec ('model', None, None)
138
+ Variable param_states/encoder/encoder/relpos_bias/rel_embedding/v_col size 1 shape (1,) partition spec None
139
+ Variable param_states/encoder/encoder/relpos_bias/rel_embedding/v_row size 1 shape (1,) partition spec None
140
+ Variable param_states/encoder/encoder_norm/scale/m size 1 shape (1,) partition spec None
141
+ Variable param_states/encoder/encoder_norm/scale/v size 512 shape (embed=512) partition spec (None,)
142
+ Variable param_states/encoder/encoder_norm/scale/v_col size 1 shape (1,) partition spec None
143
+ Variable param_states/encoder/encoder_norm/scale/v_row size 1 shape (1,) partition spec None
144
+ Variable param_states/token_embedder/embedding/m size 1 shape (1,) partition spec None
145
+ Variable param_states/token_embedder/embedding/v size 1 shape (1,) partition spec None
146
+ Variable param_states/token_embedder/embedding/v_col size 32128 shape (32128,) partition spec None
147
+ Variable param_states/token_embedder/embedding/v_row size 512 shape (512,) partition spec None
148
+ Variable step size 1 shape () partition spec None
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:05fe4c98850f026cd8154f4358131f3fe9f8538fb692a4621d31a316ac620c80
3
+ size 735867349
train/events.out.tfevents.1661710468.t1v-n-12f94ad0-w-0.60675.0.v2 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ee5357b199786bd136f34c89c093f98ec5417d1cf220340749fc2496418fc60c
3
+ size 16868
training_eval/pretrain_finnish/events.out.tfevents.1661710468.t1v-n-12f94ad0-w-0.60675.1.v2 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:776cdaf0c0ff210e9e367110778093f7d42c2d9c7836a1a8a4667fb780f2e758
3
+ size 9244