File size: 3,811 Bytes
e6333f5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import json
import time
import random
from typing import Literal

import requests
import zstandard as zstd
from torch.utils.data import IterableDataset, get_worker_info


Subset = Literal["train", "val", "test"]
URLs = {
    "val": [
        "https://the-eye.eu/public/AI/pile/val.jsonl.zst",
    ],
    "test": [
        "https://the-eye.eu/public/AI/pile/test.jsonl.zst",
    ],
    "train": [
        "https://the-eye.eu/public/AI/pile/train/00.jsonl.zst",
        "https://the-eye.eu/public/AI/pile/train/01.jsonl.zst",
        "https://the-eye.eu/public/AI/pile/train/02.jsonl.zst",
        "https://the-eye.eu/public/AI/pile/train/03.jsonl.zst",
        "https://the-eye.eu/public/AI/pile/train/04.jsonl.zst",
        "https://the-eye.eu/public/AI/pile/train/05.jsonl.zst",
        "https://the-eye.eu/public/AI/pile/train/06.jsonl.zst",
        "https://the-eye.eu/public/AI/pile/train/07.jsonl.zst",
        "https://the-eye.eu/public/AI/pile/train/08.jsonl.zst",
        "https://the-eye.eu/public/AI/pile/train/09.jsonl.zst",
        "https://the-eye.eu/public/AI/pile/train/10.jsonl.zst",
        "https://the-eye.eu/public/AI/pile/train/11.jsonl.zst",
        "https://the-eye.eu/public/AI/pile/train/12.jsonl.zst",
        "https://the-eye.eu/public/AI/pile/train/13.jsonl.zst",
        "https://the-eye.eu/public/AI/pile/train/14.jsonl.zst",
        "https://the-eye.eu/public/AI/pile/train/15.jsonl.zst",
        "https://the-eye.eu/public/AI/pile/train/16.jsonl.zst",
        "https://the-eye.eu/public/AI/pile/train/17.jsonl.zst",
        "https://the-eye.eu/public/AI/pile/train/18.jsonl.zst",
        "https://the-eye.eu/public/AI/pile/train/19.jsonl.zst",
        "https://the-eye.eu/public/AI/pile/train/20.jsonl.zst",
        "https://the-eye.eu/public/AI/pile/train/21.jsonl.zst",
        "https://the-eye.eu/public/AI/pile/train/22.jsonl.zst",
        "https://the-eye.eu/public/AI/pile/train/23.jsonl.zst",
        "https://the-eye.eu/public/AI/pile/train/24.jsonl.zst",
        "https://the-eye.eu/public/AI/pile/train/25.jsonl.zst",
        "https://the-eye.eu/public/AI/pile/train/26.jsonl.zst",
        "https://the-eye.eu/public/AI/pile/train/27.jsonl.zst",
        "https://the-eye.eu/public/AI/pile/train/28.jsonl.zst",
        "https://the-eye.eu/public/AI/pile/train/29.jsonl.zst",
    ],
}


def _read_line_from_stream(reader, initial_line="", buffer_size=4096):
    line = initial_line
    while True:
        c = reader.read(buffer_size)
        if not c:
            raise StopIteration
        line += c.decode("utf-8")
        if "\n" in line:
            break
    return line.split("\n", 1)


def _line_streamer(reader, buffer_size=4096):
    rest = ""
    while True:
        try:
            line, rest = _read_line_from_stream(
                reader,
                rest,
                buffer_size,
            )
            yield line
        except StopIteration:
            break


class ThePile(IterableDataset):
    TEXT_BUFFER_SIZE = 4096

    def __init__(self, subset: Subset):
        self.subset = subset

    def __iter__(self):
        urls = URLs[self.subset].copy()
        while True:
            wi = get_worker_info()
            seed = wi.id if wi is not None else None
            rnd = random.Random(seed)
            rnd.shuffle(urls)
            for url in urls:
                r = requests.get(url, stream=True)
                with zstd.ZstdDecompressor().stream_reader(r.raw) as reader:
                    for line in _line_streamer(reader, self.TEXT_BUFFER_SIZE):
                        data = json.loads(line)
                        yield data


if __name__ == "__main__":
    from tqdm import tqdm

    dataset = ThePile("train")
    for data in tqdm(dataset, smoothing=0.01):
        pass
    # Average: ~2000 samples/sec/worker