|
import json |
|
import time |
|
import random |
|
from typing import Literal |
|
|
|
import requests |
|
import zstandard as zstd |
|
from torch.utils.data import IterableDataset, get_worker_info |
|
|
|
|
|
Subset = Literal["train", "val", "test"] |
|
URLs = { |
|
"val": [ |
|
"https://the-eye.eu/public/AI/pile/val.jsonl.zst", |
|
], |
|
"test": [ |
|
"https://the-eye.eu/public/AI/pile/test.jsonl.zst", |
|
], |
|
"train": [ |
|
"https://the-eye.eu/public/AI/pile/train/00.jsonl.zst", |
|
"https://the-eye.eu/public/AI/pile/train/01.jsonl.zst", |
|
"https://the-eye.eu/public/AI/pile/train/02.jsonl.zst", |
|
"https://the-eye.eu/public/AI/pile/train/03.jsonl.zst", |
|
"https://the-eye.eu/public/AI/pile/train/04.jsonl.zst", |
|
"https://the-eye.eu/public/AI/pile/train/05.jsonl.zst", |
|
"https://the-eye.eu/public/AI/pile/train/06.jsonl.zst", |
|
"https://the-eye.eu/public/AI/pile/train/07.jsonl.zst", |
|
"https://the-eye.eu/public/AI/pile/train/08.jsonl.zst", |
|
"https://the-eye.eu/public/AI/pile/train/09.jsonl.zst", |
|
"https://the-eye.eu/public/AI/pile/train/10.jsonl.zst", |
|
"https://the-eye.eu/public/AI/pile/train/11.jsonl.zst", |
|
"https://the-eye.eu/public/AI/pile/train/12.jsonl.zst", |
|
"https://the-eye.eu/public/AI/pile/train/13.jsonl.zst", |
|
"https://the-eye.eu/public/AI/pile/train/14.jsonl.zst", |
|
"https://the-eye.eu/public/AI/pile/train/15.jsonl.zst", |
|
"https://the-eye.eu/public/AI/pile/train/16.jsonl.zst", |
|
"https://the-eye.eu/public/AI/pile/train/17.jsonl.zst", |
|
"https://the-eye.eu/public/AI/pile/train/18.jsonl.zst", |
|
"https://the-eye.eu/public/AI/pile/train/19.jsonl.zst", |
|
"https://the-eye.eu/public/AI/pile/train/20.jsonl.zst", |
|
"https://the-eye.eu/public/AI/pile/train/21.jsonl.zst", |
|
"https://the-eye.eu/public/AI/pile/train/22.jsonl.zst", |
|
"https://the-eye.eu/public/AI/pile/train/23.jsonl.zst", |
|
"https://the-eye.eu/public/AI/pile/train/24.jsonl.zst", |
|
"https://the-eye.eu/public/AI/pile/train/25.jsonl.zst", |
|
"https://the-eye.eu/public/AI/pile/train/26.jsonl.zst", |
|
"https://the-eye.eu/public/AI/pile/train/27.jsonl.zst", |
|
"https://the-eye.eu/public/AI/pile/train/28.jsonl.zst", |
|
"https://the-eye.eu/public/AI/pile/train/29.jsonl.zst", |
|
], |
|
} |
|
|
|
|
|
def _read_line_from_stream(reader, initial_line="", buffer_size=4096): |
|
line = initial_line |
|
while True: |
|
c = reader.read(buffer_size) |
|
if not c: |
|
raise StopIteration |
|
line += c.decode("utf-8") |
|
if "\n" in line: |
|
break |
|
return line.split("\n", 1) |
|
|
|
|
|
def _line_streamer(reader, buffer_size=4096): |
|
rest = "" |
|
while True: |
|
try: |
|
line, rest = _read_line_from_stream( |
|
reader, |
|
rest, |
|
buffer_size, |
|
) |
|
yield line |
|
except StopIteration: |
|
break |
|
|
|
|
|
class ThePile(IterableDataset): |
|
TEXT_BUFFER_SIZE = 4096 |
|
|
|
def __init__(self, subset: Subset): |
|
self.subset = subset |
|
|
|
def __iter__(self): |
|
urls = URLs[self.subset].copy() |
|
while True: |
|
wi = get_worker_info() |
|
seed = wi.id if wi is not None else None |
|
rnd = random.Random(seed) |
|
rnd.shuffle(urls) |
|
for url in urls: |
|
r = requests.get(url, stream=True) |
|
with zstd.ZstdDecompressor().stream_reader(r.raw) as reader: |
|
for line in _line_streamer(reader, self.TEXT_BUFFER_SIZE): |
|
data = json.loads(line) |
|
yield data |
|
|
|
|
|
if __name__ == "__main__": |
|
from tqdm import tqdm |
|
|
|
dataset = ThePile("train") |
|
for data in tqdm(dataset, smoothing=0.01): |
|
pass |
|
|