gss / c4x.py
naxautify's picture
init
e6333f5
raw
history blame
1.53 kB
# stream C4 dataset from Huggingface with GPT-2 Tokenizer for PyTorch Language Model Training
import json
import torch
import random
from datasets import load_dataset
from transformers import BloomTokenizerFast
from torch.utils.data import Dataset, get_worker_info
def cycled(itr):
while True:
for itm in itr:
yield itm
class C4X(Dataset):
def __init__(self, seq_len=512, split='train'):
self.seq = seq_len
self.ds = load_dataset(
'c4',
name='en',
split=split,
streaming=True,
)
self.tok = BloomTokenizerFast.from_pretrained('bigscience/bloomz-1b7')
self.init = False
def __len__(self):
return 1_000_000_000
def _init(self):
if self.init:
return
wi = get_worker_info()
self.ds = cycled(
self.ds.shuffle(
seed=wi.seed,
buffer_size=10_000,
)
)
self.init = True
def _get_next(self):
self._init()
obj = next(self.ds)['text']
tkn = self.tok.encode(obj)
return tkn
def _get_full(self):
obj = []
while len(obj) < self.seq:
obj += self._get_next()
obj.append(self.tok.eos_token_id)
s = random.randint(0, len(obj)-self.seq)
return obj[s:s+self.seq]
def __getitem__(self, _):
return torch.tensor(self._get_full())
def decode(self, tkns):
return self.tok.decode(tkns)