Spaces:
Sleeping
Sleeping
File size: 1,104 Bytes
35c1cfd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 |
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
from tqdm import tqdm
from itertools import chain
from torch.utils.data import Dataset
class ConcatDataset(Dataset):
def __init__(self, dataset, chunk_size=4096):
self.dataset = dataset
self.chunk_size = chunk_size
self.samples = []
buffer = {
"input_ids": [],
"attention_mask": [],
"labels": [],
}
for sample in tqdm(self.dataset, desc="Preprocessing dataset", dynamic_ncols=True):
buffer = {k: v + sample[k] for k,v in buffer.items()}
while len(next(iter(buffer.values()))) > self.chunk_size:
self.samples.append({k: v[:self.chunk_size] for k,v in buffer.items()})
buffer = {k: v[self.chunk_size:] for k,v in buffer.items()}
def __getitem__(self, idx):
return self.samples[idx]
def __len__(self):
return len(self.samples)
|