#!/usr/bin/env python3 | |
from datasets import load_dataset | |
from datasets import load_from_disk | |
from tokenizers import ByteLevelBPETokenizer, SentencePieceBPETokenizer | |
from tqdm import tqdm | |
from utils import keep_devnagri | |
# load dataset | |
dataset = load_dataset("mc4", "hi", split="train", streaming=True) | |
# Instantiate tokenizer | |
tokenizer = SentencePieceBPETokenizer(add_prefix_space=True) | |
def batch_iterator(batch_size=100_000): | |
# total docs: 1,85,07,273 | |
text_ls = [] | |
for example in dataset: | |
devnagari_text, is_just_punctuation = keep_devnagri(example['text']) | |
if not is_just_punctuation: | |
text_ls.append(devnagari_text) | |
if len(text_ls) == batch_size: | |
yield text_ls | |
text_ls = [] | |
if len(text_ls) > 0: | |
yield text_ls | |
# Customized training | |
tokenizer.train_from_iterator(batch_iterator(), vocab_size=50265, min_frequency=50, special_tokens=[ | |
"<s>", | |
"<pad>", | |
"</s>", | |
"<unk>", | |
"<mask>", | |
], ) | |
# Save files to disk | |
tokenizer.save("/home/khandelia1000/tokenizer.json") | |