|
""" |
|
Download, preprocess and serve the TinyStories dataset as a DataLoader. |
|
""" |
|
|
|
import argparse |
|
import glob |
|
import json |
|
import os |
|
import random |
|
from typing import List |
|
from concurrent.futures import ProcessPoolExecutor |
|
from functools import partial |
|
|
|
import numpy as np |
|
import requests |
|
import sentencepiece as spm |
|
import torch |
|
import torch.distributed as dist |
|
from tqdm import tqdm |
|
|
|
from tokenizer import Tokenizer |
|
|
|
DATA_CACHE_DIR = "data" |
|
|
|
def download_file(url: str, fname: str, chunk_size=1024): |
|
"""Helper function to download a file from a given url""" |
|
resp = requests.get(url, stream=True) |
|
total = int(resp.headers.get("content-length", 0)) |
|
with open(fname, "wb") as file, tqdm( |
|
desc=fname, |
|
total=total, |
|
unit="iB", |
|
unit_scale=True, |
|
unit_divisor=1024, |
|
) as bar: |
|
for data in resp.iter_content(chunk_size=chunk_size): |
|
size = file.write(data) |
|
bar.update(size) |
|
|
|
|
|
def download(): |
|
"""Downloads the TinyStories dataset to DATA_CACHE_DIR""" |
|
os.makedirs(DATA_CACHE_DIR, exist_ok=True) |
|
|
|
|
|
data_url = "https://huggingface.co/datasets/roneneldan/TinyStories/resolve/main/TinyStories_all_data.tar.gz" |
|
data_filename = os.path.join(DATA_CACHE_DIR, "TinyStories_all_data.tar.gz") |
|
if not os.path.exists(data_filename): |
|
print(f"Downloading {data_url} to {data_filename}...") |
|
download_file(data_url, data_filename) |
|
else: |
|
print(f"{data_filename} already exists, skipping download...") |
|
|
|
|
|
data_dir = os.path.join(DATA_CACHE_DIR, "TinyStories_all_data") |
|
if not os.path.exists(data_dir): |
|
os.makedirs(data_dir, exist_ok=True) |
|
print(f"Unpacking {data_filename}...") |
|
os.system(f"tar -xzf {data_filename} -C {data_dir}") |
|
else: |
|
print(f"{data_dir} already exists, skipping unpacking...") |
|
|
|
|
|
shard_filenames = sorted(glob.glob(os.path.join(data_dir, "*.json"))) |
|
with open(shard_filenames[0], "r") as f: |
|
data = json.load(f) |
|
print("Download done.") |
|
print(f"Number of shards: {len(shard_filenames)}") |
|
print(f"Example story:\n{data[0]}") |
|
|
|
def train_vocab(vocab_size): |
|
""" |
|
Trains a custom sentencepiece tokenizer on the TinyStories dataset. |
|
The custom tokenizer files will be saved in DATA_CACHE_DIR/tok{N} directories, |
|
where N is the vocab size. This is also where the pretok .bin files will go. |
|
""" |
|
assert vocab_size > 0, "Vocab size must be positive" |
|
|
|
|
|
prefix = os.path.join(DATA_CACHE_DIR, f"tok{vocab_size}") |
|
|
|
|
|
num_shards = 10 |
|
|
|
|
|
tiny_file = os.path.join(DATA_CACHE_DIR, "tiny.txt") |
|
data_dir = os.path.join(DATA_CACHE_DIR, "TinyStories_all_data") |
|
shard_filenames = sorted(glob.glob(os.path.join(data_dir, "*.json"))) |
|
|
|
print(f"Writing temporary file {tiny_file} with {num_shards} shards...") |
|
with open(tiny_file, "w", encoding="utf-8") as of: |
|
for shard in tqdm(shard_filenames[:num_shards]): |
|
with open(shard, "r") as f: |
|
data = json.load(f) |
|
for example in data: |
|
text = example["story"] |
|
text = text.strip() |
|
of.write(text + "\n") |
|
print(f"Size is: {os.path.getsize(tiny_file) / 1024 / 1024:.2f} MB") |
|
|
|
|
|
print("Will now train the vocab...") |
|
spm.SentencePieceTrainer.train(input=tiny_file, |
|
model_prefix=prefix, |
|
model_type="bpe", |
|
vocab_size=vocab_size, |
|
self_test_sample_size=0, |
|
input_format="text", |
|
character_coverage=1.0, |
|
num_threads=os.cpu_count(), |
|
split_digits=True, |
|
allow_whitespace_only_pieces=True, |
|
byte_fallback=True, |
|
unk_surface=r" \342\201\207 ", |
|
normalization_rule_name="identity") |
|
|
|
|
|
dec = input(f"Delete the temporary file {tiny_file}? [y/N] ") |
|
if dec.lower() == "y": |
|
os.remove(tiny_file) |
|
print(f"Deleted {tiny_file}") |
|
|
|
print(f"Trained tokenizer is in {prefix}.model") |
|
print("Done.") |
|
|
|
|
|
def process_shard(args, vocab_size): |
|
shard_id, shard = args |
|
tokenizer_model = get_tokenizer_model_path(vocab_size) |
|
enc = Tokenizer(tokenizer_model) |
|
with open(shard, "r") as f: |
|
data = json.load(f) |
|
all_tokens = [] |
|
for example in tqdm(data, position=shard_id): |
|
text = example["story"] |
|
text = text.strip() |
|
tokens = enc.encode(text, bos=True, eos=False) |
|
all_tokens.extend(tokens) |
|
|
|
all_tokens = np.array(all_tokens, dtype=np.uint16) |
|
|
|
if vocab_size == 0: |
|
|
|
tokenized_filename = shard.replace(".json", ".bin") |
|
else: |
|
|
|
bin_dir = os.path.join(DATA_CACHE_DIR, f"tok{vocab_size}") |
|
shard_basename = os.path.basename(shard) |
|
bin_basename = shard_basename.replace(".json", ".bin") |
|
tokenized_filename = os.path.join(bin_dir, bin_basename) |
|
|
|
with open(tokenized_filename, "wb") as f: |
|
f.write(all_tokens.tobytes()) |
|
|
|
avg_seq_len = all_tokens.size / ((all_tokens == 1).sum()) |
|
print(f"Saved {tokenized_filename}, average seqlen: {avg_seq_len:.2f}") |
|
|
|
|
|
def pretokenize(vocab_size): |
|
|
|
data_dir = os.path.join(DATA_CACHE_DIR, "TinyStories_all_data") |
|
shard_filenames = sorted(glob.glob(os.path.join(data_dir, "*.json"))) |
|
if vocab_size > 0: |
|
|
|
bin_dir = os.path.join(DATA_CACHE_DIR, f"tok{vocab_size}") |
|
os.makedirs(bin_dir, exist_ok=True) |
|
|
|
|
|
fun = partial(process_shard, vocab_size=vocab_size) |
|
with ProcessPoolExecutor() as executor: |
|
executor.map(fun, enumerate(shard_filenames)) |
|
print("Done.") |
|
|
|
|
|
class PretokDataset(torch.utils.data.IterableDataset): |
|
"""Loads pretokenized examples from disk and yields them as PyTorch tensors.""" |
|
|
|
def __init__(self, split, max_seq_len, vocab_size, vocab_source): |
|
super().__init__() |
|
self.split = split |
|
self.max_seq_len = max_seq_len |
|
self.vocab_size = vocab_size |
|
self.vocab_source = vocab_source |
|
|
|
def __iter__(self): |
|
|
|
worker_info = torch.utils.data.get_worker_info() |
|
worker_id = worker_info.id if worker_info else 0 |
|
|
|
rank = dist.get_rank() if dist.is_initialized() else 0 |
|
|
|
seed = 42 + worker_id + 1337 * rank |
|
rng = random.Random(seed) |
|
print(f"Created a PretokDataset with rng seed {seed}") |
|
if self.vocab_source == "llama2": |
|
|
|
bin_dir = os.path.join(DATA_CACHE_DIR, "TinyStories_all_data") |
|
shard_filenames = sorted(glob.glob(os.path.join(bin_dir, "*.bin"))) |
|
elif self.vocab_source == "custom": |
|
|
|
bin_dir = os.path.join(DATA_CACHE_DIR, f"tok{self.vocab_size}") |
|
shard_filenames = sorted(glob.glob(os.path.join(bin_dir, "*.bin"))) |
|
|
|
shard_filenames = shard_filenames[1:] if self.split == "train" else shard_filenames[:1] |
|
assert len(shard_filenames)>0, f"No bin files found in {bin_dir}" |
|
while True: |
|
rng.shuffle(shard_filenames) |
|
for shard in shard_filenames: |
|
|
|
m = np.memmap(shard, dtype=np.uint16, mode="r") |
|
num_batches = len(m) // self.max_seq_len |
|
num_batches -= 1 |
|
assert num_batches > 0, "this shard is way too small? investigate." |
|
ixs = list(range(num_batches)) |
|
rng.shuffle(ixs) |
|
for ix in ixs: |
|
start = ix * self.max_seq_len |
|
end = start + self.max_seq_len + 1 |
|
|
|
chunk = torch.from_numpy((m[start:end]).astype(np.int64)) |
|
x = chunk[:-1] |
|
y = chunk[1:] |
|
yield x, y |
|
|
|
|
|
|
|
|
|
def get_tokenizer_model_path(vocab_size): |
|
""" |
|
Returns path to the sentencepiece tokenizer model for a given vocab size |
|
vocab_size = 0 designates the default Llama 2 tokenizer, in that case |
|
None is returned. |
|
""" |
|
if vocab_size == 0: |
|
return None |
|
else: |
|
return os.path.join(DATA_CACHE_DIR, f"tok{vocab_size}.model") |
|
|
|
class Task: |
|
|
|
@staticmethod |
|
def iter_batches(batch_size, device, num_workers=0, **dataset_kwargs): |
|
ds = PretokDataset(**dataset_kwargs) |
|
dl = torch.utils.data.DataLoader( |
|
ds, batch_size=batch_size, pin_memory=True, num_workers=num_workers |
|
) |
|
for x, y in dl: |
|
x = x.to(device, non_blocking=True) |
|
y = y.to(device, non_blocking=True) |
|
yield x, y |
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
""" |
|
These stages are designed to be run in order. |
|
|
|
To tokenize data with the Llama 2 tokenizer: |
|
python tinystories.py download |
|
python tinystories.py pretokenize |
|
|
|
To tokenize data with a custom tokenizer we train ourselves with sentencepiece, e.g.: |
|
python tinystories.py download |
|
python tinystories.py train_vocab --vocab_size=2048 |
|
python tinystories.py pretokenize --vocab_size=2048 |
|
""" |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("stage", type=str, choices=["download", "pretokenize", "train_vocab"]) |
|
parser.add_argument("--vocab_size", type=int, default=0, help="pretokenization vocab size. 0 = use Llama 2 tokenizer.") |
|
args = parser.parse_args() |
|
|
|
|
|
if args.stage == "download": |
|
download() |
|
elif args.stage == "train_vocab": |
|
train_vocab(vocab_size=args.vocab_size) |
|
elif args.stage == "pretokenize": |
|
pretokenize(vocab_size=args.vocab_size) |
|
else: |
|
raise ValueError(f"Unknown stage {args.stage}") |
|
|