Spaces:
Runtime error
Runtime error
#!/usr/bin/env python | |
# coding=utf-8 | |
# Copyright 2023 The HuggingFace Inc. team. All rights reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""Script for training a Unigram tokenizer.""" | |
import argparse | |
import logging | |
import datasets | |
from tokenizers import Tokenizer, decoders, normalizers, pre_tokenizers, processors | |
from tokenizers.models import Unigram | |
from tokenizers.trainers import UnigramTrainer | |
from transformers import AlbertTokenizerFast | |
logger = logging.getLogger(__name__) | |
def parse_args(): | |
parser = argparse.ArgumentParser(description="Train a unigram tokenizer on the wikitext dataset.") | |
parser.add_argument( | |
"--dataset_name", | |
type=str, | |
default="wikitext", | |
help="Name of the training. Explore datasets at: hf.co/datasets.", | |
) | |
parser.add_argument( | |
"--dataset_config", type=str, default="wikitext-103-raw-v1", help="Configuration name of the dataset." | |
) | |
parser.add_argument( | |
"--batch_size", | |
type=int, | |
default=1000, | |
help="Batch size during training.", | |
) | |
parser.add_argument( | |
"--vocab_size", | |
type=int, | |
default=10048, | |
help="Size of the desired vocabulary.", | |
) | |
parser.add_argument( | |
"--limit", | |
default=None, | |
type=int, | |
help="Limit the number of shards (used for debugging).", | |
) | |
parser.add_argument( | |
"--export_to_hub", | |
action="store_true", | |
) | |
args = parser.parse_args() | |
return args | |
def main(args): | |
dataset = datasets.load_dataset(args.dataset_name, args.dataset_config, split="train") | |
if args.limit is not None: | |
max_train_samples = min(len(dataset), args.limit) | |
dataset = dataset.select(range(max_train_samples)) | |
logger.info(f"Limiting the dataset to {args.limit} entries.") | |
def batch_iterator(): | |
for i in range(0, len(dataset), args.batch_size): | |
yield dataset[i : i + args.batch_size]["text"] | |
# Prepare the tokenizer. | |
tokenizer = Tokenizer(Unigram()) | |
tokenizer.normalizer = normalizers.Sequence([normalizers.Replace("``", '"'), normalizers.Replace("''", '"')]) | |
tokenizer.pre_tokenizer = pre_tokenizers.Metaspace() | |
# Prepare the trainer. | |
trainer = UnigramTrainer( | |
unk_token="<unk>", | |
special_tokens=["[CLS]", "[SEP]", "<unk>", "<pad>", "[MASK]"], | |
vocab_size=args.vocab_size, | |
) | |
logger.info("Training the tokenizer.") | |
tokenizer.train_from_iterator(batch_iterator(), trainer=trainer) | |
logger.info("Tokenizer training complete!") | |
cls_token_id = tokenizer.token_to_id("[CLS]") | |
sep_token_id = tokenizer.token_to_id("[SEP]") | |
tokenizer.post_processor = processors.TemplateProcessing( | |
single="[CLS]:0 $A:0 [SEP]:0", | |
pair="[CLS]:0 $A:0 [SEP]:0 $B:1 [SEP]:1", | |
special_tokens=[ | |
("[CLS]", cls_token_id), | |
("[SEP]", sep_token_id), | |
], | |
) | |
tokenizer.decoder = decoders.Metaspace() | |
if args.export_to_hub: | |
logger.info("Exporting the trained tokenzier to Hub.") | |
new_tokenizer = AlbertTokenizerFast(tokenizer_object=tokenizer) | |
new_tokenizer.push_to_hub("unigram-tokenizer-dataset") | |
if __name__ == "__main__": | |
args = parse_args() | |
main(args) | |