NLPV commited on
Commit
07f96fd
·
verified ·
1 Parent(s): 858cd2d

Delete train_sentencepiece.py

Browse files
Files changed (1) hide show
  1. train_sentencepiece.py +0 -133
train_sentencepiece.py DELETED
@@ -1,133 +0,0 @@
1
- # Copyright 2024 The TensorFlow Authors. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- """A script to train sentencepiece model from tensorflow datasets.
16
-
17
- Reserved tokens:
18
- pad: 0,
19
- eos: 1,
20
- unk: 2
21
- (bos is not reserved)
22
- """
23
-
24
- import os
25
- import tempfile
26
- from typing import List, Tuple
27
-
28
- from absl import app
29
- from absl import flags
30
- from absl import logging
31
- import tensorflow as tf, tf_keras
32
- import tensorflow_datasets as tfds
33
-
34
- from sentencepiece import SentencePieceTrainer
35
-
36
-
37
- FLAGS = flags.FLAGS
38
- flags.DEFINE_string("output_model_path", None,
39
- "Path to save the sentencepiece model.")
40
- flags.mark_flag_as_required("output_model_path")
41
-
42
- flags.DEFINE_string("tfds_dir", None, "Directory of the tfds.")
43
- flags.DEFINE_string("tfds_name", "wmt14_translate/de-en",
44
- "Name of the dataset we generate vacabulay from.")
45
- flags.DEFINE_string("tfds_split", "train", "Split of the dataset.")
46
- flags.DEFINE_integer("vocab_size", 32000, "Size of vocabulary.")
47
- flags.DEFINE_integer(
48
- "max_char", -1,
49
- "Maximum number of characters to use. "
50
- "If a non-positive number is provided, all sentences are used.")
51
- flags.DEFINE_string("model_type", "bpe",
52
- "Model algorithm: unigram, bpe, word or char.")
53
- flags.DEFINE_float("character_coverage", 0.9995,
54
- "Character coverage to determine the minimum symbols")
55
- flags.DEFINE_list(
56
- "data_keys", ["en", "de"],
57
- "Comma-separated list of keys to use for training the vocabulary.")
58
-
59
-
60
- def dump_chars_to_textfile(dataset: tf.data.Dataset,
61
- data_keys: Tuple[str],
62
- max_char: int = -1):
63
- """Write part of a TFDS sentence dataset to lines in a text file.
64
-
65
- Args:
66
- dataset: tf.dataset containing string-data.
67
- data_keys: what keys in dataset to dump from.
68
- max_char: max character to dump to text file.
69
-
70
- Returns:
71
- name of temp file with dataset bytes, exact number of characters dumped.
72
- """
73
- ds_iter = dataset.as_numpy_iterator()
74
- with tempfile.NamedTemporaryFile(delete=False) as outfp:
75
- char_count = 0
76
- while True:
77
- example = next(ds_iter, None)
78
- if example is None or (
79
- max_char > 0 and char_count > max_char):
80
- break
81
- for k in data_keys:
82
- line = example[k] + b"\n"
83
- char_count += len(line)
84
- outfp.write(line)
85
- return outfp.name
86
-
87
-
88
- def train_sentencepiece(
89
- file_path: str,
90
- model_path: str,
91
- vocab_size: int,
92
- character_coverage: float,
93
- model_type: str):
94
- """Train SentencePiece tokenizer from subset of tf dataset.
95
-
96
- Args:
97
- file_path: path of data to train sentencepiece.
98
- model_path: path of model file to save vocab model to.
99
- vocab_size: size of vocab tokens to train.
100
- character_coverage: amount of characters covered by the model, good defaults
101
- are 0.9995 for languages with rich character set like Japanese or Chinese
102
- and 1.0 for other languages with small character set.
103
- model_type: type of sentencepiece vocab to train.
104
-
105
- Returns:
106
- path to the trained sentencepiece vocabulary model.
107
- """
108
- argstr = " ".join([
109
- f"--input={file_path}", f"--vocab_size={vocab_size}",
110
- f"--character_coverage={character_coverage}",
111
- f"--model_prefix={model_path}", f"--model_type={model_type}",
112
- "--bos_id=-1", "--pad_id=0", "--eos_id=1", "--unk_id=2"
113
- ])
114
- SentencePieceTrainer.Train(argstr)
115
-
116
-
117
- def main(argv: List[str]):
118
- del argv
119
- builder = tfds.builder(FLAGS.tfds_name, data_dir=FLAGS.tfds_dir)
120
- ds = builder.as_dataset(split=FLAGS.tfds_split)
121
- tmp_filename = dump_chars_to_textfile(ds, FLAGS.data_keys, FLAGS.max_char)
122
- logging.info("Sentencepiece model will be placed here: %s",
123
- FLAGS.output_model_path)
124
- train_sentencepiece(tmp_filename,
125
- FLAGS.output_model_path,
126
- FLAGS.vocab_size,
127
- FLAGS.character_coverage,
128
- FLAGS.model_type)
129
- os.remove(tmp_filename)
130
-
131
-
132
- if __name__ == "__main__":
133
- app.run(main)