File size: 3,811 Bytes
e6333f5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 |
import json
import time
import random
from typing import Literal
import requests
import zstandard as zstd
from torch.utils.data import IterableDataset, get_worker_info
Subset = Literal["train", "val", "test"]
URLs = {
"val": [
"https://the-eye.eu/public/AI/pile/val.jsonl.zst",
],
"test": [
"https://the-eye.eu/public/AI/pile/test.jsonl.zst",
],
"train": [
"https://the-eye.eu/public/AI/pile/train/00.jsonl.zst",
"https://the-eye.eu/public/AI/pile/train/01.jsonl.zst",
"https://the-eye.eu/public/AI/pile/train/02.jsonl.zst",
"https://the-eye.eu/public/AI/pile/train/03.jsonl.zst",
"https://the-eye.eu/public/AI/pile/train/04.jsonl.zst",
"https://the-eye.eu/public/AI/pile/train/05.jsonl.zst",
"https://the-eye.eu/public/AI/pile/train/06.jsonl.zst",
"https://the-eye.eu/public/AI/pile/train/07.jsonl.zst",
"https://the-eye.eu/public/AI/pile/train/08.jsonl.zst",
"https://the-eye.eu/public/AI/pile/train/09.jsonl.zst",
"https://the-eye.eu/public/AI/pile/train/10.jsonl.zst",
"https://the-eye.eu/public/AI/pile/train/11.jsonl.zst",
"https://the-eye.eu/public/AI/pile/train/12.jsonl.zst",
"https://the-eye.eu/public/AI/pile/train/13.jsonl.zst",
"https://the-eye.eu/public/AI/pile/train/14.jsonl.zst",
"https://the-eye.eu/public/AI/pile/train/15.jsonl.zst",
"https://the-eye.eu/public/AI/pile/train/16.jsonl.zst",
"https://the-eye.eu/public/AI/pile/train/17.jsonl.zst",
"https://the-eye.eu/public/AI/pile/train/18.jsonl.zst",
"https://the-eye.eu/public/AI/pile/train/19.jsonl.zst",
"https://the-eye.eu/public/AI/pile/train/20.jsonl.zst",
"https://the-eye.eu/public/AI/pile/train/21.jsonl.zst",
"https://the-eye.eu/public/AI/pile/train/22.jsonl.zst",
"https://the-eye.eu/public/AI/pile/train/23.jsonl.zst",
"https://the-eye.eu/public/AI/pile/train/24.jsonl.zst",
"https://the-eye.eu/public/AI/pile/train/25.jsonl.zst",
"https://the-eye.eu/public/AI/pile/train/26.jsonl.zst",
"https://the-eye.eu/public/AI/pile/train/27.jsonl.zst",
"https://the-eye.eu/public/AI/pile/train/28.jsonl.zst",
"https://the-eye.eu/public/AI/pile/train/29.jsonl.zst",
],
}
def _read_line_from_stream(reader, initial_line="", buffer_size=4096):
line = initial_line
while True:
c = reader.read(buffer_size)
if not c:
raise StopIteration
line += c.decode("utf-8")
if "\n" in line:
break
return line.split("\n", 1)
def _line_streamer(reader, buffer_size=4096):
rest = ""
while True:
try:
line, rest = _read_line_from_stream(
reader,
rest,
buffer_size,
)
yield line
except StopIteration:
break
class ThePile(IterableDataset):
TEXT_BUFFER_SIZE = 4096
def __init__(self, subset: Subset):
self.subset = subset
def __iter__(self):
urls = URLs[self.subset].copy()
while True:
wi = get_worker_info()
seed = wi.id if wi is not None else None
rnd = random.Random(seed)
rnd.shuffle(urls)
for url in urls:
r = requests.get(url, stream=True)
with zstd.ZstdDecompressor().stream_reader(r.raw) as reader:
for line in _line_streamer(reader, self.TEXT_BUFFER_SIZE):
data = json.loads(line)
yield data
if __name__ == "__main__":
from tqdm import tqdm
dataset = ThePile("train")
for data in tqdm(dataset, smoothing=0.01):
pass
# Average: ~2000 samples/sec/worker |