import torch import logging from src.utils.helpers import get_batch @torch.no_grad() def estimate_loss(model, eval_iters, block_size, batch_size, device): out = {} model.eval() for split in ['train', 'val']: losses = torch.zeros(eval_iters) for k in range(eval_iters): xb, yb = get_batch(split, block_size, batch_size) xb, yb = xb.to(device), yb.to(device) logits, loss = model(xb, yb) losses[k] = loss.item() out[split] = losses.mean().item() model.train() return out def train( model, optimizer, max_iters, eval_interval, eval_iters, block_size, batch_size, device, checkpoint_path="checkpoints/model.pth" ): logger = logging.getLogger(__name__) best_val_loss = float('inf') for iter in range(max_iters): # Evaluation if iter % eval_interval == 0: losses = estimate_loss(model, eval_iters, block_size, batch_size, device) logger.info( f"Step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}" ) # Save best model if losses['val'] < best_val_loss: best_val_loss = losses['val'] logger.info(f"Saving model with val loss: {best_val_loss:.4f}") torch.save(model, checkpoint_path) # Training step xb, yb = get_batch('train', block_size, batch_size) xb, yb = xb.to(device), yb.to(device) # Forward pass logits, loss = model(xb, yb) # Backward pass optimizer.zero_grad(set_to_none=True) loss.backward() optimizer.step() # Save final model torch.save(model, checkpoint_path)