File size: 1,534 Bytes
e6333f5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 |
# stream C4 dataset from Huggingface with GPT-2 Tokenizer for PyTorch Language Model Training
import json
import torch
import random
from datasets import load_dataset
from transformers import BloomTokenizerFast
from torch.utils.data import Dataset, get_worker_info
def cycled(itr):
while True:
for itm in itr:
yield itm
class C4X(Dataset):
def __init__(self, seq_len=512, split='train'):
self.seq = seq_len
self.ds = load_dataset(
'c4',
name='en',
split=split,
streaming=True,
)
self.tok = BloomTokenizerFast.from_pretrained('bigscience/bloomz-1b7')
self.init = False
def __len__(self):
return 1_000_000_000
def _init(self):
if self.init:
return
wi = get_worker_info()
self.ds = cycled(
self.ds.shuffle(
seed=wi.seed,
buffer_size=10_000,
)
)
self.init = True
def _get_next(self):
self._init()
obj = next(self.ds)['text']
tkn = self.tok.encode(obj)
return tkn
def _get_full(self):
obj = []
while len(obj) < self.seq:
obj += self._get_next()
obj.append(self.tok.eos_token_id)
s = random.randint(0, len(obj)-self.seq)
return obj[s:s+self.seq]
def __getitem__(self, _):
return torch.tensor(self._get_full())
def decode(self, tkns):
return self.tok.decode(tkns)
|