Spaces:
Running
Running
# Copyright 2024 The TensorFlow Authors. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""A script to train sentencepiece model from tensorflow datasets. | |
Reserved tokens: | |
pad: 0, | |
eos: 1, | |
unk: 2 | |
(bos is not reserved) | |
""" | |
import os | |
import tempfile | |
from typing import List, Tuple | |
from absl import app | |
from absl import flags | |
from absl import logging | |
import tensorflow as tf, tf_keras | |
import tensorflow_datasets as tfds | |
from sentencepiece import SentencePieceTrainer | |
FLAGS = flags.FLAGS | |
flags.DEFINE_string("output_model_path", None, | |
"Path to save the sentencepiece model.") | |
flags.mark_flag_as_required("output_model_path") | |
flags.DEFINE_string("tfds_dir", None, "Directory of the tfds.") | |
flags.DEFINE_string("tfds_name", "wmt14_translate/de-en", | |
"Name of the dataset we generate vacabulay from.") | |
flags.DEFINE_string("tfds_split", "train", "Split of the dataset.") | |
flags.DEFINE_integer("vocab_size", 32000, "Size of vocabulary.") | |
flags.DEFINE_integer( | |
"max_char", -1, | |
"Maximum number of characters to use. " | |
"If a non-positive number is provided, all sentences are used.") | |
flags.DEFINE_string("model_type", "bpe", | |
"Model algorithm: unigram, bpe, word or char.") | |
flags.DEFINE_float("character_coverage", 0.9995, | |
"Character coverage to determine the minimum symbols") | |
flags.DEFINE_list( | |
"data_keys", ["en", "de"], | |
"Comma-separated list of keys to use for training the vocabulary.") | |
def dump_chars_to_textfile(dataset: tf.data.Dataset, | |
data_keys: Tuple[str], | |
max_char: int = -1): | |
"""Write part of a TFDS sentence dataset to lines in a text file. | |
Args: | |
dataset: tf.dataset containing string-data. | |
data_keys: what keys in dataset to dump from. | |
max_char: max character to dump to text file. | |
Returns: | |
name of temp file with dataset bytes, exact number of characters dumped. | |
""" | |
ds_iter = dataset.as_numpy_iterator() | |
with tempfile.NamedTemporaryFile(delete=False) as outfp: | |
char_count = 0 | |
while True: | |
example = next(ds_iter, None) | |
if example is None or ( | |
max_char > 0 and char_count > max_char): | |
break | |
for k in data_keys: | |
line = example[k] + b"\n" | |
char_count += len(line) | |
outfp.write(line) | |
return outfp.name | |
def train_sentencepiece( | |
file_path: str, | |
model_path: str, | |
vocab_size: int, | |
character_coverage: float, | |
model_type: str): | |
"""Train SentencePiece tokenizer from subset of tf dataset. | |
Args: | |
file_path: path of data to train sentencepiece. | |
model_path: path of model file to save vocab model to. | |
vocab_size: size of vocab tokens to train. | |
character_coverage: amount of characters covered by the model, good defaults | |
are 0.9995 for languages with rich character set like Japanese or Chinese | |
and 1.0 for other languages with small character set. | |
model_type: type of sentencepiece vocab to train. | |
Returns: | |
path to the trained sentencepiece vocabulary model. | |
""" | |
argstr = " ".join([ | |
f"--input={file_path}", f"--vocab_size={vocab_size}", | |
f"--character_coverage={character_coverage}", | |
f"--model_prefix={model_path}", f"--model_type={model_type}", | |
"--bos_id=-1", "--pad_id=0", "--eos_id=1", "--unk_id=2" | |
]) | |
SentencePieceTrainer.Train(argstr) | |
def main(argv: List[str]): | |
del argv | |
builder = tfds.builder(FLAGS.tfds_name, data_dir=FLAGS.tfds_dir) | |
ds = builder.as_dataset(split=FLAGS.tfds_split) | |
tmp_filename = dump_chars_to_textfile(ds, FLAGS.data_keys, FLAGS.max_char) | |
logging.info("Sentencepiece model will be placed here: %s", | |
FLAGS.output_model_path) | |
train_sentencepiece(tmp_filename, | |
FLAGS.output_model_path, | |
FLAGS.vocab_size, | |
FLAGS.character_coverage, | |
FLAGS.model_type) | |
os.remove(tmp_filename) | |
if __name__ == "__main__": | |
app.run(main) | |