aapot commited on
Commit
5d1d3e2
·
1 Parent(s): 062405e

Initial commit

Browse files
.gitattributes CHANGED
@@ -1,6 +1,7 @@
1
  *.7z filter=lfs diff=lfs merge=lfs -text
2
  *.arrow filter=lfs diff=lfs merge=lfs -text
3
  *.bin filter=lfs diff=lfs merge=lfs -text
 
4
  *.bz2 filter=lfs diff=lfs merge=lfs -text
5
  *.ftz filter=lfs diff=lfs merge=lfs -text
6
  *.gz filter=lfs diff=lfs merge=lfs -text
@@ -9,14 +10,10 @@
9
  *.lfs.* filter=lfs diff=lfs merge=lfs -text
10
  *.model filter=lfs diff=lfs merge=lfs -text
11
  *.msgpack filter=lfs diff=lfs merge=lfs -text
12
- *.npy filter=lfs diff=lfs merge=lfs -text
13
- *.npz filter=lfs diff=lfs merge=lfs -text
14
  *.onnx filter=lfs diff=lfs merge=lfs -text
15
  *.ot filter=lfs diff=lfs merge=lfs -text
16
  *.parquet filter=lfs diff=lfs merge=lfs -text
17
  *.pb filter=lfs diff=lfs merge=lfs -text
18
- *.pickle filter=lfs diff=lfs merge=lfs -text
19
- *.pkl filter=lfs diff=lfs merge=lfs -text
20
  *.pt filter=lfs diff=lfs merge=lfs -text
21
  *.pth filter=lfs diff=lfs merge=lfs -text
22
  *.rar filter=lfs diff=lfs merge=lfs -text
@@ -27,5 +24,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
  *.wasm filter=lfs diff=lfs merge=lfs -text
28
  *.xz filter=lfs diff=lfs merge=lfs -text
29
  *.zip filter=lfs diff=lfs merge=lfs -text
30
- *.zst filter=lfs diff=lfs merge=lfs -text
31
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
1
  *.7z filter=lfs diff=lfs merge=lfs -text
2
  *.arrow filter=lfs diff=lfs merge=lfs -text
3
  *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bin.* filter=lfs diff=lfs merge=lfs -text
5
  *.bz2 filter=lfs diff=lfs merge=lfs -text
6
  *.ftz filter=lfs diff=lfs merge=lfs -text
7
  *.gz filter=lfs diff=lfs merge=lfs -text
 
10
  *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
  *.model filter=lfs diff=lfs merge=lfs -text
12
  *.msgpack filter=lfs diff=lfs merge=lfs -text
 
 
13
  *.onnx filter=lfs diff=lfs merge=lfs -text
14
  *.ot filter=lfs diff=lfs merge=lfs -text
15
  *.parquet filter=lfs diff=lfs merge=lfs -text
16
  *.pb filter=lfs diff=lfs merge=lfs -text
 
 
17
  *.pt filter=lfs diff=lfs merge=lfs -text
18
  *.pth filter=lfs diff=lfs merge=lfs -text
19
  *.rar filter=lfs diff=lfs merge=lfs -text
 
24
  *.wasm filter=lfs diff=lfs merge=lfs -text
25
  *.xz filter=lfs diff=lfs merge=lfs -text
26
  *.zip filter=lfs diff=lfs merge=lfs -text
27
+ *.zstandard filter=lfs diff=lfs merge=lfs -text
28
  *tfevents* filter=lfs diff=lfs merge=lfs -text
29
+ checkpoint*/** filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ __pycache__/
config.json ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "Finnish-NLP/t5-small-nl16-finnish",
3
+ "architectures": [
4
+ "T5ForConditionalGeneration"
5
+ ],
6
+ "d_ff": 2048,
7
+ "d_kv": 64,
8
+ "d_model": 512,
9
+ "decoder_start_token_id": 0,
10
+ "dense_act_fn": "gelu_new",
11
+ "dropout_rate": 0.1,
12
+ "eos_token_id": 1,
13
+ "feed_forward_proj": "gated-gelu",
14
+ "initializer_factor": 1.0,
15
+ "is_encoder_decoder": true,
16
+ "is_gated_act": true,
17
+ "layer_norm_epsilon": 1e-06,
18
+ "model_type": "t5",
19
+ "n_positions": 512,
20
+ "num_decoder_layers": 16,
21
+ "num_heads": 8,
22
+ "num_layers": 16,
23
+ "output_past": true,
24
+ "pad_token_id": 0,
25
+ "relative_attention_max_distance": 128,
26
+ "relative_attention_num_buckets": 32,
27
+ "tie_word_embeddings": false,
28
+ "torch_dtype": "float32",
29
+ "transformers_version": "4.20.1",
30
+ "use_cache": true,
31
+ "vocab_size": 32128
32
+ }
convert_t5x_checkpoint_to_flax.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://gist.github.com/stefan-it/30e4998ef159f33696e377a46f699d9f
2
+
3
+ import argparse
4
+
5
+ from t5x import checkpoints
6
+ from transformers import T5Config, FlaxT5ForConditionalGeneration
7
+
8
+
9
+ def convert_t5x_checkpoint_to_flax(t5x_checkpoint_path, config_name, flax_dump_folder_path):
10
+ config = T5Config.from_pretrained(config_name)
11
+ flax_model = FlaxT5ForConditionalGeneration(config=config)
12
+ t5x_model = checkpoints.load_t5x_checkpoint(t5x_checkpoint_path)
13
+
14
+ split_mlp_wi = "wi_0" in t5x_model["target"]["encoder"]["layers_0"]["mlp"]
15
+
16
+ # Encoder
17
+ for layer_index in range(config.num_layers):
18
+ layer_name = f"layers_{str(layer_index)}"
19
+
20
+ # Self-Attention
21
+ t5x_attention_key = t5x_model["target"]["encoder"][layer_name]["attention"]["key"]["kernel"]
22
+ t5x_attention_out = t5x_model["target"]["encoder"][layer_name]["attention"]["out"]["kernel"]
23
+ t5x_attention_query = t5x_model["target"]["encoder"][layer_name]["attention"]["query"]["kernel"]
24
+ t5x_attention_value = t5x_model["target"]["encoder"][layer_name]["attention"]["value"]["kernel"]
25
+
26
+ ## Layer Normalization
27
+ t5x_attention_layer_norm = t5x_model["target"]["encoder"][layer_name]["pre_attention_layer_norm"]["scale"]
28
+
29
+ if split_mlp_wi:
30
+ t5x_mlp_wi_0 = t5x_model["target"]["encoder"][layer_name]["mlp"]["wi_0"]["kernel"]
31
+ t5x_mlp_wi_1 = t5x_model["target"]["encoder"][layer_name]["mlp"]["wi_1"]["kernel"]
32
+ else:
33
+ t5x_mlp_wi = t5x_model["target"]["encoder"][layer_name]["mlp"]["wi"]["kernel"]
34
+
35
+ t5x_mlp_wo = t5x_model["target"]["encoder"][layer_name]["mlp"]["wo"]["kernel"]
36
+
37
+ ## Layer Normalization
38
+ t5x_mlp_layer_norm = t5x_model["target"]["encoder"][layer_name]["pre_mlp_layer_norm"]["scale"]
39
+
40
+ # Assigning
41
+ flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["0"]["SelfAttention"]["k"]["kernel"] = t5x_attention_key
42
+ flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["0"]["SelfAttention"]["o"]["kernel"] = t5x_attention_out
43
+ flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["0"]["SelfAttention"]["q"]["kernel"] = t5x_attention_query
44
+ flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["0"]["SelfAttention"]["v"]["kernel"] = t5x_attention_value
45
+
46
+ flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["0"]["layer_norm"]["weight"] = t5x_attention_layer_norm
47
+
48
+ if split_mlp_wi:
49
+ flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["1"]["DenseReluDense"]["wi_0"]["kernel"] = t5x_mlp_wi_0
50
+ flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["1"]["DenseReluDense"]["wi_1"]["kernel"] = t5x_mlp_wi_1
51
+ else:
52
+ flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["1"]["DenseReluDense"]["wi"]["kernel"] = t5x_mlp_wi
53
+
54
+ flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["1"]["DenseReluDense"]["wo"]["kernel"] = t5x_mlp_wo
55
+ flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["1"]["layer_norm"]["weight"] = t5x_mlp_layer_norm
56
+
57
+ # Only for layer 0:
58
+ t5x_encoder_rel_embedding = t5x_model["target"]["encoder"]["relpos_bias"]["rel_embedding"].T
59
+ flax_model.params["encoder"]["block"]["0"]["layer"]["0"]["SelfAttention"]["relative_attention_bias"]["embedding"] = t5x_encoder_rel_embedding
60
+
61
+ # Assigning
62
+ t5x_encoder_norm = t5x_model["target"]["encoder"]["encoder_norm"]["scale"]
63
+ flax_model.params["encoder"]["final_layer_norm"]["weight"] = t5x_encoder_norm
64
+
65
+ # Decoder
66
+ for layer_index in range(config.num_layers):
67
+ layer_name = f"layers_{str(layer_index)}"
68
+
69
+ # Self-Attention
70
+ t5x_attention_key = t5x_model["target"]["decoder"][layer_name]["self_attention"]["key"]["kernel"]
71
+ t5x_attention_out = t5x_model["target"]["decoder"][layer_name]["self_attention"]["out"]["kernel"]
72
+ t5x_attention_query = t5x_model["target"]["decoder"][layer_name]["self_attention"]["query"]["kernel"]
73
+ t5x_attention_value = t5x_model["target"]["decoder"][layer_name]["self_attention"]["value"]["kernel"]
74
+
75
+ ## Layer Normalization
76
+ t5x_pre_attention_layer_norm = t5x_model["target"]["decoder"][layer_name]["pre_self_attention_layer_norm"]["scale"]
77
+
78
+ # Encoder-Decoder-Attention
79
+ t5x_enc_dec_attention_key = t5x_model["target"]["decoder"][layer_name]["encoder_decoder_attention"]["key"]["kernel"]
80
+ t5x_enc_dec_attention_out = t5x_model["target"]["decoder"][layer_name]["encoder_decoder_attention"]["out"]["kernel"]
81
+ t5x_enc_dec_attention_query = t5x_model["target"]["decoder"][layer_name]["encoder_decoder_attention"]["query"]["kernel"]
82
+ t5x_enc_dec_attention_value = t5x_model["target"]["decoder"][layer_name]["encoder_decoder_attention"]["value"]["kernel"]
83
+
84
+ ## Layer Normalization
85
+ t5x_cross_layer_norm = t5x_model["target"]["decoder"][layer_name]["pre_cross_attention_layer_norm"]["scale"]
86
+
87
+ # MLP
88
+ if split_mlp_wi:
89
+ t5x_mlp_wi_0 = t5x_model["target"]["decoder"][layer_name]["mlp"]["wi_0"]["kernel"]
90
+ t5x_mlp_wi_1 = t5x_model["target"]["decoder"][layer_name]["mlp"]["wi_1"]["kernel"]
91
+ else:
92
+ t5x_mlp_wi = t5x_model["target"]["decoder"][layer_name]["mlp"]["wi"]["kernel"]
93
+
94
+ t5x_mlp_wo = t5x_model["target"]["decoder"][layer_name]["mlp"]["wo"]["kernel"]
95
+
96
+ ## Layer Normalization
97
+ tx5_mlp_layer_norm = t5x_model["target"]["decoder"][layer_name]["pre_mlp_layer_norm"]["scale"]
98
+
99
+ # Assigning
100
+ flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["0"]["SelfAttention"]["k"]["kernel"] = t5x_attention_key
101
+ flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["0"]["SelfAttention"]["o"]["kernel"] = t5x_attention_out
102
+ flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["0"]["SelfAttention"]["q"]["kernel"] = t5x_attention_query
103
+ flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["0"]["SelfAttention"]["v"]["kernel"] = t5x_attention_value
104
+
105
+ flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["0"]["layer_norm"]["weight"] = t5x_pre_attention_layer_norm
106
+
107
+ flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["1"]["EncDecAttention"]["k"]["kernel"] = t5x_enc_dec_attention_key
108
+ flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["1"]["EncDecAttention"]["o"]["kernel"] = t5x_enc_dec_attention_out
109
+ flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["1"]["EncDecAttention"]["q"]["kernel"] = t5x_enc_dec_attention_query
110
+ flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["1"]["EncDecAttention"]["v"]["kernel"] = t5x_enc_dec_attention_value
111
+
112
+ flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["1"]["layer_norm"]["weight"] = t5x_cross_layer_norm
113
+
114
+ if split_mlp_wi:
115
+ flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["2"]["DenseReluDense"]["wi_0"]["kernel"] = t5x_mlp_wi_0
116
+ flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["2"]["DenseReluDense"]["wi_1"]["kernel"] = t5x_mlp_wi_1
117
+ else:
118
+ flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["2"]["DenseReluDense"]["wi"]["kernel"] = t5x_mlp_wi
119
+
120
+ flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["2"]["DenseReluDense"]["wo"]["kernel"] = t5x_mlp_wo
121
+
122
+ flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["2"]["layer_norm"]["weight"] = tx5_mlp_layer_norm
123
+
124
+ # Decoder Normalization
125
+ tx5_decoder_norm = t5x_model["target"]["decoder"]["decoder_norm"]["scale"]
126
+ flax_model.params["decoder"]["final_layer_norm"]["weight"] = tx5_decoder_norm
127
+
128
+ # Only for layer 0:
129
+ t5x_decoder_rel_embedding = t5x_model["target"]["decoder"]["relpos_bias"]["rel_embedding"].T
130
+ flax_model.params["decoder"]["block"]["0"]["layer"]["0"]["SelfAttention"]["relative_attention_bias"]["embedding"] = t5x_decoder_rel_embedding
131
+
132
+ # Token Embeddings
133
+ tx5_token_embeddings = t5x_model["target"]["token_embedder"]["embedding"]
134
+ flax_model.params["shared"]["embedding"] = tx5_token_embeddings
135
+
136
+ # LM Head
137
+ flax_model.params["lm_head"]["kernel"] = t5x_model["target"]["decoder"]["logits_dense"]["kernel"]
138
+
139
+ flax_model.save_pretrained(flax_dump_folder_path)
140
+ print("T5X Model was sucessfully converted!")
141
+
142
+
143
+ if __name__ == "__main__":
144
+ parser = argparse.ArgumentParser()
145
+ # Required parameters
146
+ parser.add_argument(
147
+ "--t5x_checkpoint_path", default=None, type=str, required=True, help="Path the TX5 checkpoint."
148
+ )
149
+ parser.add_argument(
150
+ "--config_name", default=None, type=str, required=True, help="Config name of T5 model."
151
+ )
152
+ parser.add_argument(
153
+ "--flax_dump_folder_path", default=None, type=str, required=True, help="Path to the output FLAX model."
154
+ )
155
+ args = parser.parse_args()
156
+ convert_t5x_checkpoint_to_flax(args.t5x_checkpoint_path, args.config_name, args.flax_dump_folder_path)
157
+
small_nl16.gin ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # T5.1.1 Efficient small nl16 model.
2
+
3
+ import seqio
4
+ include 't5x/examples/scalable_t5/t5_1_1/base.gin' # imports vocab, optimizer and model.
5
+
6
+ # ------------------- Network specification overrides --------------------------
7
+ network.Transformer.config = @network.T5Config()
8
+ network.T5Config:
9
+ emb_dim = 512
10
+ num_heads = 8
11
+ num_encoder_layers = 16
12
+ num_decoder_layers = 16
13
+ head_dim = 64
14
+ mlp_dim = 2048
15
+
16
+ # ------------------- Model specification overrides --------------------------
17
+ VOCABULARY = @seqio.SentencePieceVocabulary()
18
+ seqio.SentencePieceVocabulary.sentencepiece_model_file = "spiece.model"
19
+
20
+ MODEL = @models.EncoderDecoderModel()
21
+ models.EncoderDecoderModel:
22
+ input_vocabulary = %VOCABULARY
23
+ output_vocabulary = %VOCABULARY
small_nl16_pretrain.gin ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Register necessary SeqIO Tasks/Mixtures.
2
+ from __gin__ import dynamic_registration
3
+ from t5x import utils
4
+ from t5x import partitioning
5
+ import tasks
6
+ import __main__ as train_script
7
+
8
+ include 'small_nl16.gin'
9
+ include 't5x/configs/runs/pretrain.gin'
10
+
11
+
12
+ # ------------------- Training specification overrides --------------------------
13
+ train_script.train:
14
+ eval_period = 10000
15
+ use_gda = False
16
+
17
+ utils.SaveCheckpointConfig:
18
+ period = 10000
19
+ keep = 10
20
+
21
+ MIXTURE_OR_TASK_NAME = "pretrain_finnish"
22
+ USE_CACHED_TASKS = False
23
+ TASK_FEATURE_LENGTHS = {"inputs": 512, "targets": 512}
24
+ TRAIN_STEPS = 500000
25
+ DROPOUT_RATE = 0.0
26
+ BATCH_SIZE = 256
spiece.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:55a3645122435e9773fac81fa3f94c1e14149e80311636dfa9245fba3e57a826
3
+ size 824186
spiece.vocab ADDED
The diff for this file is too large to render. See raw diff
 
start_train.sh ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # set train hyperparams
2
+ unset LD_PRELOAD
3
+
4
+ PROJECT_DIR="/researchdisk/t5-small-nl16-finnish"
5
+ T5X_DIR=${HOME}"/t5x" # directory where the t5x is cloned.
6
+ MODEL_DIR="/researchdisk/t5-small-nl16-finnish"
7
+ export PYTHONPATH=${PROJECT_DIR}
8
+
9
+ python3 ${T5X_DIR}/t5x/train.py \
10
+ --gin_search_paths=${PROJECT_DIR} \
11
+ --gin_file="small_nl16_pretrain.gin" \
12
+ --gin.MODEL_DIR=\"${MODEL_DIR}\"
tasks.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # adapted from https://huggingface.co/pere/pk-nb-t5x/blob/main/tasks.py
2
+
3
+ import functools
4
+
5
+ import seqio
6
+ import tensorflow as tf
7
+ import t5.data
8
+ from datasets import load_dataset, load_from_disk
9
+ from t5.data import postprocessors
10
+ from t5.data import preprocessors
11
+ from t5.evaluation import metrics
12
+ from seqio import FunctionDataSource, utils
13
+
14
+ TaskRegistry = seqio.TaskRegistry
15
+
16
+ vocabulary = seqio.SentencePieceVocabulary('spiece.model', extra_ids=0)
17
+
18
+ DEFAULT_OUTPUT_FEATURES = {
19
+ "inputs": seqio.Feature(
20
+ vocabulary=vocabulary, add_eos=True,
21
+ required=False),
22
+ "targets": seqio.Feature(
23
+ vocabulary=vocabulary, add_eos=True)
24
+ }
25
+
26
+
27
+ def gen_dataset(split, shuffle=False, seed=None, column="text", dataset=None):
28
+ if shuffle:
29
+ if seed:
30
+ dataset = dataset.shuffle(seed=seed)
31
+ else:
32
+ dataset = dataset.shuffle()
33
+ while True:
34
+ for item in dataset[str(split)]:
35
+ yield item[column]
36
+
37
+
38
+ def dataset_fn(split, shuffle_files, seed=None, dataset=None):
39
+ return tf.data.Dataset.from_generator(
40
+ functools.partial(gen_dataset, split, shuffle_files, seed, dataset=dataset),
41
+ output_signature=tf.TensorSpec(shape=(), dtype=tf.string, name=dataset_name)
42
+ )
43
+
44
+
45
+ @utils.map_over_dataset
46
+ def target_to_key(x, key_map, target_key):
47
+ """Assign the value from the dataset to target_key in key_map"""
48
+ return {**key_map, target_key: x}
49
+
50
+
51
+ # Final pretraining task used in Raffel et al., 2019 adaptated to our use
52
+ dataset_name = "/researchdisk/lm_training_dataset_full"
53
+ dataset_params = {"from_disk_path": dataset_name}
54
+
55
+ if "from_disk_path" in dataset_params:
56
+ dataset = load_from_disk(dataset_params.get("from_disk_path"))
57
+ else:
58
+ dataset = load_dataset(**dataset_params)
59
+
60
+ dataset_shapes = {"train": dataset["train"].num_rows, "validation": dataset["validation"].num_rows}
61
+ TaskRegistry.add(
62
+ "pretrain_finnish",
63
+ source=seqio.FunctionDataSource(
64
+ dataset_fn=functools.partial(dataset_fn, dataset=dataset),
65
+ splits=("train", "validation"),
66
+ caching_permitted=False,
67
+ num_input_examples=dataset_shapes,
68
+ ),
69
+ preprocessors=[
70
+ functools.partial(
71
+ target_to_key, key_map={
72
+ "inputs": None,
73
+ "targets": None,
74
+ }, target_key="targets"),
75
+ seqio.preprocessors.tokenize,
76
+ # seqio.CacheDatasetPlaceholder(),
77
+ preprocessors.span_corruption,
78
+ seqio.preprocessors.append_eos_after_trim,
79
+ ],
80
+ output_features={"targets": DEFAULT_OUTPUT_FEATURES["targets"]},
81
+ metric_fns=[metrics.accuracy]
82
+ )
train_sentencepiece.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ import sentencepiece as spm
2
+
3
+ spm.SentencePieceTrainer.train(input="/researchdisk/lm_training_dataset_full_sentences/train.txt", model_prefix='spiece', vocab_size=32000, character_coverage=1.0,
4
+ pad_id=0, unk_id=2, eos_id=1, bos_id=-1,
5
+ train_extremely_large_corpus=True,
6
+ num_threads=96, input_sentence_size=50000000, shuffle_input_sentence=True)