The Ultra-Scale Playbook: Training LLMs on GPU Clusters
TL;DR
First Steps: Training on one GPU
Memory usage in Transformers
Memory profiling a training step
Weights/grads/optimizer states memory
Activations memory
Activation recomputation
Gradient accumulation
Data Parallelism
First optimization: Overlap gradient synchronization with backward pass
Second optimization: Bucketing gradients
Third optimization: Interplay with gradient accumulation
Revisit global batch size
Our journey up to now
ZeRO (Zero Redundancy Optimizer)
Memory usage revisited
ZeRO-1: Partitioning Optimizer States
ZeRO-2: Adding Gradient Partitioning
ZeRO-3: Adding Parameter Partitioning
Tensor Parallelism
Tensor Parallelism in a Transformer Block
Sequence Parallelism
Context Parallelism
Introducing Context Parallelism
Discovering Ring Attention
Zig-Zag Ring Attention – A Balanced Compute Implementation
Pipeline Parallelism
Splitting layers on various nodes - All forward, all backward
One-forward-one-backward and LLama 3.1 schemes
Interleaving stages
Zero Bubble and DualPipe
Expert parallelism
5D parallelism in a nutshell
How to Find the Best Training Configuration
Diving in the GPUs – fusing, threading, mixing
A primer on GPU
How to improve performance with Kernels ?
Memory Coalescing
Tiling
Thread Coarsening
Minimizing Control Divergence
Flash Attention 1-3
Fused Kernels
Mixed Precision Training
FP16 and BF16 training
FP8 pretraining
Conclusion
What you learned
What we learned
What’s next?
References
Landmark LLM Scaling Papers
Training Frameworks
Debugging
Distribution Techniques
CUDA Kernels
Hardware
Others
Appendix
A0: Parallel Programming Crash Course
Broadcast
Reduce & AllReduce
A quick focus on Ring All-Reduce
Gather & AllGather
Scatter & ReduceScatter
Barrier
NCCL: NVIDIA Collective Communications Library
A1: Profiling
Kernels
Print a table of the profiling results, sorted by total CUDA time, limited to the top 10 entries
include
include
include
Load and compile the CUDA extension
Define input tensors
Run the CUDA kernel
A2: TP Backward pass
A3: ZeRO-R
$P_a:$ Partitioned Activation Checkpointing
$C_B:$ Constant Size Buffers
$M_D$: Memory Defragmentation
Communication Analysis of ZeRO-R
A5. Memory profile
Set up optimizer
TP: Practical PyTorch Implementation
This is the f
function in the paper: https://arxiv.org/abs/1909.08053
core logic of Column Parallel linear
Gelu code
Interconnect
How to profile your code
Formulas for compute / comms the balanhe balance
Integrating Context Parallelism with TP/SP
The nanotron FP8 recipe
Overlapping computation and communication