File size: 1,388 Bytes
e6333f5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 |
import torch
from torch.utils.data import IterableDataset
from transformers import PreTrainedTokenizerBase
from pile import ThePile
class ThePileTokenized(IterableDataset):
def __init__(
self,
base_dataset: ThePile,
tokenizer: PreTrainedTokenizerBase,
max_length: int = 1024,
repeat_factor: float = 1.0,
):
self.pile = base_dataset
self.tokenizer = tokenizer
self.max_length = max_length
self.repeat_factor = repeat_factor
def __iter__(self):
ds = iter(self.pile)
buffer = []
while True:
tokens = self.tokenizer.encode(next(ds)["text"])
buffer += [self.tokenizer.eos_token_id] + tokens
while len(buffer) > self.max_length:
yield torch.tensor(buffer[: self.max_length])
buffer = buffer[int(self.max_length / self.repeat_factor) :]
if __name__ == "__main__":
from tqdm import tqdm
from torch.utils.data import DataLoader
from transformers import GPT2Tokenizer
dataset = ThePileTokenized(
ThePile("train"),
GPT2Tokenizer.from_pretrained("gpt2"),
max_length=2048,
repeat_factor=4 / 3,
)
dataloader = DataLoader(
dataset,
batch_size=1,
)
for batch in tqdm(dataloader, smoothing=0.01):
x = 0
# ~6 iters/s for 1 worker
|