Spaces:
Paused
Paused
File size: 3,118 Bytes
5bb6ad4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 |
import torch.nn.functional as F
import torch.multiprocessing as mp
from torch.utils.data.distributed import DistributedSampler
from torch.nn.parallel import DistributedDataParallel as DDP
from datasets import load_dataset
from torch.utils.data import DataLoader
from tokenizer import Tokenizer
from config import ModelArgs
tokenizer = Tokenizer().ready_tokenizer()
tinystories = True
fw = False
fw_train = None
fw_test = None
if(tinystories):
fw_train = load_dataset("roneneldan/TinyStories", split="train")
fw_test = load_dataset("roneneldan/TinyStories", split="validation")
print(fw_train)
print(fw_test)
if(fw):
fw_train = load_dataset("HuggingFaceFW/fineweb", name="sample-10BT", split="train", streaming=False)
fw_train = fw_train.train_test_split(test_size=0.01)
print(fw_train)
print(fw_train)
tokenizer.add_special_tokens({'pad_token': '[PAD]'})
def tokenize_function(examples):
return tokenizer(
examples['text'],
max_length=ModelArgs.block_size,
padding='max_length',
truncation=True,
return_tensors='pt'
)
def prepare_dataset(split, device, batch_size):
print("Device is: ", device)
def collate_fn(batch):
# Extract text data
texts = [item ["text"] for item in batch]
input_encodings = tokenizer(texts, max_length = ModelArgs.block_size, padding='max_length', truncation=True, return_tensors="pt")
input_encodings["labels"] = input_encodings["input_ids"].clone()
input_encodings["labels"][:, :-1] = input_encodings["input_ids"][:, 1:]
input_encodings["labels"][:, -1] = tokenizer.eos_token_id
return input_encodings
dataloader = None
if(tinystories):
if(split == 'train'):
data_loader = DataLoader(
fw_train,
# generator=generator,
batch_size=batch_size,
sampler=DistributedSampler(fw_train, shuffle=True),
collate_fn=collate_fn,
drop_last=True,
shuffle=False
)
elif(split == 'val'):
data_loader = DataLoader(
fw_test,
batch_size=batch_size,
sampler=DistributedSampler(fw_test, shuffle=True),
collate_fn=collate_fn,
drop_last=True,
shuffle=False
)
elif(fw):
if(split == 'train'):
data_loader = DataLoader(
fw_train['train'],
batch_size=batch_size,
sampler=DistributedSampler(fw_train['train'], shuffle=True),
collate_fn=collate_fn,
drop_last=True,
shuffle=False
)
elif(split == 'val'):
data_loader = DataLoader(
fw_train['test'],
batch_size=batch_size,
# generator=generator,
sampler=DistributedSampler(fw_train["test"]),
collate_fn=collate_fn,
drop_last=True,
shuffle=False
)
return data_loader
|