File size: 1,388 Bytes
e6333f5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import torch
from torch.utils.data import IterableDataset

from transformers import PreTrainedTokenizerBase

from pile import ThePile


class ThePileTokenized(IterableDataset):
    def __init__(
        self,
        base_dataset: ThePile,
        tokenizer: PreTrainedTokenizerBase,
        max_length: int = 1024,
        repeat_factor: float = 1.0,
    ):
        self.pile = base_dataset
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.repeat_factor = repeat_factor

    def __iter__(self):
        ds = iter(self.pile)
        buffer = []
        while True:
            tokens = self.tokenizer.encode(next(ds)["text"])
            buffer += [self.tokenizer.eos_token_id] + tokens
            while len(buffer) > self.max_length:
                yield torch.tensor(buffer[: self.max_length])
                buffer = buffer[int(self.max_length / self.repeat_factor) :]


if __name__ == "__main__":
    from tqdm import tqdm
    from torch.utils.data import DataLoader
    from transformers import GPT2Tokenizer

    dataset = ThePileTokenized(
        ThePile("train"),
        GPT2Tokenizer.from_pretrained("gpt2"),
        max_length=2048,
        repeat_factor=4 / 3,
    )
    dataloader = DataLoader(
        dataset,
        batch_size=1,
    )
    for batch in tqdm(dataloader, smoothing=0.01):
        x = 0
    # ~6 iters/s for 1 worker