Spaces:
Paused
Paused
import argparse | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from torch.nn.parallel import DistributedDataParallel as DDP | |
from torch.distributed import init_process_group, destroy_process_group | |
import torch | |
import wandb | |
import torch.optim as optim | |
import os | |
from config import ModelArgs | |
from model import Llama | |
from inference import greedy_decode | |
from data import prepare_dataset | |
from tokenizer import Tokenizer | |
torch.set_float32_matmul_precision('high') | |
scaler = torch.amp.GradScaler(enabled=(ModelArgs.dtype == 'float16')) | |
save_chechpoint_iter = 50 | |
total_iters = 10000 | |
eval_iters = 50 | |
eval_check = 100 | |
warmup_iters = 700 | |
min_lr = 0.1 * ModelArgs.max_lr | |
lr_decay_iters = 10000 | |
total_batch_size = 524288 | |
micro_batch_size = ModelArgs.batch_size | |
gradient_accumulation_steps = total_batch_size // (micro_batch_size * (ModelArgs.block_size * torch.cuda.device_count())) | |
class Trainer: | |
def __init__(self, model_args): | |
def setup(rank=None, world_size=None): | |
# os.environ['MASTER_ADDR'] = 'localhost' | |
# os.environ['MASTER_PORT'] = '12355' | |
init_process_group("nccl") | |
# torch.cuda.set_device(int(os.environ['LOCAL_RANK'])) | |
self.model_args = model_args | |
self.tokenizer = Tokenizer().ready_tokenizer() | |
setup() | |
def cleanup(self): | |
destroy_process_group() | |
def _save_snapshot(self, model, optimizer, epoch, step, save_dir): | |
snapshot = {} | |
snapshot["MODEL_STATE"] = model.module.state_dict() | |
snapshot["OPTIMIZER_STATE"]= optimizer.state_dict() | |
snapshot["EPOCHS_RUN"] = epoch | |
snapshot["STEP_RUN"] = step | |
torch.save(snapshot, os.path.join(save_dir, "snapshot.pt")) | |
print(f"Epoch: {epoch} | step {step} | Training snapshot saved at snapshot.pt") | |
# Warmup phase for 2000 steps | |
def warmup_fn(step): | |
if step < 2000: | |
return step / 2000 # LR gradually increases | |
return 1.0 | |
# learning rate decay scheduler (cosine with warmup) from https://github.com/karpathy/nanoGPT/blob/master/train.py | |
def get_lr(it): | |
# 1) linear warmup for warmup_iters steps | |
if it < warmup_iters: | |
return ModelArgs.max_lr * (it + 1) / (warmup_iters + 1) | |
# 2) if it > lr_decay_iters, return min learning rate | |
if it > lr_decay_iters: | |
return min_lr | |
# 3) in between, use cosine decay down to min learning rate | |
decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters) | |
assert 0 <= decay_ratio <= 1 | |
coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) | |
return min_lr + coeff * (ModelArgs.max_lr - min_lr) | |
def train(): | |
setup() | |
device = int(os.environ["LOCAL_RANK"]) | |
torch.cuda.set_device(int(device)) | |
print(f"Start running DDP on rank {device}.") | |
if(device == 0): | |
# # Initialise run | |
wandb.init( | |
# entity = 'rajceo2031', | |
project = 'Llama-DDP-Pretrain-10-billion-tokens', | |
# config = CFG, | |
# save_code = True, | |
#group = 'ANN', | |
#job_type = 'train' | |
) | |
print("wand initialized") | |
model = Llama(embeddings_dims=ModelArgs.embeddings_dims, block_size=ModelArgs.block_size, vocab_size=ModelArgs.vocab_size, dropout=ModelArgs.dropout, device=device) | |
# print(f"Model on device {device} is ready") | |
print(f"Model on device {device} is ready") | |
optimizer = optim.AdamW(model.parameters(), lr=ModelArgs.max_lr, betas=(ModelArgs.beta_1, ModelArgs.beta_2), weight_decay=ModelArgs.weight_decay_optim, eps=ModelArgs.eps) | |
# model = torch.compile(model) | |
model = model.to(device) | |
model = DDP(model, device_ids=[device]) | |
model.eval() | |
world_size = torch.cuda.device_count() | |
def estimate_loss(val_loader, val_iterator, device): | |
out = {} | |
loader = None | |
epoch_loss = None | |
epoch_losses = [] | |
for split in ['val']: | |
print(f"Starting with {split} evaluation...") | |
for step in range(eval_check): | |
try: | |
batch = next(val_iterator) | |
except StopIteration: | |
val_loader_iterator = iter(val_loader) | |
batch = next(val_loader_iterator) | |
total_loss = 0 | |
total_batches = 0 | |
idx = batch['input_ids'] | |
targets = batch['labels'] | |
idx = idx.to(device) | |
targets = targets.to(device) | |
with torch.autocast(device_type=device, dtype=torch.bfloat16): | |
logits = model(idx) | |
batch_size, block_size, embeddings_dims = logits.shape | |
logits = logits.view(batch_size * block_size, embeddings_dims) | |
targets = targets.view(batch_size * block_size) | |
loss = F.cross_entropy(logits, targets, ignore_index=tokenizer.pad_token_id) | |
total_loss += loss.item() | |
total_batches += 1 | |
epoch_loss = total_loss / total_batches if total_batches > 0 else 0.0 | |
epoch_losses.append(epoch_loss) | |
out[split] = sum(epoch_losses) / len(epoch_losses) if epoch_losses else 0.0 | |
epoch_loss = None | |
epoch_losses = [] | |
model.train() | |
return out | |
model.train() | |
count = 0 | |
train_dataloader = prepare_dataset('train', device, ModelArgs.batch_size) | |
val_loader= prepare_dataset('val', device, ModelArgs.batch_size) | |
print("Loaders ready both") | |
epochs = ModelArgs.epochs | |
train_loader_length = 0 | |
train_data_iterator = iter(train_dataloader) | |
val_data_iterator = iter(val_loader) | |
token_count = 0 | |
if(device == 0): | |
train_loader_length = len(train_dataloader) | |
for step in tqdm(range(total_iters)): | |
if(device == 0): | |
print("Step : ", step, "/", total_iters) | |
print('Total batches: ', len(train_dataloader)) | |
print("Total gradient accumulation steps: ", gradient_accumulation_steps) | |
print("Total tokens processed: ", token_count) | |
if (step % eval_iters == 0 and step != 0) or step == total_iters - 1: | |
losses = estimate_loss( val_loader, val_data_iterator, 'cuda') | |
# avg_train_loss = losses['train'] | |
avg_val_loss = losses['val'] | |
print(f"[GPU {device}] | Step: {step} / {total_iters} | Val Loss: {losses['val']:.4f}") | |
avg_val_loss = torch.Tensor([losses['val']]).to(device) | |
# torch.distributed.reduce(avg_train_loss, dst=0, op=torch.distributed.ReduceOp.SUM) | |
torch.distributed.reduce(avg_val_loss, dst=0, op=torch.distributed.ReduceOp.SUM) | |
if device == 0: | |
all_gpus_avg_val_loss = avg_val_loss / world_size | |
print(f"All_GPUs_Val_losses: {all_gpus_avg_val_loss.item():.4f}") | |
wandb.log({ | |
# "Learning Rate": optimizer.param_groups[0]['lr'], | |
# "All_GPUs_Train_losses": all_gpus_avg_train_loss, | |
"All_GPUs_Val_losses": all_gpus_avg_val_loss, | |
# "training_step_loss": losses['train'], | |
"val_step_loss": losses['val'], | |
# "Step": step, | |
# "Epoch": epoch | |
}) | |
if step % save_chechpoint_iter == 0 and device == 0 and step != 0: | |
print(f"Saving the model checkpoint for step: {step}") | |
_save_snapshot(model, optimizer, None, None, step) | |
accumulated_loss = 0.0 | |
optimizer.zero_grad(set_to_none=True) | |
for micro_step in range(gradient_accumulation_steps): | |
try: | |
batch = next(train_data_iterator) | |
except StopIteration: | |
train_data_iterator = iter(train_dataloader) | |
batch = next(train_data_iterator) | |
# print(batch) | |
# batch = next(train_data_iterator) | |
# print(batch) | |
# batch = {k: v.to(self.local_rank) for k, v in batch.items()} | |
idx = batch['input_ids'].to(device) | |
# idx, targets = get_batch(split='train') | |
# print(f"Starting the train step: {step}...") | |
# for idx, targets in train_loader: | |
# idx, targets = next(iter(train_loader)) | |
# print("Idx: ", idx) | |
# print("Targets: ", targets) | |
# idx = idx.to(device) | |
# print("Idx: ", idx) | |
# print("Targets: ", targets) | |
targets = batch['labels'].to(device) | |
token_count += len(idx) | |
with torch.autocast(device_type=ModelArgs.device, dtype=torch.bfloat16): | |
logits = model(idx) | |
batch_size, block_size, embeddings_dims = logits.shape | |
# print(logits.shape) | |
# print(targets) | |
logits = logits.view(batch_size*block_size, embeddings_dims) | |
# print("OK") | |
targets = targets.view(batch_size * block_size) | |
# print("OK2") | |
loss = nn.functional.cross_entropy(logits, targets, ignore_index=tokenizer.pad_token_id) | |
loss = loss / gradient_accumulation_steps #IDK why div is done here specifically? Maybe think of it in terms of a very big batch being processed and there is need for equal important of each mini batch for the overall big batch | |
accumulated_loss += loss.detach() | |
model.require_backward_grad_sync = (micro_step == gradient_accumulation_steps - 1) # so that we dont synchronize the gradient everytime across the GPU devices | |
scaler.scale(loss).backward() | |
# Check for unused parameters | |
unused_params = find_unused_parameters(model) | |
if unused_params: | |
print(f"Unused parameters: {unused_params}") | |
# break | |
if(device == 0): | |
if(micro_step % 10 == 0): | |
# if(step == train_loader_length): | |
# break | |
print("Micro Batch : ", micro_step) | |
print("Step : ", step, "/", total_iters) | |
print('Total batches: ', len(train_dataloader)) | |
print("Total gradient accumulation steps: ", gradient_accumulation_steps) | |
print("Total tokens processed: ", token_count) | |
# count += 1 | |
lr = get_lr(step) | |
for params in optimizer.param_groups: | |
params['lr'] = lr | |
# Compute gradient norms before clipping | |
if(ModelArgs.clip != 0.0): | |
scaler.unscale_(optimizer) #To avoid underflow | |
total_norm_before = torch.norm( | |
torch.stack([torch.norm(p.grad.detach(), 2) for p in model.parameters()]), 2 | |
) | |
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=ModelArgs.clip) | |
# Compute gradient norms after clipping | |
total_norm_after = torch.norm( | |
torch.stack([torch.norm(p.grad.detach(), 2) for p in model.parameters()]), 2 | |
) | |
if(device == 0 and step !=0): | |
print(f"Gradient Norm Before Clipping: {total_norm_before.item():.4f}") | |
print(f"Gradient Norm After Clipping: {total_norm_after.item():.4f}") | |
scaler.step(optimizer) | |
scaler.update() | |
# optimizer.step() | |
# new_scheduler.step() | |
torch.cuda.synchronize() | |
torch.distributed.reduce(loss, dst=0, op=torch.distributed.ReduceOp.SUM) | |
if(device == 0): | |
wandb.log({ | |
"Learning Rate": lr, | |
"All_GPUs_Train_losses": accumulated_loss.item(), | |
# "All_GPUs_Val_losses": all_gpus_avg_val_loss, | |
# "training_step_loss": losses['train'], | |
# "val_step_loss": losses['val'], | |
"Step": step, | |
# "Epoch": epoch | |
}) | |
# print(loss.item()) | |
# break | |
if device == 0 and step % 5 == 0: | |
count = 3 | |
while(count): # Only generate text on the main process | |
prompt = "Once upon a time" | |
generated_text = topk_sampling(model, prompt, max_length=50, top_k=50, temperature=1.0, device=device) | |
print(f" Step: {step} | Generated Text: {generated_text}") | |
count -= 1 | |
if device == 0: | |
wandb.finish() | |
cleanup() | |
world_size = torch.cuda.device_count() | |
print(f"World size: {world_size}") | |
def parse_args(): | |
parser = argparse.ArgumentParser(description="Model Training Arguments") | |
# Add arguments for each field in ModelArgs | |
parser.add_argument("--epochs", type=int, default=ModelArgs.epochs, help="Number of training epochs.") | |
parser.add_argument("--block_size", type=int, default=ModelArgs.block_size, help="Block size for the model.") | |
parser.add_argument("--batch_size", type=int, default=ModelArgs.batch_size, help="Batch size for training.") | |
# parser.add_argument("--inference", type=lambda x: (str(x).lower() == 'true'), default=ModelArgs.inference, help="Whether to run in inference mode.") | |
parser.add_argument("--embeddings_dims", type=int, default=ModelArgs.embeddings_dims, help="Embedding dimensions.") | |
parser.add_argument("--attn_dropout", type=float, default=ModelArgs.attn_dropout, help="Attention dropout rate.") | |
parser.add_argument("--no_of_heads", type=int, default=ModelArgs.no_of_heads, help="Number of attention heads.") | |
parser.add_argument("--dropout", type=float, default=ModelArgs.dropout, help="Dropout rate.") | |
parser.add_argument("--val_epochs", type=int, default=ModelArgs.val_epochs, help="Number of validation epochs.") | |
parser.add_argument("--max_lr", type=float, default=ModelArgs.max_lr, help="Learning rate.") | |
parser.add_argument("--no_of_decoder_layers", type=int, default=ModelArgs.no_of_decoder_layers, help="Number of decoder layers.") | |
parser.add_argument("--weight_decay_optim", type=float, default=ModelArgs.weight_decay_optim, help="Weight decay for optimizer.") | |
parser.add_argument("--beta_1", type=float, default=ModelArgs.beta_1, help="Beta1 for Adam optimizer.") | |
parser.add_argument("--beta_2", type=float, default=ModelArgs.beta_2, help="Beta2 for Adam optimizer.") | |
parser.add_argument("--clip", type=float, default=ModelArgs.clip, help="Gradient clipping value.") | |
parser.add_argument("--device", type=str, default=ModelArgs.device, help="Device to run the model on (e.g., 'cuda' or 'cpu').") | |
parser.add_argument("--no_kv_heads", type=int, default=ModelArgs.no_kv_heads, help="Number of key/value heads.") | |
parser.add_argument("--vocab_size", type=int, default=ModelArgs.vocab_size, help="Vocabulary size.") | |
parser.add_argument("--eps", type=float, default=ModelArgs.eps, help="Epsilon value for numerical stability.") | |
parser.add_argument("--dtype", type=str, default=ModelArgs.dtype, help="Data type for tensors (e.g., 'float16' or 'bfloat16').") | |
parser.add_argument("--save_checkpoint_dir", type=str, default=ModelArgs.save_checkpoint_dir, help="Directory to save model checkpoints.") | |
parser.add_argument("--prompt", type=str, default=ModelArgs.prompt, help="Prompt for testing during training.") | |
# Additional arguments | |
parser.add_argument("--save_checkpoint_iter", type=int, default=ModelArgs.save_checkpoint_iter, help="Save checkpoint every N iterations.") | |
parser.add_argument("--total_iters", type=int, default=ModelArgs.total_iters, help="Total number of training iterations.") | |
parser.add_argument("--eval_iters", type=int, default=ModelArgs.eval_iters, help="Number of iterations for evaluation.") | |
parser.add_argument("--eval_check", type=int, default=ModelArgs.eval_check, help="Evaluate model every N iterations.") | |
parser.add_argument("--warmup_iters", type=int, default=ModelArgs.warmup_iters, help="Number of warmup iterations for learning rate scheduling.") | |
parser.add_argument("--min_lr", type=float, default=ModelArgs.min_lr, help="Minimum learning rate.") | |
parser.add_argument("--lr_decay_iters", type=int, default=ModelArgs.lr_decay_iters, help="Number of iterations for learning rate decay.") | |
parser.add_argument("--total_batch_size", type=int, default=ModelArgs.total_batch_size, help="Total batch size across all devices.") | |
parser.add_argument("--micro_batch_size", type=int, default=ModelArgs.micro_batch_size, help="Micro batch size per device.") | |
parser.add_argument("--gradient_accumulation_steps", type=int, default=ModelArgs.gradient_accumulation_steps, help="Number of gradient accumulation steps.") | |
args = parser.parse_args() | |
return args | |
def initialize_model_args(args): | |
# Create a ModelArgs instance from the parsed arguments | |
model_args = ModelArgs( | |
epochs=args.epochs, | |
block_size=args.block_size, | |
batch_size=args.batch_size, | |
# inference=args.inference, | |
embeddings_dims=args.embeddings_dims, | |
attn_dropout=args.attn_dropout, | |
no_of_heads=args.no_of_heads, | |
dropout=args.dropout, | |
val_epochs=args.val_epochs, | |
max_lr=args.max_lr, | |
no_of_decoder_layers=args.no_of_decoder_layers, | |
weight_decay_optim=args.weight_decay_optim, | |
beta_1=args.beta_1, | |
beta_2=args.beta_2, | |
clip=args.clip, | |
device=args.device, | |
no_kv_heads=args.no_kv_heads, | |
vocab_size=args.vocab_size, | |
eps=args.eps, | |
dtype=args.dtype, | |
save_checkpoint_dir=args.save_checkpoint_dir, | |
prompt=args.prompt, | |
save_checkpoint_iter=args.save_checkpoint_iter, | |
total_iters=args.total_iters, | |
eval_iters=args.eval_iters, | |
eval_check=args.eval_check, | |
warmup_iters=args.warmup_iters, | |
min_lr=args.min_lr, | |
lr_decay_iters=args.lr_decay_iters, | |
total_batch_size=args.total_batch_size, | |
micro_batch_size=args.micro_batch_size, | |
gradient_accumulation_steps=args.gradient_accumulation_steps | |
) | |
return model_args | |
if __name__ == "__main__": | |
args = parse_args() | |
model_args = initialize_model_args(args) | |