gss / pile_hf.py
naxautify's picture
init
e6333f5
import torch
from torch.utils.data import IterableDataset
from transformers import PreTrainedTokenizerBase
from pile import ThePile
class ThePileTokenized(IterableDataset):
def __init__(
self,
base_dataset: ThePile,
tokenizer: PreTrainedTokenizerBase,
max_length: int = 1024,
repeat_factor: float = 1.0,
):
self.pile = base_dataset
self.tokenizer = tokenizer
self.max_length = max_length
self.repeat_factor = repeat_factor
def __iter__(self):
ds = iter(self.pile)
buffer = []
while True:
tokens = self.tokenizer.encode(next(ds)["text"])
buffer += [self.tokenizer.eos_token_id] + tokens
while len(buffer) > self.max_length:
yield torch.tensor(buffer[: self.max_length])
buffer = buffer[int(self.max_length / self.repeat_factor) :]
if __name__ == "__main__":
from tqdm import tqdm
from torch.utils.data import DataLoader
from transformers import GPT2Tokenizer
dataset = ThePileTokenized(
ThePile("train"),
GPT2Tokenizer.from_pretrained("gpt2"),
max_length=2048,
repeat_factor=4 / 3,
)
dataloader = DataLoader(
dataset,
batch_size=1,
)
for batch in tqdm(dataloader, smoothing=0.01):
x = 0
# ~6 iters/s for 1 worker