Siddharth63 commited on
Commit
d67c39e
·
1 Parent(s): aae3052

Upload 10 files

Browse files
config.gin ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __gin__ import dynamic_registration
2
+ import __main__ as train_script
3
+ import seqio
4
+ from t5x import adafactor
5
+ from t5x.examples.t5 import network
6
+ from t5x import gin_utils
7
+ from t5x import models
8
+ from t5x import partitioning
9
+ from t5x import trainer
10
+ from t5x import utils
11
+ import ul2_tasks
12
+
13
+ # Macros:
14
+ # ==============================================================================
15
+ BATCH_SIZE = 256
16
+ DROPOUT_RATE = 0.0
17
+ LABEL_SMOOTHING = 0.0
18
+ LOSS_NORMALIZING_FACTOR = None
19
+ MIXTURE_OR_TASK_MODULE = None
20
+ MIXTURE_OR_TASK_NAME = 'pretrain_biological_ul2'
21
+ MODEL = @models.EncoderDecoderModel()
22
+ MODEL_DIR = '/models/bioul2-mini-nl8'
23
+ OPTIMIZER = @adafactor.Adafactor()
24
+ RANDOM_SEED = None
25
+ SHUFFLE_TRAIN_EXAMPLES = True
26
+ TASK_FEATURE_LENGTHS = {'inputs': 512, 'targets': 512}
27
+ TRAIN_STEPS = 500000
28
+ USE_CACHED_TASKS = False
29
+ USE_HARDWARE_RNG = False
30
+ VOCABULARY = @seqio.SentencePieceVocabulary()
31
+ Z_LOSS = 0.0001
32
+
33
+ # Parameters for adafactor.Adafactor:
34
+ # ==============================================================================
35
+ adafactor.Adafactor.decay_rate = 0.8
36
+ adafactor.Adafactor.logical_factor_rules = \
37
+ @adafactor.standard_logical_factor_rules()
38
+ adafactor.Adafactor.step_offset = 0
39
+
40
+ # Parameters for utils.CheckpointConfig:
41
+ # ==============================================================================
42
+ utils.CheckpointConfig.restore = @utils.RestoreCheckpointConfig()
43
+ utils.CheckpointConfig.save = @utils.SaveCheckpointConfig()
44
+
45
+ # Parameters for utils.create_learning_rate_scheduler:
46
+ # ==============================================================================
47
+ utils.create_learning_rate_scheduler.base_learning_rate = 1.0
48
+ utils.create_learning_rate_scheduler.factors = 'constant * rsqrt_decay'
49
+ utils.create_learning_rate_scheduler.warmup_steps = 10000
50
+
51
+ # Parameters for train/utils.DatasetConfig:
52
+ # ==============================================================================
53
+ train/utils.DatasetConfig.batch_size = %BATCH_SIZE
54
+ train/utils.DatasetConfig.mixture_or_task_name = %MIXTURE_OR_TASK_NAME
55
+ train/utils.DatasetConfig.module = %MIXTURE_OR_TASK_MODULE
56
+ train/utils.DatasetConfig.pack = True
57
+ train/utils.DatasetConfig.seed = None
58
+ train/utils.DatasetConfig.shuffle = %SHUFFLE_TRAIN_EXAMPLES
59
+ train/utils.DatasetConfig.split = 'train'
60
+ train/utils.DatasetConfig.task_feature_lengths = %TASK_FEATURE_LENGTHS
61
+ train/utils.DatasetConfig.use_cached = %USE_CACHED_TASKS
62
+
63
+ # Parameters for train_eval/utils.DatasetConfig:
64
+ # ==============================================================================
65
+ train_eval/utils.DatasetConfig.batch_size = %BATCH_SIZE
66
+ train_eval/utils.DatasetConfig.mixture_or_task_name = %MIXTURE_OR_TASK_NAME
67
+ train_eval/utils.DatasetConfig.module = %MIXTURE_OR_TASK_MODULE
68
+ train_eval/utils.DatasetConfig.pack = True
69
+ train_eval/utils.DatasetConfig.seed = 42
70
+ train_eval/utils.DatasetConfig.shuffle = False
71
+ train_eval/utils.DatasetConfig.split = 'validation'
72
+ train_eval/utils.DatasetConfig.task_feature_lengths = %TASK_FEATURE_LENGTHS
73
+ train_eval/utils.DatasetConfig.use_cached = %USE_CACHED_TASKS
74
+
75
+ # Parameters for models.EncoderDecoderModel:
76
+ # ==============================================================================
77
+ models.EncoderDecoderModel.input_vocabulary = %VOCABULARY
78
+ models.EncoderDecoderModel.label_smoothing = %LABEL_SMOOTHING
79
+ models.EncoderDecoderModel.loss_normalizing_factor = %LOSS_NORMALIZING_FACTOR
80
+ models.EncoderDecoderModel.module = @network.Transformer()
81
+ models.EncoderDecoderModel.optimizer_def = %OPTIMIZER
82
+ models.EncoderDecoderModel.output_vocabulary = %VOCABULARY
83
+ models.EncoderDecoderModel.z_loss = %Z_LOSS
84
+
85
+ # Parameters for partitioning.PjitPartitioner:
86
+ # ==============================================================================
87
+ partitioning.PjitPartitioner.logical_axis_rules = \
88
+ @partitioning.standard_logical_axis_rules()
89
+ partitioning.PjitPartitioner.model_parallel_submesh = None
90
+ partitioning.PjitPartitioner.num_partitions = 1
91
+
92
+ # Parameters for utils.RestoreCheckpointConfig:
93
+ # ==============================================================================
94
+ utils.RestoreCheckpointConfig.path = []
95
+
96
+ # Parameters for utils.SaveCheckpointConfig:
97
+ # ==============================================================================
98
+ utils.SaveCheckpointConfig.dtype = 'float32'
99
+ utils.SaveCheckpointConfig.keep = 5
100
+ utils.SaveCheckpointConfig.period = 10000
101
+ utils.SaveCheckpointConfig.save_dataset = False
102
+
103
+ # Parameters for seqio.SentencePieceVocabulary:
104
+ # ==============================================================================
105
+ seqio.SentencePieceVocabulary.sentencepiece_model_file = 'spiece.model'
106
+
107
+ # Parameters for network.T5Config:
108
+ # ==============================================================================
109
+ network.T5Config.dropout_rate = %DROPOUT_RATE
110
+ network.T5Config.dtype = 'bfloat16'
111
+ network.T5Config.emb_dim = 384
112
+ network.T5Config.head_dim = 64
113
+ network.T5Config.logits_via_embedding = False
114
+ network.T5Config.mlp_activations = ('gelu', 'linear')
115
+ network.T5Config.mlp_dim = 1536
116
+ network.T5Config.num_decoder_layers = 8
117
+ network.T5Config.num_encoder_layers = 8
118
+ network.T5Config.num_heads = 8
119
+ network.T5Config.vocab_size = 32128
120
+
121
+ # Parameters for train_script.train:
122
+ # ==============================================================================
123
+ train_script.train.checkpoint_cfg = @utils.CheckpointConfig()
124
+ train_script.train.eval_period = 10000
125
+ train_script.train.eval_steps = 20
126
+ train_script.train.infer_eval_dataset_cfg = None
127
+ train_script.train.model = %MODEL
128
+ train_script.train.model_dir = %MODEL_DIR
129
+ train_script.train.partitioner = @partitioning.PjitPartitioner()
130
+ train_script.train.random_seed = %RANDOM_SEED
131
+ train_script.train.summarize_config_fn = @gin_utils.summarize_gin_config
132
+ train_script.train.total_steps = %TRAIN_STEPS
133
+ train_script.train.train_dataset_cfg = @train/utils.DatasetConfig()
134
+ train_script.train.train_eval_dataset_cfg = @train_eval/utils.DatasetConfig()
135
+ train_script.train.trainer_cls = @trainer.Trainer
136
+ train_script.train.use_hardware_rng = %USE_HARDWARE_RNG
137
+
138
+ # Parameters for trainer.Trainer:
139
+ # ==============================================================================
140
+ trainer.Trainer.learning_rate_fn = @utils.create_learning_rate_scheduler()
141
+ trainer.Trainer.num_microbatches = None
142
+
143
+ # Parameters for network.Transformer:
144
+ # ==============================================================================
145
+ network.Transformer.config = @network.T5Config()
config.json ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "./",
3
+ "architectures": [
4
+ "T5ForConditionalGeneration"
5
+ ],
6
+ "d_ff": 1536,
7
+ "d_kv": 64,
8
+ "d_model": 384,
9
+ "decoder_start_token_id": 0,
10
+ "dense_act_fn": "gelu_new",
11
+ "dropout_rate": 0.1,
12
+ "eos_token_id": 1,
13
+ "feed_forward_proj": "gated-gelu",
14
+ "initializer_factor": 1.0,
15
+ "is_encoder_decoder": true,
16
+ "is_gated_act": true,
17
+ "layer_norm_epsilon": 1e-06,
18
+ "model_type": "t5",
19
+ "n_positions": 512,
20
+ "num_decoder_layers": 8,
21
+ "num_heads": 8,
22
+ "num_layers": 8,
23
+ "output_past": true,
24
+ "pad_token_id": 0,
25
+ "relative_attention_max_distance": 128,
26
+ "relative_attention_num_buckets": 32,
27
+ "tie_word_embeddings": false,
28
+ "torch_dtype": "float32",
29
+ "transformers_version": "4.22.1",
30
+ "use_cache": true,
31
+ "vocab_size": 32128
32
+ }
convert_t5x_checkpoint_to_flax.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://gist.github.com/stefan-it/30e4998ef159f33696e377a46f699d9f
2
+
3
+ import argparse
4
+
5
+ from t5x import checkpoints
6
+ from transformers import T5Config, FlaxT5ForConditionalGeneration, AutoModelForSeq2SeqLM
7
+ import torch
8
+
9
+
10
+ def convert_t5x_checkpoint_to_flax(t5x_checkpoint_path, config_name, flax_dump_folder_path):
11
+ config = T5Config.from_pretrained(config_name)
12
+ flax_model = FlaxT5ForConditionalGeneration(config=config)
13
+ t5x_model = checkpoints.load_t5x_checkpoint(t5x_checkpoint_path)
14
+
15
+ split_mlp_wi = "wi_0" in t5x_model["target"]["encoder"]["layers_0"]["mlp"]
16
+
17
+ # Encoder
18
+ for layer_index in range(config.num_layers):
19
+ layer_name = f"layers_{str(layer_index)}"
20
+
21
+ # Self-Attention
22
+ t5x_attention_key = t5x_model["target"]["encoder"][layer_name]["attention"]["key"]["kernel"]
23
+ t5x_attention_out = t5x_model["target"]["encoder"][layer_name]["attention"]["out"]["kernel"]
24
+ t5x_attention_query = t5x_model["target"]["encoder"][layer_name]["attention"]["query"]["kernel"]
25
+ t5x_attention_value = t5x_model["target"]["encoder"][layer_name]["attention"]["value"]["kernel"]
26
+
27
+ ## Layer Normalization
28
+ t5x_attention_layer_norm = t5x_model["target"]["encoder"][layer_name]["pre_attention_layer_norm"]["scale"]
29
+
30
+ if split_mlp_wi:
31
+ t5x_mlp_wi_0 = t5x_model["target"]["encoder"][layer_name]["mlp"]["wi_0"]["kernel"]
32
+ t5x_mlp_wi_1 = t5x_model["target"]["encoder"][layer_name]["mlp"]["wi_1"]["kernel"]
33
+ else:
34
+ t5x_mlp_wi = t5x_model["target"]["encoder"][layer_name]["mlp"]["wi"]["kernel"]
35
+
36
+ t5x_mlp_wo = t5x_model["target"]["encoder"][layer_name]["mlp"]["wo"]["kernel"]
37
+
38
+ ## Layer Normalization
39
+ t5x_mlp_layer_norm = t5x_model["target"]["encoder"][layer_name]["pre_mlp_layer_norm"]["scale"]
40
+
41
+ # Assigning
42
+ flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["0"]["SelfAttention"]["k"]["kernel"] = t5x_attention_key
43
+ flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["0"]["SelfAttention"]["o"]["kernel"] = t5x_attention_out
44
+ flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["0"]["SelfAttention"]["q"]["kernel"] = t5x_attention_query
45
+ flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["0"]["SelfAttention"]["v"]["kernel"] = t5x_attention_value
46
+
47
+ flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["0"]["layer_norm"]["weight"] = t5x_attention_layer_norm
48
+
49
+ if split_mlp_wi:
50
+ flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["1"]["DenseReluDense"]["wi_0"]["kernel"] = t5x_mlp_wi_0
51
+ flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["1"]["DenseReluDense"]["wi_1"]["kernel"] = t5x_mlp_wi_1
52
+ else:
53
+ flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["1"]["DenseReluDense"]["wi"]["kernel"] = t5x_mlp_wi
54
+
55
+ flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["1"]["DenseReluDense"]["wo"]["kernel"] = t5x_mlp_wo
56
+ flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["1"]["layer_norm"]["weight"] = t5x_mlp_layer_norm
57
+
58
+ # Only for layer 0:
59
+ t5x_encoder_rel_embedding = t5x_model["target"]["encoder"]["relpos_bias"]["rel_embedding"].T
60
+ flax_model.params["encoder"]["block"]["0"]["layer"]["0"]["SelfAttention"]["relative_attention_bias"]["embedding"] = t5x_encoder_rel_embedding
61
+
62
+ # Assigning
63
+ t5x_encoder_norm = t5x_model["target"]["encoder"]["encoder_norm"]["scale"]
64
+ flax_model.params["encoder"]["final_layer_norm"]["weight"] = t5x_encoder_norm
65
+
66
+ # Decoder
67
+ for layer_index in range(config.num_decoder_layers):
68
+ layer_name = f"layers_{str(layer_index)}"
69
+
70
+ # Self-Attention
71
+ t5x_attention_key = t5x_model["target"]["decoder"][layer_name]["self_attention"]["key"]["kernel"]
72
+ t5x_attention_out = t5x_model["target"]["decoder"][layer_name]["self_attention"]["out"]["kernel"]
73
+ t5x_attention_query = t5x_model["target"]["decoder"][layer_name]["self_attention"]["query"]["kernel"]
74
+ t5x_attention_value = t5x_model["target"]["decoder"][layer_name]["self_attention"]["value"]["kernel"]
75
+
76
+ ## Layer Normalization
77
+ t5x_pre_attention_layer_norm = t5x_model["target"]["decoder"][layer_name]["pre_self_attention_layer_norm"]["scale"]
78
+
79
+ # Encoder-Decoder-Attention
80
+ t5x_enc_dec_attention_key = t5x_model["target"]["decoder"][layer_name]["encoder_decoder_attention"]["key"]["kernel"]
81
+ t5x_enc_dec_attention_out = t5x_model["target"]["decoder"][layer_name]["encoder_decoder_attention"]["out"]["kernel"]
82
+ t5x_enc_dec_attention_query = t5x_model["target"]["decoder"][layer_name]["encoder_decoder_attention"]["query"]["kernel"]
83
+ t5x_enc_dec_attention_value = t5x_model["target"]["decoder"][layer_name]["encoder_decoder_attention"]["value"]["kernel"]
84
+
85
+ ## Layer Normalization
86
+ t5x_cross_layer_norm = t5x_model["target"]["decoder"][layer_name]["pre_cross_attention_layer_norm"]["scale"]
87
+
88
+ # MLP
89
+ if split_mlp_wi:
90
+ t5x_mlp_wi_0 = t5x_model["target"]["decoder"][layer_name]["mlp"]["wi_0"]["kernel"]
91
+ t5x_mlp_wi_1 = t5x_model["target"]["decoder"][layer_name]["mlp"]["wi_1"]["kernel"]
92
+ else:
93
+ t5x_mlp_wi = t5x_model["target"]["decoder"][layer_name]["mlp"]["wi"]["kernel"]
94
+
95
+ t5x_mlp_wo = t5x_model["target"]["decoder"][layer_name]["mlp"]["wo"]["kernel"]
96
+
97
+ ## Layer Normalization
98
+ tx5_mlp_layer_norm = t5x_model["target"]["decoder"][layer_name]["pre_mlp_layer_norm"]["scale"]
99
+
100
+ # Assigning
101
+ flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["0"]["SelfAttention"]["k"]["kernel"] = t5x_attention_key
102
+ flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["0"]["SelfAttention"]["o"]["kernel"] = t5x_attention_out
103
+ flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["0"]["SelfAttention"]["q"]["kernel"] = t5x_attention_query
104
+ flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["0"]["SelfAttention"]["v"]["kernel"] = t5x_attention_value
105
+
106
+ flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["0"]["layer_norm"]["weight"] = t5x_pre_attention_layer_norm
107
+
108
+ flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["1"]["EncDecAttention"]["k"]["kernel"] = t5x_enc_dec_attention_key
109
+ flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["1"]["EncDecAttention"]["o"]["kernel"] = t5x_enc_dec_attention_out
110
+ flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["1"]["EncDecAttention"]["q"]["kernel"] = t5x_enc_dec_attention_query
111
+ flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["1"]["EncDecAttention"]["v"]["kernel"] = t5x_enc_dec_attention_value
112
+
113
+ flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["1"]["layer_norm"]["weight"] = t5x_cross_layer_norm
114
+
115
+ if split_mlp_wi:
116
+ flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["2"]["DenseReluDense"]["wi_0"]["kernel"] = t5x_mlp_wi_0
117
+ flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["2"]["DenseReluDense"]["wi_1"]["kernel"] = t5x_mlp_wi_1
118
+ else:
119
+ flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["2"]["DenseReluDense"]["wi"]["kernel"] = t5x_mlp_wi
120
+
121
+ flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["2"]["DenseReluDense"]["wo"]["kernel"] = t5x_mlp_wo
122
+
123
+ flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["2"]["layer_norm"]["weight"] = tx5_mlp_layer_norm
124
+
125
+ # Decoder Normalization
126
+ tx5_decoder_norm = t5x_model["target"]["decoder"]["decoder_norm"]["scale"]
127
+ flax_model.params["decoder"]["final_layer_norm"]["weight"] = tx5_decoder_norm
128
+
129
+ # Only for layer 0:
130
+ t5x_decoder_rel_embedding = t5x_model["target"]["decoder"]["relpos_bias"]["rel_embedding"].T
131
+ flax_model.params["decoder"]["block"]["0"]["layer"]["0"]["SelfAttention"]["relative_attention_bias"]["embedding"] = t5x_decoder_rel_embedding
132
+
133
+ # Token Embeddings
134
+ tx5_token_embeddings = t5x_model["target"]["token_embedder"]["embedding"]
135
+ flax_model.params["shared"]["embedding"] = tx5_token_embeddings
136
+
137
+ # LM Head
138
+ flax_model.params["lm_head"]["kernel"] = t5x_model["target"]["decoder"]["logits_dense"]["kernel"]
139
+
140
+ flax_model.save_pretrained(flax_dump_folder_path)
141
+ print("T5X Model was sucessfully converted!")
142
+
143
+ def convert_flax_to_pytorch(flax_dump_folder_path, pytorch_dump_folder_path):
144
+ model = AutoModelForSeq2SeqLM.from_pretrained(flax_dump_folder_path, from_flax=True, torch_dtype=torch.float32)
145
+ model.save_pretrained(pytorch_dump_folder_path)
146
+ print("Flax model was sucessfully converted to Pytorch!")
147
+
148
+ if __name__ == "__main__":
149
+ parser = argparse.ArgumentParser()
150
+ # Required parameters
151
+ parser.add_argument(
152
+ "--t5x_checkpoint_path", default=None, type=str, required=True, help="Path the TX5 checkpoint."
153
+ )
154
+ parser.add_argument(
155
+ "--config_name", default=None, type=str, required=True, help="Config name of T5 model."
156
+ )
157
+ parser.add_argument(
158
+ "--flax_dump_folder_path", default=None, type=str, required=True, help="Path to the output FLAX model."
159
+ )
160
+ args = parser.parse_args()
161
+ convert_t5x_checkpoint_to_flax(args.t5x_checkpoint_path, args.config_name, args.flax_dump_folder_path)
162
+ convert_flax_to_pytorch(args.flax_dump_folder_path, args.flax_dump_folder_path)
export_checkpoint.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from transformers import T5ForConditionalGeneration, TFT5ForConditionalGeneration
3
+
4
+ def main(args):
5
+ pt_model = T5ForConditionalGeneration.from_pretrained(args.model_dir, from_flax=True)
6
+ pt_model.save_pretrained(args.model_dir)
7
+ tf_model = TFT5ForConditionalGeneration.from_pretrained(args.model_dir, from_pt=True)
8
+ tf_model.save_pretrained(args.model_dir)
9
+
10
+
11
+ if __name__ == "__main__":
12
+ parser = argparse.ArgumentParser()
13
+ parser.add_argument('--model_dir', type=str, default='.')
mini_nl8.gin ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # T5.1.1 Efficient mini nl8 model.
2
+
3
+ import seqio
4
+ include 't5x/examples/t5/t5_1_1/base.gin' # imports vocab, optimizer and model.
5
+
6
+ # ------------------- Network specification overrides --------------------------
7
+ network.Transformer.config = @network.T5Config()
8
+ network.T5Config:
9
+ emb_dim = 384
10
+ num_heads = 8
11
+ num_encoder_layers = 8
12
+ num_decoder_layers = 8
13
+ head_dim = 64
14
+ mlp_dim = 1536
15
+
16
+ # ------------------- Model specification overrides --------------------------
17
+ VOCABULARY = @seqio.SentencePieceVocabulary()
18
+ seqio.SentencePieceVocabulary.sentencepiece_model_file = "spiece.model"
19
+
20
+ MODEL = @models.EncoderDecoderModel()
21
+ models.EncoderDecoderModel:
22
+ input_vocabulary = %VOCABULARY
23
+ output_vocabulary = %VOCABULARY
mini_nl8_pretrain.gin ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Register necessary SeqIO Tasks/Mixtures.
2
+ from __gin__ import dynamic_registration
3
+ from t5x import utils
4
+ import ul2_tasks
5
+ import __main__ as train_script
6
+
7
+ include 'mini_nl8.gin'
8
+ include 't5x/configs/runs/pretrain.gin'
9
+
10
+
11
+ # ------------------- Training specification overrides --------------------------
12
+ train_script.train:
13
+ eval_period = 2000
14
+ stats_period = 100
15
+
16
+ utils.SaveCheckpointConfig:
17
+ period = 50000
18
+ keep = 4
19
+ use_gda = False
20
+
21
+ MIXTURE_OR_TASK_NAME = "pretrain_biological_ul2"
22
+ USE_CACHED_TASKS = False
23
+ TASK_FEATURE_LENGTHS = {"inputs": 512, "targets": 512}
24
+ TRAIN_STEPS = 500000
25
+ DROPOUT_RATE = 0.0
26
+ BATCH_SIZE = 256
model-info.txt ADDED
The diff for this file is too large to render. See raw diff
 
requirements.txt ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ datasets>=1.16.1
2
+ transformers>=4.13.0
3
+ flax>=0.3.5
4
+ optax>=0.1.0
5
+ tqdm>=4.61.1
6
+ numpy>=1.19.5
7
+ tokenizers>=0.10.3
8
+ sentencepiece>=0.1.96
9
+ protobuf>=3.17.3,<=3.20.99
10
+ tensorboard>=2.7.0
11
+ torch>=1.9.0
12
+ tensorflow>=2.7.0
13
+ jax[tpu]>=0.2.28
ul2_objective.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+ import tensorflow as tf
3
+ import seqio
4
+ import t5.data
5
+ from typing import Optional, Sequence
6
+
7
+ # found this function and modified from https://github.com/GoogleCloudPlatform/t5x-on-vertex-ai/blob/main/tasks/custom_tasks.py#L78
8
+ # UL2 paper appendix code missed this function
9
+ def prepend_prompt(
10
+ dataset: tf.data.Dataset,
11
+ output_features: seqio.preprocessors.OutputFeaturesType,
12
+ sequence_length: Optional[seqio.preprocessors.SequenceLengthType] = None,
13
+ prompt_mode: str = "",
14
+ key: str = "inputs",
15
+ mode: str = "",
16
+ ) -> tf.data.Dataset:
17
+ """Prepends a prompt at the beginning of an input sequence."""
18
+ del sequence_length
19
+ if prompt_mode and mode:
20
+ # output_features may not have inputs key
21
+ out_keys = list(output_features.keys())
22
+ prompt_tokens = output_features[out_keys[0]].vocabulary.encode_tf(prompt_mode)
23
+
24
+ def add_to_inputs(x):
25
+ x[key] = tf.concat([prompt_tokens, x[key]], axis=0)
26
+ return x
27
+
28
+ dataset = dataset.map(add_to_inputs)
29
+ return dataset
30
+
31
+
32
+ # modified from t5.data.preprocessors because output_features may not have inputs key
33
+ def split_tokens_to_inputs_length(dataset, sequence_length, output_features, **kwargs):
34
+ max_tokens = sequence_length["inputs"]
35
+ # output_features may not have inputs key
36
+ out_keys = list(output_features.keys())
37
+ if output_features[out_keys[0]].add_eos:
38
+ # Leave room to insert an EOS token.
39
+ max_tokens -= 1
40
+
41
+ return t5.data.preprocessors.split_tokens(
42
+ dataset, max_tokens_per_segment=max_tokens, **kwargs
43
+ )
44
+
45
+
46
+ # modified from t5.data.preprocessors because output_features may not have inputs key
47
+ def prefix_lm(dataset, sequence_length, output_features):
48
+ """Prefix language modeling objective used in Raffel et al. 2019."""
49
+ ds = dataset
50
+ ds = t5.data.preprocessors.select_random_chunk(
51
+ ds, output_features=output_features, feature_key="targets", max_length=65536
52
+ )
53
+ ds = split_tokens_to_inputs_length(
54
+ ds, output_features=output_features, sequence_length=sequence_length
55
+ )
56
+ ds = t5.data.preprocessors.denoise(
57
+ ds,
58
+ output_features,
59
+ inputs_fn=t5.data.preprocessors.drop_nonnoise_tokens,
60
+ targets_fn=t5.data.preprocessors.drop_noise_tokens,
61
+ noise_density=0.5,
62
+ noise_mask_fn=t5.data.preprocessors.random_prefix_noise_mask,
63
+ )
64
+ return ds
65
+
66
+
67
+ # copied from UL2 paper https://arxiv.org/pdf/2205.05131.pdf appendix chapter 9.2
68
+ # note: modified to use the prefix_lm() from above instead of the default t5.data.preprocessors.prefix_lm() because output_features may not have inputs key
69
+ def ul2_objective(
70
+ dataset: tf.data.Dataset,
71
+ sequence_length: seqio.preprocessors.SequenceLengthType,
72
+ output_features: seqio.preprocessors.OutputFeaturesType,
73
+ use_prefix_lm_task: bool = False,
74
+ rates: Optional[Sequence[float]] = None,
75
+ mean_noise_span_lengths: Sequence[float] = (3.0,),
76
+ noise_densities: Sequence[float] = (0.15,),
77
+ shard_ds: bool = True,
78
+ optional_task_prefixes: Optional[Sequence[str]] = None,
79
+ input_feature_key: str = "inputs",
80
+ merge_examples_to_reduce_padding: bool = True,
81
+ reserved_for_packing: bool = None,
82
+ seed: int = 7,
83
+ ) -> tf.data.Dataset:
84
+ """UL2-like pre-training objectives.
85
+ This preprocessor amounts to calling the 'span_corruption' function several
86
+ times with different values of 'noise_density' and 'mean_noise_span_length'.
87
+ We either shard or copy the dataset, then apply each function to each shard.
88
+ Add S-denoising (prefixLM) using use_prefix_lm_task.
89
+ Args:
90
+ dataset: A tf.data.Dataset with dictionaries containing the key 'input_feature_key'.
91
+ sequence_length: dict mapping of feature key to int length for that feature.
92
+ output_features: mapping of keys to features.
93
+ use_prefix_lm_task: <bool> If True, include PrefixLM in the task mix.
94
+ rates: <Optional<List<float>> List of rates per task. If None, tasks are sampled uniformly.
95
+ mean_noise_span_lengths: List of mean number of tokens per masked span per example.
96
+ noise_densities: List of what fraction of the tokens to mask.
97
+ shard_ds: <bool> If True, shard dataset per objective.
98
+ optional_task_prefixes: <Optional<list<str>> Strings to prepend for each corruption scheme. NOTE: If including prefixLM task, it must be the last prefix.
99
+ input_feature_key: which feature to use from the dataset as the input text tokens.
100
+ merge_examples_to_reduce_padding: if True, combines multiple input examples to reduce padding.
101
+ reserved_for_packing: if specified, reduces the desired inputs length by the specified amount to enable multiple examples to be packed together downstream.
102
+ seed: tf.int64 for controlling the random choice of spans.
103
+ Returns:
104
+ a dataset
105
+ """
106
+
107
+ if optional_task_prefixes: # Ensure each task has a prefix.
108
+ num_tasks = len(noise_densities) + int(use_prefix_lm_task)
109
+ valid_number_of_prefixes = num_tasks == len(optional_task_prefixes)
110
+ if not valid_number_of_prefixes:
111
+ raise ValueError("Number of task prefixes must match number of tasks.")
112
+ inputs_length = sequence_length[input_feature_key]
113
+ input_lengths, targets_lengths = [], []
114
+ sequence_lengths = {x: y for x, y in sequence_length.items()}
115
+ if reserved_for_packing:
116
+ inputs_length -= reserved_for_packing
117
+ for x, y in sequence_length.items():
118
+ sequence_lengths[x] = y - reserved_for_packing
119
+ hyperparams = list(zip(mean_noise_span_lengths, noise_densities))
120
+ for mean_noise_span_length, noise_density in hyperparams:
121
+ input_length, targets_length = t5.data.preprocessors.random_spans_helper(
122
+ extra_tokens_per_span_inputs=1,
123
+ extra_tokens_per_span_targets=1,
124
+ inputs_length=inputs_length,
125
+ mean_noise_span_length=mean_noise_span_length,
126
+ noise_density=noise_density,
127
+ )
128
+ input_lengths.append(input_length)
129
+ targets_lengths.append(targets_length)
130
+
131
+ if sequence_length["targets"] < targets_length:
132
+ upper_bound = max(targets_lengths)
133
+ raise ValueError(
134
+ f"Expected max targets length for span corruption ({upper_bound}) is "
135
+ f"greater than configured targets length "
136
+ f"({sequence_length['targets']})"
137
+ )
138
+ ds = dataset
139
+ ds = t5.data.preprocessors.select_random_chunk(
140
+ ds, output_features=output_features, feature_key="targets", max_length=65536
141
+ )
142
+ if merge_examples_to_reduce_padding:
143
+ ds = t5.data.preprocessors.reduce_concat_tokens(
144
+ ds, feature_key="targets", batch_size=128
145
+ )
146
+ num_shards = len(input_lengths) + int(use_prefix_lm_task)
147
+ if shard_ds:
148
+ ds_shards = [ds.shard(num_shards, i) for i in range(num_shards)]
149
+ else:
150
+ ds_shards = [ds for _ in range(num_shards)]
151
+ processed_ds = []
152
+ hyperparams = zip(input_lengths, hyperparams, range(num_shards))
153
+ for input_length, (noise_span_length, noise_density), i in hyperparams:
154
+ ds = ds_shards[i]
155
+ ds = t5.data.preprocessors.split_tokens(
156
+ ds,
157
+ feature_key="targets",
158
+ min_tokens_per_segment=None,
159
+ max_tokens_per_segment=input_length,
160
+ )
161
+ ds = t5.data.preprocessors.denoise(
162
+ ds,
163
+ output_features,
164
+ inputs_fn=t5.data.preprocessors.noise_span_to_unique_sentinel,
165
+ targets_fn=t5.data.preprocessors.nonnoise_span_to_unique_sentinel,
166
+ noise_density=noise_density,
167
+ noise_mask_fn=functools.partial(
168
+ t5.data.preprocessors.random_spans_noise_mask,
169
+ mean_noise_span_length=noise_span_length,
170
+ ),
171
+ input_feature_key=input_feature_key,
172
+ )
173
+ if optional_task_prefixes:
174
+ ds = prepend_prompt(
175
+ ds,
176
+ output_features,
177
+ prompt_mode=optional_task_prefixes[i],
178
+ mode=optional_task_prefixes[i],
179
+ key=input_feature_key,
180
+ )
181
+ processed_ds.append(ds)
182
+ if use_prefix_lm_task:
183
+ ds = ds_shards[-1]
184
+ ds = prefix_lm(ds, sequence_lengths, output_features)
185
+ if optional_task_prefixes:
186
+ ds = prepend_prompt(
187
+ ds,
188
+ output_features,
189
+ prompt_mode=optional_task_prefixes[-1],
190
+ mode=optional_task_prefixes[-1],
191
+ key=input_feature_key,
192
+ )
193
+ processed_ds.append(ds)
194
+ ds = tf.data.experimental.sample_from_datasets(processed_ds, rates, seed)
195
+ return ds
ul2_tasks.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+ from typing import Dict
3
+
4
+ import seqio
5
+ import tensorflow as tf
6
+ from datasets import load_dataset, load_from_disk
7
+ from t5.evaluation import metrics
8
+ from seqio import utils, FunctionDataSource
9
+ import t5.data
10
+ from datasets import load_dataset, load_from_disk
11
+ from t5.data import postprocessors
12
+ from t5.data import preprocessors
13
+
14
+
15
+ from .ul2_objective import ul2_objective
16
+
17
+ # values from UL2 paper https://arxiv.org/pdf/2205.05131.pdf chapter 3.1.2 table 1
18
+ R_DENOISER_SPAN_LENGTHS = [3.0, 8.0]
19
+ X_DENOISER_SPAN_LENGTHS = [3.0, 8.0, 64.0, 64.0]
20
+ R_DENOISER_CORRUPT_RATES = [0.15, 0.15]
21
+ X_DENOISER_CORRUPT_RATES = [0.5, 0.5, 0.15, 0.5]
22
+
23
+ R_DENOISER_TOKEN_PREFIX = "[NLU]"
24
+ X_DENOISER_TOKEN_PREFIX = "[NLG]"
25
+ S_DENOISER_TOKEN_PREFIX = "[S2S]"
26
+
27
+ TaskRegistry = seqio.TaskRegistry
28
+
29
+ vocabulary = seqio.SentencePieceVocabulary("spiece.model")
30
+
31
+ DEFAULT_OUTPUT_FEATURES = {
32
+ "inputs": seqio.Feature(vocabulary=vocabulary, add_eos=True, required=False),
33
+ "targets": seqio.Feature(vocabulary=vocabulary, add_eos=True),
34
+ }
35
+
36
+ def gen_dataset(split, shuffle=False, seed=None, column="text", path=None, name=None):
37
+ dataset = load_dataset(path, name, streaming=True, use_auth_token=True)
38
+ if shuffle:
39
+ if seed:
40
+ dataset = dataset.shuffle(seed=seed)
41
+ else:
42
+ dataset = dataset.shuffle()
43
+ while True:
44
+ for item in dataset[str(split)]:
45
+ yield item[column]
46
+
47
+
48
+ def dataset_fn(split, shuffle_files, seed=None, path=None, name=None):
49
+ return tf.data.Dataset.from_generator(
50
+ functools.partial(
51
+ gen_dataset, split, shuffle_files, seed, path=path, name=name
52
+ ),
53
+ output_signature=tf.TensorSpec(shape=(), dtype=tf.string, name=path),
54
+ )
55
+
56
+
57
+ @utils.map_over_dataset
58
+ def target_to_key(x, key_map, target_key):
59
+ """Assign the value from the dataset to target_key in key_map"""
60
+ return {**key_map, target_key: x}
61
+
62
+ ## First way to add to task registry
63
+ dataset_name = 'Siddharth63/biological_dataset'
64
+ dataset = load_dataset(dataset_name)
65
+
66
+ dataset_shapes = {"train": dataset["train"].num_rows,
67
+ "validation": dataset["validation"].num_rows}
68
+
69
+ TaskRegistry.add(
70
+ "pretrain_biological_ul2",
71
+ source=seqio.FunctionDataSource(
72
+ dataset_fn=functools.partial(dataset_fn, dataset=dataset),
73
+ splits=("train", "validation"),
74
+ caching_permitted=False,
75
+ num_input_examples=dataset_shapes,
76
+ ),
77
+ preprocessors=[
78
+ functools.partial(
79
+ target_to_key, key_map={
80
+ "inputs": None,
81
+ "targets": None,
82
+ }, target_key="targets"),
83
+ seqio.preprocessors.tokenize,
84
+ functools.partial(
85
+ ul2_objective,
86
+ shard_ds=False,
87
+ use_prefix_lm_task=True, # use S-denoising
88
+ rates=[0.4 / len(R_DENOISER_SPAN_LENGTHS)]*len(R_DENOISER_SPAN_LENGTHS) + [
89
+ 0.4 / len(X_DENOISER_SPAN_LENGTHS)]*len(X_DENOISER_SPAN_LENGTHS) + [0.2], # equal total 40% rate for both R- and X-denoisers + 20% for S-denoising (suggested at the paper chapter 4.5)
90
+ mean_noise_span_lengths=R_DENOISER_SPAN_LENGTHS + X_DENOISER_SPAN_LENGTHS,
91
+ noise_densities=R_DENOISER_CORRUPT_RATES + X_DENOISER_CORRUPT_RATES,
92
+ optional_task_prefixes=[R_DENOISER_TOKEN_PREFIX]*len(R_DENOISER_SPAN_LENGTHS) + [
93
+ X_DENOISER_TOKEN_PREFIX]*len(X_DENOISER_SPAN_LENGTHS) + [S_DENOISER_TOKEN_PREFIX],
94
+ reserved_for_packing=1, # make room for task prefix token
95
+ ),
96
+ seqio.preprocessors.append_eos_after_trim,
97
+ ],
98
+ output_features={"targets": DEFAULT_OUTPUT_FEATURES["targets"]},
99
+ metric_fns=[metrics.accuracy]
100
+ )
101
+
102
+
103
+ ## Second way to add to task registry
104
+ # TaskRegistry.add(
105
+ # "pretrain_biological_ul2",
106
+ # source=seqio.FunctionDataSource(
107
+ # dataset_fn=functools.partial(
108
+ # dataset_fn, path="Siddharth63/biological_dataset", name="full"
109
+ # ),
110
+ # splits=("train", "validation"),
111
+ # caching_permitted=False,
112
+ # ),
113
+ # preprocessors=[
114
+ # functools.partial(
115
+ # target_to_key,
116
+ # key_map={
117
+ # "inputs": "text",
118
+ # "targets": "text",
119
+ # },
120
+ # target_key="targets",
121
+ # ),
122
+ # seqio.preprocessors.tokenize,
123
+ # functools.partial(
124
+ # ul2_objective,
125
+ # shard_ds=False,
126
+ # use_prefix_lm_task=True, # use S-denoising
127
+ # rates=[0.4 / len(R_DENOISER_SPAN_LENGTHS)] * len(R_DENOISER_SPAN_LENGTHS)
128
+ # + [0.4 / len(X_DENOISER_SPAN_LENGTHS)] * len(X_DENOISER_SPAN_LENGTHS)
129
+ # + [
130
+ # 0.2
131
+ # ], # equal total 40% rate for both R- and X-denoisers + 20% for S-denoising (suggested at the paper chapter 4.5)
132
+ # mean_noise_span_lengths=R_DENOISER_SPAN_LENGTHS + X_DENOISER_SPAN_LENGTHS,
133
+ # noise_densities=R_DENOISER_CORRUPT_RATES + X_DENOISER_CORRUPT_RATES,
134
+ # optional_task_prefixes=[R_DENOISER_TOKEN_PREFIX]
135
+ # * len(R_DENOISER_SPAN_LENGTHS)
136
+ # + [X_DENOISER_TOKEN_PREFIX] * len(X_DENOISER_SPAN_LENGTHS)
137
+ # + [S_DENOISER_TOKEN_PREFIX],
138
+ # reserved_for_packing=1, # make room for task prefix token
139
+ # ),
140
+ # seqio.preprocessors.append_eos_after_trim,
141
+ # ],
142
+ # output_features={
143
+ # "targets": DEFAULT_OUTPUT_FEATURES["targets"],
144
+ # "inputs": seqio.Feature(vocabulary=vocabulary, add_eos=True),
145
+ # },
146
+ # metric_fns=[metrics.accuracy],
147
+ # )