StoryLlama / trainer.py
YuvrajSingh9886's picture
Upload 12 files
5bb6ad4 verified
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed import init_process_group, destroy_process_group
import torch
import wandb
import torch.optim as optim
import os
from config import ModelArgs
from model import Llama
from inference import greedy_decode
from data import prepare_dataset
from tokenizer import Tokenizer
torch.set_float32_matmul_precision('high')
scaler = torch.amp.GradScaler(enabled=(ModelArgs.dtype == 'float16'))
save_chechpoint_iter = 50
total_iters = 10000
eval_iters = 50
eval_check = 100
warmup_iters = 700
min_lr = 0.1 * ModelArgs.max_lr
lr_decay_iters = 10000
total_batch_size = 524288
micro_batch_size = ModelArgs.batch_size
gradient_accumulation_steps = total_batch_size // (micro_batch_size * (ModelArgs.block_size * torch.cuda.device_count()))
class Trainer:
def __init__(self, model_args):
def setup(rank=None, world_size=None):
# os.environ['MASTER_ADDR'] = 'localhost'
# os.environ['MASTER_PORT'] = '12355'
init_process_group("nccl")
# torch.cuda.set_device(int(os.environ['LOCAL_RANK']))
self.model_args = model_args
self.tokenizer = Tokenizer().ready_tokenizer()
setup()
def cleanup(self):
destroy_process_group()
def _save_snapshot(self, model, optimizer, epoch, step, save_dir):
snapshot = {}
snapshot["MODEL_STATE"] = model.module.state_dict()
snapshot["OPTIMIZER_STATE"]= optimizer.state_dict()
snapshot["EPOCHS_RUN"] = epoch
snapshot["STEP_RUN"] = step
torch.save(snapshot, os.path.join(save_dir, "snapshot.pt"))
print(f"Epoch: {epoch} | step {step} | Training snapshot saved at snapshot.pt")
# Warmup phase for 2000 steps
def warmup_fn(step):
if step < 2000:
return step / 2000 # LR gradually increases
return 1.0
# learning rate decay scheduler (cosine with warmup) from https://github.com/karpathy/nanoGPT/blob/master/train.py
def get_lr(it):
# 1) linear warmup for warmup_iters steps
if it < warmup_iters:
return ModelArgs.max_lr * (it + 1) / (warmup_iters + 1)
# 2) if it > lr_decay_iters, return min learning rate
if it > lr_decay_iters:
return min_lr
# 3) in between, use cosine decay down to min learning rate
decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters)
assert 0 <= decay_ratio <= 1
coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio))
return min_lr + coeff * (ModelArgs.max_lr - min_lr)
def train():
setup()
device = int(os.environ["LOCAL_RANK"])
torch.cuda.set_device(int(device))
print(f"Start running DDP on rank {device}.")
if(device == 0):
# # Initialise run
wandb.init(
# entity = 'rajceo2031',
project = 'Llama-DDP-Pretrain-10-billion-tokens',
# config = CFG,
# save_code = True,
#group = 'ANN',
#job_type = 'train'
)
print("wand initialized")
model = Llama(embeddings_dims=ModelArgs.embeddings_dims, block_size=ModelArgs.block_size, vocab_size=ModelArgs.vocab_size, dropout=ModelArgs.dropout, device=device)
# print(f"Model on device {device} is ready")
print(f"Model on device {device} is ready")
optimizer = optim.AdamW(model.parameters(), lr=ModelArgs.max_lr, betas=(ModelArgs.beta_1, ModelArgs.beta_2), weight_decay=ModelArgs.weight_decay_optim, eps=ModelArgs.eps)
# model = torch.compile(model)
model = model.to(device)
model = DDP(model, device_ids=[device])
model.eval()
world_size = torch.cuda.device_count()
@torch.inference_mode()
def estimate_loss(val_loader, val_iterator, device):
out = {}
loader = None
epoch_loss = None
epoch_losses = []
for split in ['val']:
print(f"Starting with {split} evaluation...")
for step in range(eval_check):
try:
batch = next(val_iterator)
except StopIteration:
val_loader_iterator = iter(val_loader)
batch = next(val_loader_iterator)
total_loss = 0
total_batches = 0
idx = batch['input_ids']
targets = batch['labels']
idx = idx.to(device)
targets = targets.to(device)
with torch.autocast(device_type=device, dtype=torch.bfloat16):
logits = model(idx)
batch_size, block_size, embeddings_dims = logits.shape
logits = logits.view(batch_size * block_size, embeddings_dims)
targets = targets.view(batch_size * block_size)
loss = F.cross_entropy(logits, targets, ignore_index=tokenizer.pad_token_id)
total_loss += loss.item()
total_batches += 1
epoch_loss = total_loss / total_batches if total_batches > 0 else 0.0
epoch_losses.append(epoch_loss)
out[split] = sum(epoch_losses) / len(epoch_losses) if epoch_losses else 0.0
epoch_loss = None
epoch_losses = []
model.train()
return out
model.train()
count = 0
train_dataloader = prepare_dataset('train', device, ModelArgs.batch_size)
val_loader= prepare_dataset('val', device, ModelArgs.batch_size)
print("Loaders ready both")
epochs = ModelArgs.epochs
train_loader_length = 0
train_data_iterator = iter(train_dataloader)
val_data_iterator = iter(val_loader)
token_count = 0
if(device == 0):
train_loader_length = len(train_dataloader)
for step in tqdm(range(total_iters)):
if(device == 0):
print("Step : ", step, "/", total_iters)
print('Total batches: ', len(train_dataloader))
print("Total gradient accumulation steps: ", gradient_accumulation_steps)
print("Total tokens processed: ", token_count)
if (step % eval_iters == 0 and step != 0) or step == total_iters - 1:
losses = estimate_loss( val_loader, val_data_iterator, 'cuda')
# avg_train_loss = losses['train']
avg_val_loss = losses['val']
print(f"[GPU {device}] | Step: {step} / {total_iters} | Val Loss: {losses['val']:.4f}")
avg_val_loss = torch.Tensor([losses['val']]).to(device)
# torch.distributed.reduce(avg_train_loss, dst=0, op=torch.distributed.ReduceOp.SUM)
torch.distributed.reduce(avg_val_loss, dst=0, op=torch.distributed.ReduceOp.SUM)
if device == 0:
all_gpus_avg_val_loss = avg_val_loss / world_size
print(f"All_GPUs_Val_losses: {all_gpus_avg_val_loss.item():.4f}")
wandb.log({
# "Learning Rate": optimizer.param_groups[0]['lr'],
# "All_GPUs_Train_losses": all_gpus_avg_train_loss,
"All_GPUs_Val_losses": all_gpus_avg_val_loss,
# "training_step_loss": losses['train'],
"val_step_loss": losses['val'],
# "Step": step,
# "Epoch": epoch
})
if step % save_chechpoint_iter == 0 and device == 0 and step != 0:
print(f"Saving the model checkpoint for step: {step}")
_save_snapshot(model, optimizer, None, None, step)
accumulated_loss = 0.0
optimizer.zero_grad(set_to_none=True)
for micro_step in range(gradient_accumulation_steps):
try:
batch = next(train_data_iterator)
except StopIteration:
train_data_iterator = iter(train_dataloader)
batch = next(train_data_iterator)
# print(batch)
# batch = next(train_data_iterator)
# print(batch)
# batch = {k: v.to(self.local_rank) for k, v in batch.items()}
idx = batch['input_ids'].to(device)
# idx, targets = get_batch(split='train')
# print(f"Starting the train step: {step}...")
# for idx, targets in train_loader:
# idx, targets = next(iter(train_loader))
# print("Idx: ", idx)
# print("Targets: ", targets)
# idx = idx.to(device)
# print("Idx: ", idx)
# print("Targets: ", targets)
targets = batch['labels'].to(device)
token_count += len(idx)
with torch.autocast(device_type=ModelArgs.device, dtype=torch.bfloat16):
logits = model(idx)
batch_size, block_size, embeddings_dims = logits.shape
# print(logits.shape)
# print(targets)
logits = logits.view(batch_size*block_size, embeddings_dims)
# print("OK")
targets = targets.view(batch_size * block_size)
# print("OK2")
loss = nn.functional.cross_entropy(logits, targets, ignore_index=tokenizer.pad_token_id)
loss = loss / gradient_accumulation_steps #IDK why div is done here specifically? Maybe think of it in terms of a very big batch being processed and there is need for equal important of each mini batch for the overall big batch
accumulated_loss += loss.detach()
model.require_backward_grad_sync = (micro_step == gradient_accumulation_steps - 1) # so that we dont synchronize the gradient everytime across the GPU devices
scaler.scale(loss).backward()
# Check for unused parameters
unused_params = find_unused_parameters(model)
if unused_params:
print(f"Unused parameters: {unused_params}")
# break
if(device == 0):
if(micro_step % 10 == 0):
# if(step == train_loader_length):
# break
print("Micro Batch : ", micro_step)
print("Step : ", step, "/", total_iters)
print('Total batches: ', len(train_dataloader))
print("Total gradient accumulation steps: ", gradient_accumulation_steps)
print("Total tokens processed: ", token_count)
# count += 1
lr = get_lr(step)
for params in optimizer.param_groups:
params['lr'] = lr
# Compute gradient norms before clipping
if(ModelArgs.clip != 0.0):
scaler.unscale_(optimizer) #To avoid underflow
total_norm_before = torch.norm(
torch.stack([torch.norm(p.grad.detach(), 2) for p in model.parameters()]), 2
)
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=ModelArgs.clip)
# Compute gradient norms after clipping
total_norm_after = torch.norm(
torch.stack([torch.norm(p.grad.detach(), 2) for p in model.parameters()]), 2
)
if(device == 0 and step !=0):
print(f"Gradient Norm Before Clipping: {total_norm_before.item():.4f}")
print(f"Gradient Norm After Clipping: {total_norm_after.item():.4f}")
scaler.step(optimizer)
scaler.update()
# optimizer.step()
# new_scheduler.step()
torch.cuda.synchronize()
torch.distributed.reduce(loss, dst=0, op=torch.distributed.ReduceOp.SUM)
if(device == 0):
wandb.log({
"Learning Rate": lr,
"All_GPUs_Train_losses": accumulated_loss.item(),
# "All_GPUs_Val_losses": all_gpus_avg_val_loss,
# "training_step_loss": losses['train'],
# "val_step_loss": losses['val'],
"Step": step,
# "Epoch": epoch
})
# print(loss.item())
# break
if device == 0 and step % 5 == 0:
count = 3
while(count): # Only generate text on the main process
prompt = "Once upon a time"
generated_text = topk_sampling(model, prompt, max_length=50, top_k=50, temperature=1.0, device=device)
print(f" Step: {step} | Generated Text: {generated_text}")
count -= 1
if device == 0:
wandb.finish()
cleanup()
world_size = torch.cuda.device_count()
print(f"World size: {world_size}")
def parse_args():
parser = argparse.ArgumentParser(description="Model Training Arguments")
# Add arguments for each field in ModelArgs
parser.add_argument("--epochs", type=int, default=ModelArgs.epochs, help="Number of training epochs.")
parser.add_argument("--block_size", type=int, default=ModelArgs.block_size, help="Block size for the model.")
parser.add_argument("--batch_size", type=int, default=ModelArgs.batch_size, help="Batch size for training.")
# parser.add_argument("--inference", type=lambda x: (str(x).lower() == 'true'), default=ModelArgs.inference, help="Whether to run in inference mode.")
parser.add_argument("--embeddings_dims", type=int, default=ModelArgs.embeddings_dims, help="Embedding dimensions.")
parser.add_argument("--attn_dropout", type=float, default=ModelArgs.attn_dropout, help="Attention dropout rate.")
parser.add_argument("--no_of_heads", type=int, default=ModelArgs.no_of_heads, help="Number of attention heads.")
parser.add_argument("--dropout", type=float, default=ModelArgs.dropout, help="Dropout rate.")
parser.add_argument("--val_epochs", type=int, default=ModelArgs.val_epochs, help="Number of validation epochs.")
parser.add_argument("--max_lr", type=float, default=ModelArgs.max_lr, help="Learning rate.")
parser.add_argument("--no_of_decoder_layers", type=int, default=ModelArgs.no_of_decoder_layers, help="Number of decoder layers.")
parser.add_argument("--weight_decay_optim", type=float, default=ModelArgs.weight_decay_optim, help="Weight decay for optimizer.")
parser.add_argument("--beta_1", type=float, default=ModelArgs.beta_1, help="Beta1 for Adam optimizer.")
parser.add_argument("--beta_2", type=float, default=ModelArgs.beta_2, help="Beta2 for Adam optimizer.")
parser.add_argument("--clip", type=float, default=ModelArgs.clip, help="Gradient clipping value.")
parser.add_argument("--device", type=str, default=ModelArgs.device, help="Device to run the model on (e.g., 'cuda' or 'cpu').")
parser.add_argument("--no_kv_heads", type=int, default=ModelArgs.no_kv_heads, help="Number of key/value heads.")
parser.add_argument("--vocab_size", type=int, default=ModelArgs.vocab_size, help="Vocabulary size.")
parser.add_argument("--eps", type=float, default=ModelArgs.eps, help="Epsilon value for numerical stability.")
parser.add_argument("--dtype", type=str, default=ModelArgs.dtype, help="Data type for tensors (e.g., 'float16' or 'bfloat16').")
parser.add_argument("--save_checkpoint_dir", type=str, default=ModelArgs.save_checkpoint_dir, help="Directory to save model checkpoints.")
parser.add_argument("--prompt", type=str, default=ModelArgs.prompt, help="Prompt for testing during training.")
# Additional arguments
parser.add_argument("--save_checkpoint_iter", type=int, default=ModelArgs.save_checkpoint_iter, help="Save checkpoint every N iterations.")
parser.add_argument("--total_iters", type=int, default=ModelArgs.total_iters, help="Total number of training iterations.")
parser.add_argument("--eval_iters", type=int, default=ModelArgs.eval_iters, help="Number of iterations for evaluation.")
parser.add_argument("--eval_check", type=int, default=ModelArgs.eval_check, help="Evaluate model every N iterations.")
parser.add_argument("--warmup_iters", type=int, default=ModelArgs.warmup_iters, help="Number of warmup iterations for learning rate scheduling.")
parser.add_argument("--min_lr", type=float, default=ModelArgs.min_lr, help="Minimum learning rate.")
parser.add_argument("--lr_decay_iters", type=int, default=ModelArgs.lr_decay_iters, help="Number of iterations for learning rate decay.")
parser.add_argument("--total_batch_size", type=int, default=ModelArgs.total_batch_size, help="Total batch size across all devices.")
parser.add_argument("--micro_batch_size", type=int, default=ModelArgs.micro_batch_size, help="Micro batch size per device.")
parser.add_argument("--gradient_accumulation_steps", type=int, default=ModelArgs.gradient_accumulation_steps, help="Number of gradient accumulation steps.")
args = parser.parse_args()
return args
def initialize_model_args(args):
# Create a ModelArgs instance from the parsed arguments
model_args = ModelArgs(
epochs=args.epochs,
block_size=args.block_size,
batch_size=args.batch_size,
# inference=args.inference,
embeddings_dims=args.embeddings_dims,
attn_dropout=args.attn_dropout,
no_of_heads=args.no_of_heads,
dropout=args.dropout,
val_epochs=args.val_epochs,
max_lr=args.max_lr,
no_of_decoder_layers=args.no_of_decoder_layers,
weight_decay_optim=args.weight_decay_optim,
beta_1=args.beta_1,
beta_2=args.beta_2,
clip=args.clip,
device=args.device,
no_kv_heads=args.no_kv_heads,
vocab_size=args.vocab_size,
eps=args.eps,
dtype=args.dtype,
save_checkpoint_dir=args.save_checkpoint_dir,
prompt=args.prompt,
save_checkpoint_iter=args.save_checkpoint_iter,
total_iters=args.total_iters,
eval_iters=args.eval_iters,
eval_check=args.eval_check,
warmup_iters=args.warmup_iters,
min_lr=args.min_lr,
lr_decay_iters=args.lr_decay_iters,
total_batch_size=args.total_batch_size,
micro_batch_size=args.micro_batch_size,
gradient_accumulation_steps=args.gradient_accumulation_steps
)
return model_args
if __name__ == "__main__":
args = parse_args()
model_args = initialize_model_args(args)