The Ultra-Scale Playbook: Training LLMs on GPU Clusters


First Steps: Training on one GPU

Memory usage in Transformers

Memory profiling a training step

Weights/grads/optimizer states memory

Activations memory

Activation recomputation

Gradient accumulation

Data Parallelism

First optimization: Overlap gradient synchronization with backward pass

Second optimization: Bucketing gradients

Third optimization: Interplay with gradient accumulation

Revisit global batch size

Our journey up to now

ZeRO (Zero Redundancy Optimizer)

Memory usage revisited

ZeRO-1: Partitioning Optimizer States

ZeRO-2: Adding Gradient Partitioning

ZeRO-3: Adding Parameter Partitioning

Tensor Parallelism

Tensor Parallelism in a Transformer Block

Sequence Parallelism

Context Parallelism

Introducing Context Parallelism

Discovering Ring Attention

Zig-Zag Ring Attention – A Balanced Compute Implementation

Pipeline Parallelism

Splitting layers on various nodes - All forward, all backward

One-forward-one-backward and LLama 3.1 schemes

Interleaving stages

Zero Bubble and DualPipe

Expert parallelism

5D parallelism in a nutshell

How to Find the Best Training Configuration

Diving in the GPUs – fusing, threading, mixing

A primer on GPU

How to improve performance with Kernels ?

Memory Coalescing


Thread Coarsening

Minimizing Control Divergence

Flash Attention 1-3

Fused Kernels

Mixed Precision Training

FP16 and BF16 training

FP8 pretraining


What you learned

What we learned

What’s next?


Landmark LLM Scaling Papers

Training Frameworks


Distribution Techniques

CUDA Kernels




A0: Parallel Programming Crash Course


Reduce & AllReduce

A quick focus on Ring All-Reduce

Gather & AllGather

Scatter & ReduceScatter


NCCL: NVIDIA Collective Communications Library

A1: Profiling


Print a table of the profiling results, sorted by total CUDA time, limited to the top 10 entries




Load and compile the CUDA extension

Define input tensors

Run the CUDA kernel

A2: TP Backward pass

A3: ZeRO-R

$P_a:$ Partitioned Activation Checkpointing

$C_B:$ Constant Size Buffers

$M_D$: Memory Defragmentation

Communication Analysis of ZeRO-R

A5. Memory profile

Set up optimizer

TP: Practical PyTorch Implementation

This is the f function in the paper:

core logic of Column Parallel linear

Gelu code


How to profile your code

Formulas for compute / comms the balanhe balance

Integrating Context Parallelism with TP/SP

The nanotron FP8 recipe

Overlapping computation and communication