StoryLlama / data.py
YuvrajSingh9886's picture
Upload 12 files
5bb6ad4 verified
import torch.nn.functional as F
import torch.multiprocessing as mp
from torch.utils.data.distributed import DistributedSampler
from torch.nn.parallel import DistributedDataParallel as DDP
from datasets import load_dataset
from torch.utils.data import DataLoader
from tokenizer import Tokenizer
from config import ModelArgs
tokenizer = Tokenizer().ready_tokenizer()
tinystories = True
fw = False
fw_train = None
fw_test = None
if(tinystories):
fw_train = load_dataset("roneneldan/TinyStories", split="train")
fw_test = load_dataset("roneneldan/TinyStories", split="validation")
print(fw_train)
print(fw_test)
if(fw):
fw_train = load_dataset("HuggingFaceFW/fineweb", name="sample-10BT", split="train", streaming=False)
fw_train = fw_train.train_test_split(test_size=0.01)
print(fw_train)
print(fw_train)
tokenizer.add_special_tokens({'pad_token': '[PAD]'})
def tokenize_function(examples):
return tokenizer(
examples['text'],
max_length=ModelArgs.block_size,
padding='max_length',
truncation=True,
return_tensors='pt'
)
def prepare_dataset(split, device, batch_size):
print("Device is: ", device)
def collate_fn(batch):
# Extract text data
texts = [item ["text"] for item in batch]
input_encodings = tokenizer(texts, max_length = ModelArgs.block_size, padding='max_length', truncation=True, return_tensors="pt")
input_encodings["labels"] = input_encodings["input_ids"].clone()
input_encodings["labels"][:, :-1] = input_encodings["input_ids"][:, 1:]
input_encodings["labels"][:, -1] = tokenizer.eos_token_id
return input_encodings
dataloader = None
if(tinystories):
if(split == 'train'):
data_loader = DataLoader(
fw_train,
# generator=generator,
batch_size=batch_size,
sampler=DistributedSampler(fw_train, shuffle=True),
collate_fn=collate_fn,
drop_last=True,
shuffle=False
)
elif(split == 'val'):
data_loader = DataLoader(
fw_test,
batch_size=batch_size,
sampler=DistributedSampler(fw_test, shuffle=True),
collate_fn=collate_fn,
drop_last=True,
shuffle=False
)
elif(fw):
if(split == 'train'):
data_loader = DataLoader(
fw_train['train'],
batch_size=batch_size,
sampler=DistributedSampler(fw_train['train'], shuffle=True),
collate_fn=collate_fn,
drop_last=True,
shuffle=False
)
elif(split == 'val'):
data_loader = DataLoader(
fw_train['test'],
batch_size=batch_size,
# generator=generator,
sampler=DistributedSampler(fw_train["test"]),
collate_fn=collate_fn,
drop_last=True,
shuffle=False
)
return data_loader