|
import torch |
|
from torch.utils.data import IterableDataset |
|
|
|
from transformers import PreTrainedTokenizerBase |
|
|
|
from pile import ThePile |
|
|
|
|
|
class ThePileTokenized(IterableDataset): |
|
def __init__( |
|
self, |
|
base_dataset: ThePile, |
|
tokenizer: PreTrainedTokenizerBase, |
|
max_length: int = 1024, |
|
repeat_factor: float = 1.0, |
|
): |
|
self.pile = base_dataset |
|
self.tokenizer = tokenizer |
|
self.max_length = max_length |
|
self.repeat_factor = repeat_factor |
|
|
|
def __iter__(self): |
|
ds = iter(self.pile) |
|
buffer = [] |
|
while True: |
|
tokens = self.tokenizer.encode(next(ds)["text"]) |
|
buffer += [self.tokenizer.eos_token_id] + tokens |
|
while len(buffer) > self.max_length: |
|
yield torch.tensor(buffer[: self.max_length]) |
|
buffer = buffer[int(self.max_length / self.repeat_factor) :] |
|
|
|
|
|
if __name__ == "__main__": |
|
from tqdm import tqdm |
|
from torch.utils.data import DataLoader |
|
from transformers import GPT2Tokenizer |
|
|
|
dataset = ThePileTokenized( |
|
ThePile("train"), |
|
GPT2Tokenizer.from_pretrained("gpt2"), |
|
max_length=2048, |
|
repeat_factor=4 / 3, |
|
) |
|
dataloader = DataLoader( |
|
dataset, |
|
batch_size=1, |
|
) |
|
for batch in tqdm(dataloader, smoothing=0.01): |
|
x = 0 |
|
|
|
|