gss / pile.py
naxautify's picture
init
e6333f5
import json
import time
import random
from typing import Literal
import requests
import zstandard as zstd
from torch.utils.data import IterableDataset, get_worker_info
Subset = Literal["train", "val", "test"]
URLs = {
"val": [
"https://the-eye.eu/public/AI/pile/val.jsonl.zst",
],
"test": [
"https://the-eye.eu/public/AI/pile/test.jsonl.zst",
],
"train": [
"https://the-eye.eu/public/AI/pile/train/00.jsonl.zst",
"https://the-eye.eu/public/AI/pile/train/01.jsonl.zst",
"https://the-eye.eu/public/AI/pile/train/02.jsonl.zst",
"https://the-eye.eu/public/AI/pile/train/03.jsonl.zst",
"https://the-eye.eu/public/AI/pile/train/04.jsonl.zst",
"https://the-eye.eu/public/AI/pile/train/05.jsonl.zst",
"https://the-eye.eu/public/AI/pile/train/06.jsonl.zst",
"https://the-eye.eu/public/AI/pile/train/07.jsonl.zst",
"https://the-eye.eu/public/AI/pile/train/08.jsonl.zst",
"https://the-eye.eu/public/AI/pile/train/09.jsonl.zst",
"https://the-eye.eu/public/AI/pile/train/10.jsonl.zst",
"https://the-eye.eu/public/AI/pile/train/11.jsonl.zst",
"https://the-eye.eu/public/AI/pile/train/12.jsonl.zst",
"https://the-eye.eu/public/AI/pile/train/13.jsonl.zst",
"https://the-eye.eu/public/AI/pile/train/14.jsonl.zst",
"https://the-eye.eu/public/AI/pile/train/15.jsonl.zst",
"https://the-eye.eu/public/AI/pile/train/16.jsonl.zst",
"https://the-eye.eu/public/AI/pile/train/17.jsonl.zst",
"https://the-eye.eu/public/AI/pile/train/18.jsonl.zst",
"https://the-eye.eu/public/AI/pile/train/19.jsonl.zst",
"https://the-eye.eu/public/AI/pile/train/20.jsonl.zst",
"https://the-eye.eu/public/AI/pile/train/21.jsonl.zst",
"https://the-eye.eu/public/AI/pile/train/22.jsonl.zst",
"https://the-eye.eu/public/AI/pile/train/23.jsonl.zst",
"https://the-eye.eu/public/AI/pile/train/24.jsonl.zst",
"https://the-eye.eu/public/AI/pile/train/25.jsonl.zst",
"https://the-eye.eu/public/AI/pile/train/26.jsonl.zst",
"https://the-eye.eu/public/AI/pile/train/27.jsonl.zst",
"https://the-eye.eu/public/AI/pile/train/28.jsonl.zst",
"https://the-eye.eu/public/AI/pile/train/29.jsonl.zst",
],
}
def _read_line_from_stream(reader, initial_line="", buffer_size=4096):
line = initial_line
while True:
c = reader.read(buffer_size)
if not c:
raise StopIteration
line += c.decode("utf-8")
if "\n" in line:
break
return line.split("\n", 1)
def _line_streamer(reader, buffer_size=4096):
rest = ""
while True:
try:
line, rest = _read_line_from_stream(
reader,
rest,
buffer_size,
)
yield line
except StopIteration:
break
class ThePile(IterableDataset):
TEXT_BUFFER_SIZE = 4096
def __init__(self, subset: Subset):
self.subset = subset
def __iter__(self):
urls = URLs[self.subset].copy()
while True:
wi = get_worker_info()
seed = wi.id if wi is not None else None
rnd = random.Random(seed)
rnd.shuffle(urls)
for url in urls:
r = requests.get(url, stream=True)
with zstd.ZstdDecompressor().stream_reader(r.raw) as reader:
for line in _line_streamer(reader, self.TEXT_BUFFER_SIZE):
data = json.loads(line)
yield data
if __name__ == "__main__":
from tqdm import tqdm
dataset = ThePile("train")
for data in tqdm(dataset, smoothing=0.01):
pass
# Average: ~2000 samples/sec/worker