Spaces:
Sleeping
Sleeping
import torch | |
import logging | |
from src.utils.helpers import get_batch | |
def estimate_loss(model, eval_iters, block_size, batch_size, device): | |
out = {} | |
model.eval() | |
for split in ['train', 'val']: | |
losses = torch.zeros(eval_iters) | |
for k in range(eval_iters): | |
xb, yb = get_batch(split, block_size, batch_size) | |
xb, yb = xb.to(device), yb.to(device) | |
logits, loss = model(xb, yb) | |
losses[k] = loss.item() | |
out[split] = losses.mean().item() | |
model.train() | |
return out | |
def train( | |
model, | |
optimizer, | |
max_iters, | |
eval_interval, | |
eval_iters, | |
block_size, | |
batch_size, | |
device, | |
checkpoint_path="checkpoints/model.pth" | |
): | |
logger = logging.getLogger(__name__) | |
best_val_loss = float('inf') | |
for iter in range(max_iters): | |
# Evaluation | |
if iter % eval_interval == 0: | |
losses = estimate_loss(model, eval_iters, block_size, batch_size, device) | |
logger.info( | |
f"Step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}" | |
) | |
# Save best model | |
if losses['val'] < best_val_loss: | |
best_val_loss = losses['val'] | |
logger.info(f"Saving model with val loss: {best_val_loss:.4f}") | |
torch.save(model, checkpoint_path) | |
# Training step | |
xb, yb = get_batch('train', block_size, batch_size) | |
xb, yb = xb.to(device), yb.to(device) | |
# Forward pass | |
logits, loss = model(xb, yb) | |
# Backward pass | |
optimizer.zero_grad(set_to_none=True) | |
loss.backward() | |
optimizer.step() | |
# Save final model | |
torch.save(model, checkpoint_path) |