lvwerra HF staff commited on
Commit
89bae42
·
1 Parent(s): 0ccc803

add one gpu section

Browse files
blog-export-headrs.html ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <h2>The Ultra-Scale Playbook: Training LLMs on GPU Clusters</h2>
2
+
3
+ <h2>TL;DR</h2>
4
+
5
+ <h2>First Steps: Training on one GPU</h2>
6
+
7
+ <h3>Memory usage in Transformers</h3>
8
+
9
+ <h4>Memory profiling a training step</h4>
10
+
11
+ <h4>Weights/grads/optimizer states memory</h4>
12
+
13
+ <h4>Activations memory</h4>
14
+
15
+ <h3><strong>Activation recomputation</strong></h3>
16
+
17
+ <h3>Gradient accumulation</h3>
18
+
19
+ <h2>Data Parallelism</h2>
20
+
21
+ <h4><strong>First optimization:</strong> Overlap gradient synchronization with backward pass</h4>
22
+
23
+ <h4><strong>Second optimization:</strong> Bucketing gradients</h4>
24
+
25
+ <h4><strong>Third optimization: I</strong>nterplay with gradient accumulation</h4>
26
+
27
+ <h3>Revisit global batch size</h3>
28
+
29
+ <h3>Our journey up to now</h3>
30
+
31
+ <h3>ZeRO (<strong>Ze</strong>ro <strong>R</strong>edundancy <strong>O</strong>ptimizer)</h3>
32
+
33
+ <h4>Memory usage revisited</h4>
34
+
35
+ <h4>ZeRO-1: Partitioning Optimizer States</h4>
36
+
37
+ <h4>ZeRO-2: Adding <strong>Gradient Partitioning</strong></h4>
38
+
39
+ <h4>ZeRO-3: Adding Parameter <strong>Partitioning</strong></h4>
40
+
41
+ <h2>Tensor Parallelism</h2>
42
+
43
+ <h3>Tensor Parallelism in a Transformer Block</h3>
44
+
45
+ <h3>Sequence Parallelism</h3>
46
+
47
+ <h2>Context Parallelism</h2>
48
+
49
+ <h3>Introducing Context Parallelism</h3>
50
+
51
+ <h3>Discovering Ring Attention</h3>
52
+
53
+ <h3>Zig-Zag Ring Attention – A Balanced Compute Implementation</h3>
54
+
55
+ <h2></h2>
56
+
57
+ <h2>Pipeline Parallelism</h2>
58
+
59
+ <h3>Splitting layers on various nodes - All forward, all backward</h3>
60
+
61
+ <h3>One-forward-one-backward and LLama 3.1 schemes</h3>
62
+
63
+ <h3>Interleaving stages</h3>
64
+
65
+ <h3>Zero Bubble and DualPipe</h3>
66
+
67
+ <h2>Expert parallelism</h2>
68
+
69
+ <h2>5D parallelism in a nutshell</h2>
70
+
71
+ <h2>How to Find the Best Training Configuration</h2>
72
+
73
+ <h2>Diving in the GPUs – fusing, threading, mixing</h2>
74
+
75
+ <h4>A primer on GPU</h4>
76
+
77
+ <h3>How to improve performance with Kernels ?</h3>
78
+
79
+ <h4>Memory Coalescing</h4>
80
+
81
+ <h4>Tiling</h4>
82
+
83
+ <h4>Thread Coarsening</h4>
84
+
85
+ <h4>Minimizing Control Divergence</h4>
86
+
87
+ <h3>Flash Attention 1-3</h3>
88
+
89
+ <h3>Fused Kernels</h3>
90
+
91
+ <h3>Mixed Precision Training</h3>
92
+
93
+ <h4>FP16 and BF16 training</h4>
94
+
95
+ <h4>FP8 pretraining</h4>
96
+
97
+ <h2>Conclusion</h2>
98
+
99
+ <h3>What you learned</h3>
100
+
101
+ <h3>What we learned</h3>
102
+
103
+ <h3>What’s next?</h3>
104
+
105
+ <h2>References</h2>
106
+
107
+ <h3>Landmark LLM Scaling Papers</h3>
108
+
109
+ <h3>Training Frameworks</h3>
110
+
111
+ <h3>Debugging</h3>
112
+
113
+ <h3>Distribution Techniques</h3>
114
+
115
+ <h3>CUDA Kernels</h3>
116
+
117
+ <h3>Hardware</h3>
118
+
119
+ <h3>Others</h3>
120
+
121
+ <h2>Appendix</h2>
122
+
123
+ <h3>A0: Parallel Programming Crash Course</h3>
124
+
125
+ <h4>Broadcast</h4>
126
+
127
+ <h4>Reduce &amp; AllReduce</h4>
128
+
129
+ <h4><strong>A quick focus on Ring All-Reduce</strong></h4>
130
+
131
+ <h4>Gather &amp; AllGather</h4>
132
+
133
+ <h4>Scatter &amp; ReduceScatter</h4>
134
+
135
+ <h4>Barrier</h4>
136
+
137
+ <h4>NCCL: NVIDIA Collective Communications Library</h4>
138
+
139
+ <h3>A1: Profiling</h3>
140
+
141
+ <h4>Kernels</h4>
142
+
143
+ <h2>Print a table of the profiling results, sorted by total CUDA time, limited to the top 10 entries</h2>
144
+
145
+ <h2>include <torch/extension.h></h2>
146
+
147
+ <h2>include <cuda.h></h2>
148
+
149
+ <h2>include <cuda_runtime.h></h2>
150
+
151
+ <h2>Load and compile the CUDA extension</h2>
152
+
153
+ <h2>Define input tensors</h2>
154
+
155
+ <h2>Run the CUDA kernel</h2>
156
+
157
+ <h3>A2: TP Backward pass</h3>
158
+
159
+ <h3>A3: ZeRO-R</h3>
160
+
161
+ <h4>$P_a:$ Partitioned Activation Checkpointing</h4>
162
+
163
+ <h4><strong>$C_B:$ Constant Size Buffers</strong></h4>
164
+
165
+ <h4><strong>$M_D$: Memory Defragmentation</strong></h4>
166
+
167
+ <h4>Communication Analysis of ZeRO-R</h4>
168
+
169
+ <h3>A5. Memory profile</h3>
170
+
171
+ <h2>Set up optimizer</h2>
172
+
173
+ <h3>TP: Practical PyTorch Implementation</h3>
174
+
175
+ <h2>This is the <code>f</code> function in the paper: https://arxiv.org/abs/1909.08053</h2>
176
+
177
+ <h2>core logic of Column Parallel linear</h2>
178
+
179
+ <h4>Gelu code</h4>
180
+
181
+ <h4>Interconnect</h4>
182
+
183
+ <h3>How to profile your code</h3>
184
+
185
+ <h3>Formulas for compute / comms the balanhe balance</h3>
186
+
187
+ <h3>Integrating Context Parallelism with TP/SP</h3>
188
+
189
+ <h3>The nanotron FP8 recipe</h3>
190
+
191
+ <h2>Overlapping computation and communication</h2>
192
+
blog-export.html CHANGED
@@ -94,12 +94,6 @@ We’ll neglect these last two contributors as they are typically small and cons
94
  </blockquote>
95
  <p>These items are stored as tensors which come in different <em>shapes</em> and <em>precisions</em>. The <em>shapes</em> are determined by hyper-parameters such as batch size, sequence length, model hidden dimensions, attention heads, vocabulary size, and potential model sharding as we’ll see later. <em>Precision</em> refers to formats like FP32, BF16, or FP8, which respectively require 4, 2, or 1 byte to store each single value in the tensor.</p>
96
  <p>So how can I quickly determine memory usage from these variable? One simple way is to do this empirically and just measure it.</p>
97
- <p><img alt="**Overlapped backward pass (stream 7) and gradients accumulation (stream 28) means we start the optimizer step as soon as the backward pass is done**" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%205.png" /></p>
98
- <p><strong>Overlapped backward pass (stream 7) and gradients accumulation (stream 28) means we start the optimizer step as soon as the backward pass is done</strong></p>
99
- <p><img alt="In this naive approach we see a long AllReduce operation (stream 28) happening to sync all gradients after the backward pass is over, after which comes the optimizer step. GPUs stay idle while the communication happens." src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%206.png" /></p>
100
- <p>In this naive approach we see a long AllReduce operation (stream 28) happening to sync all gradients after the backward pass is over, after which comes the optimizer step. GPUs stay idle while the communication happens.</p>
101
- <p><img alt="**Overlapped backward pass (stream 7) and gradients accumulation (stream 28) means we start the optimizer step as soon as the backward pass is done**" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%205.png" /></p>
102
- <p><strong>Overlapped backward pass (stream 7) and gradients accumulation (stream 28) means we start the optimizer step as soon as the backward pass is done</strong></p>
103
  <h3>Memory profiling a training step</h3>
104
  <p>Using this snippet [TODO: link to appendix A5], we can understand how memory is allocated throughout training. We can see that memory utilization is not a static thing but varies a lot during training and during a training step:</p>
105
  <p><img alt="llama-1b-memory.png" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/44e32d18-4ed6-455a-a1f7-bbbdebe2fefd.png" /></p>
@@ -172,7 +166,7 @@ m_{act} = L<em> seq * bs * h * (34 + \frac{5</em>n_{heads}*seq}{h})</p>
172
  <p>It’s time to explain our first technique – called <strong><em>activation recomputation</em><em>–</em> </strong>**which will help us cap activation memory footprint. An essential tool in today’s large model training toolbox.</p>
173
  <h2><strong>Activation recomputation</strong></h2>
174
  <p>The general idea behind <strong><em>activation recomputation</em><em> –</em>also called </strong><em>gradient checkpointing</em><strong> or </strong><em>rematerialization</em><em>– </em>****is to discard some activations during the forward pass to save memory and spend some extra compute to recompute these on the fly during the backward pass. Without recomputation, we store every hidden state between two learnable operations (e.g. FF, LayerNorm etc.), such that we can use them during the backward pass to compute gradients. When we use recomputation we typically will only store activations at a few key points along the model architecture, discard the rest of activations and recompute them on the fly during the backward pass from the nearest saved activations, basically performing again a sub-part of the forward pass to trade of memory for compute. It generally looks like this:</p>
175
- <p><img alt="image.png" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%207.png" /></p>
176
  <p>There are several strategies to select key activations to store:</p>
177
  <ul>
178
  <li><strong>Full</strong>: We checkpoint activations at the transition point between each layer of the Transformer model. This is usually called the <code>full</code> strategy since it requires a forward pass through each layer essentially adding a full forward pass during the backward pass. This strategy saves the most memory but is the most expensive one in terms of compute. It generally increases the compute cost and time by up to 30-40% which is very noticeable.</li>
@@ -204,7 +198,7 @@ m_{act} = L<em> seq * bs * h * (34 + \frac{5</em>n_{heads}*seq}{h})</p>
204
  bs = gbs = mbs \times grad_acc
205
  $$</p>
206
  <p>Gradient accumulation allows us to effectively increase our batch size up to infinity (and beyond!) while the memory footprint stays constant. Gradient accumulation is also compatible with activation recomputation for further memory reduction. One drawback however, is that gradient accumulation requires multiple consecutive forward/backward passes per optimization step thereby increasing the compute overhead and slowing down training. No free lunch! </p>
207
- <p><img alt="image.png" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%208.png" /></p>
208
  <p><strong>Gradient accumulation allows us to reduce memory of activations which grow linearly with batch size by computing only only partial, micro-batches. There is a small overhead caused by the additional forward and backward passes.</strong></p>
209
  <blockquote>
210
  <p>Note: Using gradient accumulation means we need to keep buffers where we accumulate gradients which persist throughout a training step. Whereas without gradient accumulation, in the backward gradients are computed while freeing the activations memory, which means a lower peak memory.
@@ -212,6 +206,14 @@ $$</p>
212
  </blockquote>
213
  <p>But if you’ve carefully followed, you probably noticed that the forward/backward passes for each micro-batch can actually be run in parallel. Forward/backward passes are independent from each other, with independent input samples being the only difference. Seems like it’s time to start extending our training to more than one GPU! </p>
214
  <p>Let’s get a larger workstation 🖥️ with a couple of GPUs and start investigating our first scaling technique called <em><strong>data parallelism</strong> which is just a parallel version of gradient accumulation</em>.</p>
 
 
 
 
 
 
 
 
215
  <h1>Data Parallelism</h1>
216
  <p>The idea behind data parallelism (DP) is to replicate the model on several GPUs (we call the replica's “model instances”) and run forward and backward passes on different micro batches of data in parallel for each GPU, hence the name Data Parallelism. </p>
217
  <p>Using a different micro batch for each GPU means we’ll have different gradients in each GPU, so to keep the model instances in sync across different GPUs, the gradients from the model instances are averaged using an operation called “all-reduce”, which happens during the backward pass, before the optimizer step.</p>
@@ -303,9 +305,9 @@ $$</p>
303
  </blockquote>
304
  <p>While Data Parallelism is very efficient in scaling training, the naive replication of optimizer states, gradients, and parameters across each DP rank introduces a significant memory redundancy. ZeRO eliminates memory redundancy by partitioning the optimizer states, gradients, and parameters across the data parallel dimension, while still allowing computation with the full set of parameters. This sometimes requires more communications between DP ranks which may or may not be fully overlapped as we’ll see next!</p>
305
  <p>This approach is organized into three possible optimization stage of ZeRO:</p>
306
- <p>ZeRO-1: optimizer state partitioning ($P_{os}$)</p>
307
- <p>ZeRO-2: optimizer state + gradient partitioning ($P_{os + g}$)</p>
308
- <p>ZeRO-3 (also called FSDP for “Fully-Sharded Data Parallelism”): optimizer state + gradient + parameter partitioning ($P_{os + g + p}$)</p>
309
  <blockquote>
310
  <p>Note: You might be missing the activations among the things we can shard. Since each DP replica of the model receives a different microbatch the activations on each DP rank also differ so they are not duplicated and thus can’t be sharded!
311
  </p>
@@ -467,33 +469,57 @@ SP region needs full hidden_dim</p>
467
  - "f" is an all-reduce to synchronize gradients</p>
468
  <p>These operations "f" and "f*" are called conjugate pairs because they complement each other - when one is a no-op in forward, the other is an all-reduce in backward, and vice versa.</p>
469
  <p>For sequence parallelism (SP), we use different operations labeled "g" and "g*". Specifically, we avoid using all-reduce in the SP region since that would require gathering the full activations and increase our peak memory usage, defeating the purpose of SP.</p>
470
- <p>Which is why we call “f” and “f<em>” conjugate (as explained in <a href="https://arxiv.org/pdf/2205.05198">https://arxiv.org/pdf/2205.05198</a>). For sequence parallelism, in “g</em>”, we don’t want to all-reduce activations in the “SP” region as that would increase our peak memory usage.</p>
471
- <p><img alt="The MLP block in TP+SP: After matrix multiplications Z1*B1 and Z1*B2, we need to restore the hidden dimension. Instead of using all-reduce like in vanilla TP, we use reduce-scatter along the sequence dimension to maintain sharding. This keeps activations sharded in both TP and SP regions." src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2031.png" /></p>
472
- <p>The MLP block in TP+SP: After matrix multiplications Z1<em>B1 and Z1</em>B2, we need to restore the hidden dimension. Instead of using all-reduce like in vanilla TP, we use reduce-scatter along the sequence dimension to maintain sharding. This keeps activations sharded in both TP and SP regions.</p>
473
- <p>Let's look at the MLP block in detail:</p>
474
  <ol>
475
- <li>After the matrix multiplications Z1<em>B1 and Z1</em>B2, we need to restore the hidden dimension for correctness (this is part of row-linear)</li>
476
- <li>However, instead of using all-reduce like in vanilla TP, we use reduce-scatter along the sequence dimension. We can do this because the SP region don’t need the full sequence dimension.</li>
477
- <li>To optimize memory usage, we overlap the computation of Z1<em>B1 with the reduce-scatter operation. This keeps our peak memory usage low - the activation shape never exceeds (seq</em>mbs*hidden_dim/tp) at any point.</li>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
478
  </ol>
479
- <p>This approach maintains sharding in both TP and SP regions while ensuring mathematical correctness.</p>
480
- <p>Instead, after the “Z1 B1” and “Z1 B2” matmuls we restore the hidden_dim (see row-linear section above), we need to reduce in (g*) to ensure correctness, because that’s part of the row-linear, but instead of all-reduce we scatter along sequence dimension, since SP region is independent along this dimension, so that we keep activations sharded in both TP and SP regions. </p>
481
- <p>Notice that at the moment where we restore the hidden_dim for “W1” and “W2”, we already have full activations with shape (seq<em>mbs</em>hidden_dim), so in practice we overlap the compute of “Z1 B1” with the <em>reduce-scatter</em> to always keep a max activation shape of seq<em>mbs</em>hidden_dim/tp</p>
482
- <p>Does that mean that SP incurs more communication than TP? Well, yes and no. In the forward of a vanilla TP we had two all-reduce per transformer block, and in SP we have two all-gather and two reduce-scatter per transformer block. So SP does twice the number of communication operations as TP. But since an all-reduce operation can be broken down into to an all-gather + reduce-scatter (see in [TODO: Appendix link]) they’re actually equivalent in terms of communication. Same reasoning for backward as we just use the conjugate of each operation (no-op ↔ allreduce and allgather ↔ reducescatter).</p>
483
  <p>It’s a bit difficult to keep track of all the parts that are sharded differently in TP and TP/SP - believe us, we find it hard to map as well so we made this small table to summarize how the activations (aka <code>hidden_states</code> ) shape change across hidden dimension h and sequence dimension s during a forward pass:</p>
484
  <p>| Region | Vanilla TP | TP with SP |
485
  | --- | --- | --- |
486
- | Enter TP (Column Linear) | hidden dimension: sharded (weight_out is sharded)
487
- sequence dimension: unchanged | h: sharded (weight_out is sharded)
488
  s: <strong>all-gather</strong> to full |
489
  | TP Region | h: sharded
490
- s: unchanged | h: sharded
491
  s: full |
492
  | Exit TP (Row Linear) | h: full (weight_out is full + <strong>all-reduce</strong> for correctness)
493
- s: unchanged | h: full (weight_out is full + <strong>reduce-scatter</strong> for correctness)
494
  s: <strong>reduce-scatter</strong> to sharded |
495
  | SP Region | h: full
496
- s: unchanged | h: full
497
  s: sharded |</p>
498
  <p>And for the embedding layer</p>
499
  <p>| Region | Vanilla TP | TP with SP |
@@ -501,13 +527,12 @@ s: sharded |</p>
501
  | Embedding Layer (Row Linear sharded on vocab) | h: full (weight_out is full + <strong>all-reduce</strong> for correctness)
502
  s: unchanged | h: full (weight_out is full + <strong>reduce-scatter</strong> for correctness)
503
  s: <strong>reduce-scatter</strong> to sharded |</p>
504
- <p>You can find an example of implementation of both column and row linear TP in pictotron:
 
505
  <a href="https://github.com/huggingface/picotron/blob/main/picotron/tensor_parallel/tensor_parallel.py">https://github.com/huggingface/picotron/blob/main/picotron/tensor_parallel/tensor_parallel.py</a> </p>
506
- <p>TODO: everything after here is still a bit a mess (comment by leandro)</p>
507
- <p>Experimentally, TP+SP reduces the memory requirements while keeping a minimal compute overhead as we can see here, typically reducing by 30-50% activation memories</p>
508
- <p><img alt="image.png" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2032.png" /></p>
509
  <p>If you’ve been paying close attention, you’ll notice that we’re talking about 4 comms ops <strong>IN EACH LAYER</strong> (2 for Attention and 2 for MLP), as shown here for the MLP region:</p>
510
- <p><img alt="image.png" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2033.png" /></p>
511
  <p>Besides the fact that TP requires communication in each layer, it also can’t easily be overlapped with compute, which makes throughput heavily dependent on the communication bandwidth. This is why TP is usually done only within a node (TP≤8)</p>
512
  <blockquote>
513
  <p>Notice how all-gather is overlapped with “Y A1” that’s thanks to this trick
@@ -519,14 +544,14 @@ and you can find more tricks <a href="https://discuss.pytorch.org/t/distributed-
519
  <ul>
520
  <li>TP</li>
521
  </ul>
522
- <p><img alt="image.png" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2034.png" /></p>
523
  <ul>
524
  <li>Seq Parall</li>
525
  </ul>
526
- <p><img alt="image.png" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2035.png" /></p>
527
  <p>Allreduce takes almost double the duration (900us) of reducescatter and allgather (500us)</p>
528
  <p>Let’s compare throughput as we scale TP and TP/SP for a 3B model:</p>
529
- <p><img alt="Impact of combined Tensor and Sequence Parallelism (TP/SP) on model performance and memory utilization: when scaling both TP and SP together, there's a trade-off between computational efficiency (left) and memory capacity (right). While higher parallelism degrees reduce per-GPU throughput, they enable processing of significantly larger batch sizes by reducing the activation memory." src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2036.png" /></p>
530
  <p>Impact of combined Tensor and Sequence Parallelism (TP/SP) on model performance and memory utilization: when scaling both TP and SP together, there's a trade-off between computational efficiency (left) and memory capacity (right). While higher parallelism degrees reduce per-GPU throughput, they enable processing of significantly larger batch sizes by reducing the activation memory.</p>
531
  <p>Let’s summarize our observations:</p>
532
  <ul>
@@ -544,11 +569,11 @@ and you can find more tricks <a href="https://discuss.pytorch.org/t/distributed-
544
  <h1>Context Parallelism</h1>
545
  <p>With Tensor Parallelism and Sequence Parallelism, we can reduce the memory requirements per GPU significantly as both model weights and activations are distributed across GPUs. However, when training models on longer and longer sequences (e.g. when scaling to 128k or more tokens per sequence) we might still exceed the memory available on a single node, because inside the TP region we still have to process a full sequence length.</p>
546
  <p>Even if we use full recomputation of the activations, which comes at a heavy compute overhead (30%), we still need to hold in memory some activations at the layer boundaries which scale linearly with sequence length:</p>
547
- <p><img alt="image.png" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2037.png" /></p>
548
  <p>Can we apply similar ideas to our sequence parallelism approach but inside in the modules where we apply Tensor Parallelism already, thereby also reducing the effect of sequence length? Yes, it’s time to talk about Context Parallelism, which you will find quite intuitive after all we’ve already convered.</p>
549
  <h2>Introducing Context Parallelism</h2>
550
  <p>The idea of Context Parallelism is quite simple; just like Sequence Parallelism, we’ll split the input along the sequence dimension but we now apply this splitting along the full model, instead of only the sequence parallel regions of the model. Our focus here will be to reduce the activation memory footprint by splitting the long sequences, complementing parallelism strategies like TP which target the hidden dimension of the model.</p>
551
- <p><img alt="image.png" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2038.png" /></p>
552
  <p>Splitting the sequence doesn't affect most modules like MLP and LayerNorm, where each token is processed independently. It also doesn’t require expensive communication like TP, as only the inputs are split and not the weight matrices. Just as in data parallelism, after computing the gradients, an all-reduce operation is initiated to synchronize the gradients across the context parallelism group.</p>
553
  <p>There is one important exception though, which is the <strong><em>attention module</em></strong>. In this module each token needs to access key/value pairs from <strong>all</strong> other sequence tokens or in the case of causal attention at least attends to each previous token.</p>
554
  <p>Because Context Parallelism splits the inputs along the sequence dimension across GPUs, the attention module requires full communication between GPUs to exchange the necessary key/value data.</p>
@@ -570,33 +595,33 @@ and you can find more tricks <a href="https://discuss.pytorch.org/t/distributed-
570
  <p><img alt="ring-attention.gif" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/ring-attention.gif" /></p>
571
  <p>With this animation, it’s also immediately clear why the authors chose to call this approach Ring Attention.</p>
572
  <p>There is one big problem though which is that a naive implementation of Ring Attention lead to some strong imbalance between GPU streaming from the shape of the causal attention matrix. Let’s take a real look at what is happening in the SoftMax computation by considering the attention score matrix with the causal attention mask:</p>
573
- <p><img alt="image.png" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2039.png" /></p>
574
  <p>The SoftMax is computed row-wise, which means whenever a GPU has received all the tokens of a row it can be computed. We see that GPU1 can immediately compute it as it starts with tokens 1-4 and GPU1 actually doesn’t need to receive any information from any other GPUs. However, GPU2 will need to wait for the second round to also receive 1-4 and thus have all values for tokens 1-8. Also, GPU1 seems to perform much less work than all the other GPUs.</p>
575
  <p>Let’s see if we can balance our computations better:</p>
576
  <h2>Zig-Zag Ring Attention – A Balanced Compute Implementation</h2>
577
  <p>We need a better way to distribute the input sequences. This can be achieved by assigning the tokens not purely sequential to the GPUs but by mixing the ordering a bit such that we have a good mix of early and late tokens on each GPU. This approach is called <a href="https://arxiv.org/pdf/2311.09431">Zig-Zag attention</a> and in this new arrangement, the attention mask will show an even distribution of computation but if you count the number of colored squares, you’ll see that the computation is now balanced across all GPUs.</p>
578
- <p><img alt="image.png" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2040.png" /></p>
579
  <p>At the same time we’ll also see that in order to complete all rows, each GPU will need information from all the other GPUs.</p>
580
  <p>We have two general ways to overlap computation and communication, either by performing a general all-gather, regrouping all the KV on each GPUs at the same time (in a Zero-3 type of way) or we gather them one-by-one from each GPU to each GPU as needed:</p>
581
- <p><img alt="Context Parallelism using AllGather implementation" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2041.png" /></p>
582
  <p>Context Parallelism using AllGather implementation</p>
583
- <p><img alt="Context Parallelism using All-to-All implementation" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2042.png" /></p>
584
  <p>Context Parallelism using All-to-All implementation</p>
585
  <p>TODO: add links to megatronlm(AllGather) and deepspeed(All2All) implementations</p>
586
  <h1></h1>
587
  <h1>Pipeline Parallelism</h1>
588
  <p>In the TP section we saw that if we try to scale Tensor parallelism past the number of GPUs per single node (typically 4 or 8) we hit a lower bandwidth network called “inter-node connection” which can quite strongly impair our performances. We can see this clearly on e.g. the all-reduce operation when we perform it across several nodes:</p>
589
- <p><img alt="Inter-node communication bandwidth measurements across different node counts, showing median (lines) and 5th-95th percentile ranges (shaded areas) for AllReduce, AllGather and ReduceScatter operations." src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2043.png" /></p>
590
  <p>Inter-node communication bandwidth measurements across different node counts, showing median (lines) and 5th-95th percentile ranges (shaded areas) for AllReduce, AllGather and ReduceScatter operations.</p>
591
  <p>Sequence and context parallelism can help for long sequences but don’t help much if sequence length is not the root cause of our memory issues but rather the size of the model itself. For large model (70B+), the size of the weights alone can already push past the limits of the 4-8 GPUs on a single node. We can solve this issue by summoning the fourth (and last) parallelism dimension: “pipeline parallelism”.</p>
592
- <p><img alt="image.png" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2044.png" /></p>
593
  <p>Pipeline Parallelism is conceptually very simple –we’ll simply spread the layers of our model across GPUs – but the devil lies in implementing it efficiently. Let’s dive in it!</p>
594
  <h2>Splitting layers on various nodes - All forward, all backward</h2>
595
  <p>So, let’s say we simply spread the layers on several devices, e.g. a first GPU will take the first few layers and a second GPU will take the second part of the models and so on. The forward pass through our model now simply involves sequentially passing the batch of data along the model and thus successively using each compute device.</p>
596
  <p>We have a direct first advantage: the required interconnect bandwidth stays quite low as we only send moderate-sized activations at a handful of location along the model depth. This is a huge difference e.g. compared to the communication in Tensor Parallelism, happening several times within each layer.</p>
597
  <p>But maybe you start feeling a glimpse of the troubles to come: “sequentially” and “successively”?!? This doesn’t sound very efficient in the world of parallel computation, especially after our discussion about computation and communication overlap.</p>
598
  <p>Indeed reader! The main challenge in pipeline parallelism will be how to efficiently circumvent the sequential nature of PP to keep our GPU busy at all times and avoid having one GPU computing while the others are waiting. Here is how our GPU utilization is looking when doing a naive and simple forward and backward pass through the model where the numbers indicate the model layers:</p>
599
- <p><img alt="An example of Pipeline parallelism for a model with 16 layers distributed across 4 GPUs. The numbers correspond to the layer IDs." src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2045.png" /></p>
600
  <p>An example of Pipeline parallelism for a model with 16 layers distributed across 4 GPUs. The numbers correspond to the layer IDs.</p>
601
  <p>The remaining idle time is indicated in grey and usually called the “bubble” and the sight of this probably break your heart after we spent so much time optimizing throughput.</p>
602
  <p>We can quantify how efficient a pipeline setup is by looking at how much time we loose because of the bubble. Let’s say $t_f$ and $t_b$ are the times for the forward and backward pass, respectively, as measured for one microbatch and one stage of the pipeline (a simple assumption is often to have $t_b \approx 2 \times t_f$ which you can see on the above graph). If we could perfectly parallelize the ideal total time would be $t_{id}=t_f + t_b$. However, we can count on the graph that due to the pipeline bubble there is additional time of $t_{pb}=(p-1)*(t_f+t_b)$ (where $p$ is the degree of pipeline parallelism, i.e the number of GPU on the above graph) ie. the time each GPU is waiting while other GPUs are computing.</p>
@@ -607,7 +632,7 @@ $$</p>
607
  <p>As we add more stages the bubble time thus increases and the utilization drops.</p>
608
  <p>Thankfully, various pipeline parallelism schemes have been designed to reduce the size of the bubble which as you can see on this naive example can be very large in a naive implementation.</p>
609
  <p>Let’s take a first tool out of our toolbox and think about splitting our batch into smaller bit-sized portions which can be processed in parallel or almost, like we did before in data parallel for instance. Now when the second GPU is busy processing micro-batch 1, the first GPU can already start processing micro-batch 2. Here is a schedule using 8 micro-batches:</p>
610
- <p><img alt="image.png" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2046.png" /></p>
611
  <blockquote>
612
  <p>Note: before the numbers indicated the layers but in all pipeline parallel plots from now including this one it indicates a microbatch. You can think of each square here to contain several layers as seen in the previous figure.
613
  </p>
@@ -655,9 +680,9 @@ $$</p>
655
  <p>Since the memory explosion is triggered by the activation we store for the backward pass, let’s try to see if we can start performing the backward pass while we are still performing other forward part of the computation. This will allow us to drop some of the activations we need for the backward pass as soon as possible.</p>
656
  <h2>One-forward-one-backward and LLama 3.1 schemes</h2>
657
  <p>This schedule is called <strong>one-forward-one-backward</strong> <strong>(1F1B)</strong> as the middle/steady state involves alternatively performing one forward and one backward pass. The general idea is to start performing the backward pass as soon as possible. The schedule looks like this:</p>
658
- <p><img alt="image.png" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2047.png" /></p>
659
  <p>The bubble still has the same size so our training efficiency is not significantly improved. However we only need to store activations for $p$ micro-batches instead of $m$ which quite reduce the activation memory explosion we had in the AFAB schedule. As a consequence we can add more microbatches which then will actually reduce the bubble.</p>
660
- <p><img alt="image.png" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2048.png" /></p>
661
  <p>A major complexity of this setup, visible on the above graph is how forward and backward passes are not cleanly consecutive anymore but performed in parallel across devices. This means we will have to schedule the switch from forward to backward passes independently on each device instead of in a simple and common central training loop as usual.</p>
662
  <p>This is one of the reason implementing Pipeline Parallelism usually requires rather extensive modifications to training code as well as modeling code.</p>
663
  <p>Here is the example training loop from the above gist:</p>
@@ -734,32 +759,32 @@ return logging_loss
734
  <p>Well it turns out this is possible if we are willing to bring in a few additional communications. Time to talk about “<strong>Interleaved Stages</strong>”.</p>
735
  <p>Up to now we’ve sliced our model naively along the model depth dimensions, locating for instance layers 1-4 on the first GPU and layers 5-8 on the second GPU. But there are other ways we could think about slicing our layers, e.g. having odd layers 1, 3, 5, 7 on the first GPU and even layers 2, 4, 6, 8 on the second GPU.</p>
736
  <p>This can be seen in general as a kind of “looping pipeline” where a micro-batch will move in circles from one GPU to the next as it goes through the forward pass through the model.</p>
737
- <p><img alt="image.png" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2049.png" /></p>
738
  <p>As a consequence we see additional communications happening as the model goes several times through each GPU for the same computation that previously just took one pass. However, each forward and backward pass is divided by a factor of $v$, where $v$ is the number of stages or model chunks per GPUs as we are able to better interleave forward and backward passes. </p>
739
  <p>$$
740
  t_{pb} = \frac{(p-1)<em>(t_f+t_b)}{v} \
741
  r_{bubble} = \frac{1}{v}\frac{(p-1)</em>(t_f+t_b)}{m<em>(t_f+t_b)} = \frac{p-1}{v</em>m}
742
  $$</p>
743
  <p>So we can now decrease the bubble by adding microbatches and interleaved stages, but note that quantitatively, the amount of communication also increases by 𝑣 so it’s a trade off. In the following plot you can see several configurations for a PP setup with $p=8$, where the special case of $m=1, v=1$ corresponds to naive pipeline parallelism and the configurations with $v=1$ are AFAB or 1F1B setups and $v \neq 1$ are interleaved configurations.</p>
744
- <p><img alt="image.png" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2050.png" /></p>
745
  <p>Scheduling also becomes more complex here as we need to decide on a GPU whether we are prioritizing at a given moment earlier micro-batches meaning that we close the forward and backward loops as fast as possible (so called “depth-first”, i.e. prioritizing getting batches out of the model as fast as possible) or we prioritize to first complete the forward passes of all microbatches in the queue before going over to backward passes (so called “breadth-first” i.e. prioritizing filling in the pipeline as much as possible). This is explained in details in <a href="https://arxiv.org/pdf/2211.05953">https://arxiv.org/abs/2211.05953</a>.</p>
746
  <p>You now have all the elements to understand the pipeline parallelism approach in Llama 3.1 which is using a one-forward-one-backward setup with interleaved stages and a priority setting tuneable between depth-first and bread-first.</p>
747
- <p><img alt="image.png" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2051.png" /></p>
748
  <p>However, we haven’t reached the end of possible pipeline schedules and recently some methods have been proposed to reduce the bubble to virtually zero! Peaked your curiosity? Let’s have a look!</p>
749
  <h2>Zero Bubble and DualPipe</h2>
750
  <p>There are even more sophisticated ways to reduce the bubble more and reached close to a “zero bubble” regime. The secret here is to split at an even finer-grained level the operations involved in order to interleave them in the most efficient way. For instance the pipeline implementation approach in DeepSeek V3/R1, called DualPipe reach close to a zero bubble regime.</p>
751
  <p>Let’s very quickly see how this can work by detailing briefly the <a href="https://arxiv.org/abs/2401.10241">ZeroBubble</a> work which is a precursor to DualPipe. The base observation of ZeroBubble is that a backward through a matrix multiplication involve actually two separated operations: backward for the inputs (B) and the backward for the weights (W):</p>
 
752
  <p><img alt="image.png" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2052.png" /></p>
753
- <p><img alt="image.png" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2053.png" /></p>
754
  <p>While the output of B, the backward pass for the input, is necessary for performing the backward pass of the lower layers, the backward pass of the weights, W, is not necessary for the rest of the backward pass and generally only need to be performed before the optimiser step. This means W can be flexibly scheduled anywhere after the corresponding B of the same stage. This allows for strategic placement of W to fill the pipeline bubbles. The ZB-H2 schedule on the top right is an example of (theoretical) schedule with zero bubble taking advantage for this fine-grained decomposition.</p>
755
  <p>DeepSeek’s DualPipe propose an extension of this decomposed approach to the case of two stream propagating from both sides of the PP ranks and being interleaved to minimize even further idle time in the GPUs are displayed in the following scheduling graph</p>
756
- <p><img alt="image.png" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2054.png" /></p>
757
  <p>The ZeroBubble and DualPipe schedules are a bit too complex for us to give here code snippets but you should start to have a general idea of the concepts involved. In practice, optimizing these schedules requires careful measurements of the time for each operations followed by a scheduling algorithm able to find the most optimal allocation of time given the constrains. See for instance in the <a href="https://arxiv.org/abs/2401.10241">ZeroBubble</a> paper for a discussion of the heuristics and algorithms to perform such a scheduling.</p>
758
  <h1>Expert parallelism</h1>
759
  <p>One more ~~thing~~ parallelism.</p>
760
  <p>Mixture-of-expert models have gained some traction with models such as Mixtral or more recently DeepSeek-V3/R1! The basic idea is that instead of having a single feedforward module per layer we can have several and route tokens through different ones depending on their context.</p>
761
  <p>So whereas Context parallelism</p>
762
- <p><img alt="https://arxiv.org/pdf/2407.06204" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2055.png" /></p>
763
  <p><a href="https://arxiv.org/pdf/2407.06204">https://arxiv.org/pdf/2407.06204</a></p>
764
  <p>This design makes it very easy to add a new parallelism paradigm: Expert parallelism (EP). Since the feedforward layers are fully independent we can simply put each expert’s feedforward layer on a different worker. Compared to TP it’s much more lightweight, since we don’t need to split the matrix multiplication, we just need to route the hidden states of a token to the right expert. There are several tricks to make EP work in practice, closely tied to model design. For instance, DeepSeek-V3 enforces a constraint in the router, ensuring that each token is sent to at most M nodes (in their case, 4) to reduce communication overhead.</p>
765
  <p>Congratulation reader, we’ve now covered all current 4 directions of parallelism:</p>
@@ -785,7 +810,7 @@ $$</p>
785
  <p>In contrast to Pipeline Parallelism, Tensor Parallelism is naturally complementary and interoperable with ZeRO-3. For instance, if a model’s submodules (e.g. layers or layer block) is too large to fit in a GPU when rematerialised by ZeRO-3, there is no other obvious choice than to perform partial local operation and use Tensor Parallelism for this block/sub-model, combined with other dimension of parallelism like ZeRO-3 or PP as we saw above.</p>
786
  <p>Combining ZeRO-3 and TP doesn’t raise any specific issues except how to organize the GPU in groups for each parallelism dimension. As detailed above, TP will typically be kept for high-speed intra-node communications while ZeRO-3 can use parallelism groups spanning lower speed inter-node communications as the overlap with computation is easy to perform.</p>
787
  <h1>How to Find the Best Training Configuration</h1>
788
- <p><img alt="image.png" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2056.png" /></p>
789
  <p>We’ve now covered all the parallelism techniques that are actually used to distribute and training larger models. There remain a general question: which ones should we choose and which ones are best combined? We touched a little bit on this at the end of the last section but in this section we will walk through the decision process step by step.</p>
790
  <p>| <strong>Method</strong> | <strong>Memory savings</strong> | <strong>Parallel/sharding dimension</strong> | <strong>Disadvantage</strong> |
791
  | --- | --- | --- | --- |
@@ -800,7 +825,7 @@ $$</p>
800
  <p>On the other hand by using PP or TP we can reduce the DP rank significantly, increase the local batch size and thus reduce the necessary global communication. In general, it is natural to combine 4D parallelism with at least ZeRO-1/2 to save optimizer and gradient memory and keep the number of PP stages or TP ranks under control.</p>
801
  <p>Overall, most training schedule past a certain size of the models wil tend to combine several of these methods.</p>
802
  <p>Let’s try synthesize the decision process into a relatively simple tree structure: </p>
803
- <p><img alt="image.png" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2057.png" /></p>
804
  <p>To explain briefly, data parallelism is the most efficient method, and you should always prioritize it when memory is not a concern. If communication is not a concern and you can keep the BS/GPU at a big enough value to make good use of the GPU MatMul, ZeRO is an easy method to remove memory bottlenecks and stay close to a simple DP implementation. However, on larger clusters you’ll probably be able to make efficient use for more 4D parallelism. In this case, starting with tensor parallelism is the most direct way to reduce memory usage and is generally faster than pipeline parallelism within a single node(8 GPUs). However, in scenarios with long contexts, the primary memory usage will tend to shifts from model weights, gradients, and optimizer states to activation values. In such cases, context parallelism becomes more beneficial than pipeline parallelism. Note that this is not an exact recipe and you should think of this more as a starting point of hyperparameters to run your own benchmarks. For instance sometimes TP mixed with PP can be more efficient, even if TP&lt;8 and ZeRO-1/2 can make sense to mix in with 4D parallelism as well. </p>
805
  <p>This concludes our very deep dive into the distribution methods of 4D parallelism and ZeRO. However, besides scaling our model efficiently across GPUs there is another way to improve model throughput and memory management. </p>
806
  <p>Time to turn the lights off and activate CUDA mode! </p>
@@ -812,17 +837,17 @@ $$</p>
812
  <h3>A primer on GPU</h3>
813
  <p>Generally, GPUs have a very hierarchical organization. In this primer we’ll keep the discussion at the concept levels that are necessary for the rest of our presentation.</p>
814
  <p>On the compute side, GPUs consist of an array of compute units called <strong>Streaming Multiprocessors</strong> (SM). Each SM contains and controls a set of streaming processors, also known as cores. For example, an Nvidia H100 GPU has 132 SMs with 128 cores per SM, resulting in a total of 16,896 cores (see <a href="https://resources.nvidia.com/en-us-tensor-core">https://resources.nvidia.com/en-us-tensor-core</a> for details), each capable of handling multiple threads simultaneously.</p>
815
- <p><img alt="Original figure from https://blog.codingconfessions.com/p/gpu-computing." src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2058.png" /></p>
816
  <p>Original figure from <a href="https://blog.codingconfessions.com/p/gpu-computing">https://blog.codingconfessions.com/p/gpu-computing</a>.</p>
817
  <p>The memory side is also highly hierarchical with several layers of cache and memory: <strong>Registers</strong> are the smallest units and are private to the threads during executions, <strong>Shared Memory</strong> and <strong>L1 cache are</strong> shared between the threads running on a single SM, higher up is the <strong>L2 cache</strong> shared by all SMs, finally there is the <strong>Global Memory</strong> which is the largest memory on the GPU (the advertised 80 GB for a H100 for instance) but also the slowest to access and query. </p>
818
- <p><img alt="Original figure from https://www.youtube.com/watch?v=ZQKMZIP3Fzg" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2059.png" /></p>
819
  <p>Original figure from <a href="https://www.youtube.com/watch?v=ZQKMZIP3Fzg">https://www.youtube.com/watch?v=ZQKMZIP3Fzg</a></p>
820
  <p>The goal of GPU will be to run as many workloads as possible, in parallel, on the GPU cores, by taking advantage of this hierarchical organization of compute/memory.</p>
821
  <p>A piece of code running on a core of the GPU is called a <strong>kernel</strong>. It can be written at a high-level in <strong>CUDA</strong> or <strong>Triton</strong> for instance, and is then compiled to Parallel Thread Execution, PTX, the low-level assembly used by NVIDIA GPUs.</p>
822
  <p>To run the kernel, you will also need a specific code part (called <strong>host code</strong>) which is executed on the <strong>CPU</strong>/host and will take care of preparing data allocations and loading data and code.</p>
823
- <p><img alt="image.png" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2060.png" /></p>
824
  <p>Figure 5: Host code for a CUDA kernel for adding two vectors from <a href="https://blog.codingconfessions.com/p/gpu-computing">https://blog.codingconfessions.com/p/gpu-computing</a></p>
825
- <p><img alt="image.png" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2061.png" /></p>
826
  <p>Figure 6: Device code containing the definition of the vector addition kernel from <a href="https://blog.codingconfessions.com/p/gpu-computing">https://blog.codingconfessions.com/p/gpu-computing</a></p>
827
  <p>Kernels are generally scheduled as follow:</p>
828
  <ul>
@@ -851,7 +876,7 @@ $$</p>
851
  def elu(x, alpha=1.0):
852
  return torch.where(x &lt; 0, alpha * (torch.exp(x) - 1), x)</code></p>
853
  <p>The distinction between the compiled and non-compiled versions is striking, especially given that we only added a single decorator. This remarkable difference is illustrated in the graph below (N is the number of columns) :</p>
854
- <p><img alt="image.png" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2062.png" /></p>
855
  <p>However, if this performance increase is insufficient, you can consider implementing Triton kernels. As a starting point, you can take a look at the triton kernel generated by <code>@torch.compile</code> . To do so, you simply need to set the environment variable <code>TORCH_LOGS</code> to “output_code” : </p>
856
  <p><code>bash
857
  export TORCH_LOGS="output_code"</code></p>
@@ -899,7 +924,7 @@ tl.store(output_ptr + block_indices, output_values, valid_mask)
899
  <p>```</p>
900
  <p>Here, <code>tl.program_id(0)</code> provides a unique block ID, that we use to determine which section of data that block will process. Using this block ID, <code>block_start</code> calculates the starting index for each block’s section, while <code>block_indices</code> specifies the range of indices within that section. A <code>valid_mask</code> ensures that only indices within <code>num_elements</code> are processed, safely loading the data with <code>tl.load</code>. The ELU function is then applied, modifying values based on whether they're negative, and results are written back to memory with <a href="http://tl.store"><code>tl.store</code></a> .</p>
901
  <p>When we benchmark the generated kernel using <code>triton.testing.Benchmark</code> we have the following performance : </p>
902
- <p><img alt="image.png" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2063.png" /></p>
903
  <p>This standalone kernel demonstrates superior performance with smaller sizes compared to <code>@torch.compile</code> but this is likely here just an artifact from the compilation time of torch. compile. In any case, instead of starting from scratch, we can focus on optimizing this generated kernel, saving us time in the process. </p>
904
  <p>However, in Triton, sometimes, we cannot fully achieve the peak performance of the device due to limitations in handling shared memory and scheduling within streaming multiprocessors (SMs). Our access is restricted to blocks, allowing us only to manage the scheduling of blocks across SMs. To gain even more control, we will need to implement kernels in CUDA, where we have access to all the underlying components. </p>
905
  <p>In CUDA, there are various techniques that can be employed to make kernels more efficient; we will present just a few. These include optimizing memory access patterns to reduce latency, using shared memory to store frequently accessed data, and managing thread workloads to minimize idle times. In summary, the tools for writing code to execute instructions on the GPU are:</p>
@@ -929,12 +954,12 @@ tl.store(output_ptr + block_indices, output_values, valid_mask)
929
  <p>}
930
  ```</p>
931
  <p>Here’s an excellent visualization of the kernel from this fantastic <a href="https://siboehm.com/articles/22/CUDA-MMM">blogpost</a> : </p>
932
- <p><img alt="image.png" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2064.png" /></p>
933
  <p>However, when profiling this kernel with a tool like <code>ncu</code>, we can see issues, including low memory throughput and uncoalesced memory accesses.</p>
 
934
  <p><img alt="image.png" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2065.png" /></p>
935
- <p><img alt="image.png" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2066.png" /></p>
936
  <p>The reason for this is that in this kernel, two threads in the same block with Thread IDs <code>(0, 0)</code> and <code>(1, 0)</code> (which will end up in the same warp) will both load from the same column of matrix <code>B</code> but different rows of matrix <code>A</code>. Since matrix elements are stored in row-major order (meaning each row's elements are in consecutive memory addresses, as shown in the figure below), in the first iteration with <code>i = 0</code>, thread <code>(0, 0)</code> will load $A_{0,0}$, and thread <code>(1, 0)</code> will load $A_{1,0}$. These elements are not stored close to each other in memory, and this misalignment repeats across all iterations along the shared dimension, preventing memory accesses from being coalesced.</p>
937
- <p><img alt="image.png" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2067.png" /></p>
938
  <p>To improve our kernel we can change the way the coordinates x and y are calculated like the following : </p>
939
  <p>```cpp
940
  const int x = blockIdx.x * BLOCKSIZE + (threadIdx.x / BLOCKSIZE);
@@ -949,14 +974,14 @@ const int y = blockIdx.y * BLOCKSIZE + (threadIdx.x % BLOCKSIZE);</p>
949
  ```</p>
950
  <p>Instead of using a 2D block, we switch to a 1D block and redefine how we determine the values of <code>x</code> and <code>y</code>. In this new method, threads within the same warp (which have close <code>threadIdx.x</code> values) will share the same <code>x</code> value but have different <code>y</code> values. This means that they will load the same row of matrix <code>A</code> but different columns of matrix <code>B</code>. As a result, memory accesses can be coalesced for a row-major matrix.</p>
951
  <p>When we profile our new kernel, we notice that the warning about uncoalesced memory accesses has disappeared, and <strong>the GPU's memory throughput has increased by approximately 10 times</strong>.</p>
952
- <p><img alt="image.png" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2068.png" /></p>
953
  <p>We also notice that the execution time of the kernel <strong>decreases by 10x</strong> !</p>
954
  <p>Let’s cover another technique you will often see mentioned in the litterature: tiling.</p>
955
  <h3>Tiling</h3>
956
  <p>Tiling is a technique that leverages <em>shared memory</em> to optimize memory access patterns. As we mentioned above, the shared memory is a small, fast memory accessible by all threads within a block. It allows data to be reused by multiple threads, reducing the need to repeatedly load data from slower global memory.</p>
957
  <p>In matrix multiplication for example, each thread in a block may need elements from two matrices, say A and B. If each thread independently loads the row and column it needs from global memory, we end up with many redundant loads, as multiple threads in a block will access overlapping data. Instead, we can use tiling to load a block (or tile) of A and B into shared memory just once, allowing all threads in that block to reuse the same shared data.</p>
958
  <p>In the tiling approach, each iteration involves all threads within a block cooperatively loading two tiles—one from matrix A and another from matrix B —into shared memory. Specifically, threads load a tile of matrix A (of size <code>BLOCK_SIZE_M</code> by <code>BLOCK_SIZE_K</code>) and a tile of matrix B (of size <code>BLOCK_SIZE_K</code> by <code>BLOCK_SIZE_N</code>). Once the tiles are in shared memory, the threads perform matrix multiplication on these tiles, enabling efficient computation since all necessary data is quickly accessible. The results of the tile multiplication are stored in an accumulation matrix that holds intermediate results. After each iteration, the results from the current tile multiplication are added to this accumulation matrix, continuing until all tiles from both matrices have been processed. </p>
959
- <p><img alt="From https://cnugteren.github.io/tutorial/pages/page4.html" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2069.png" /></p>
960
  <p>From <a href="https://cnugteren.github.io/tutorial/pages/page4.html">https://cnugteren.github.io/tutorial/pages/page4.html</a></p>
961
  <p>The important parts to understand the implementation are below (for simplicity we consider a square shaped tile) : </p>
962
  <p>```cpp</p>
@@ -988,7 +1013,7 @@ C[localRow * N + localCol] = sum;
988
  <p>When benchmarking this kernel using ncu, we noticed that the memory throughput increased to 410 Gb / s, and the kernel execution time decreased by ~43% achieving a ~6.6 TFLOPs performance</p>
989
  <h3>Thread Coarsening</h3>
990
  <p>The tiling technique has significantly improved the performance of our kernel. However, when analyzing the warp states which quantify how many cycles were spent in each state, we observe the following:</p>
991
- <p><img alt="image.png" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2070.png" /></p>
992
  <p>The meaning of the states can be found in the <a href="https://docs.nvidia.com/nsight-compute/ProfilingGuide/index.html#metrics-reference">Profiling Guide</a>, specifically in the <strong>Warp Stall Reasons</strong> section. There we can read that : </p>
993
  <blockquote>
994
  <p>smsp__pcsamp_warps_issue_stalled_mio_throttle : Warp was stalled waiting for the MIO (memory input/output) instruction queue to be not full. This stall reason is high in cases of extreme utilization of the MIO pipelines, which include special math instructions, dynamic branches, as well as shared memory instructions. When caused by shared memory accesses, trying to use fewer but wider loads can reduce pipeline pressure.
@@ -1003,10 +1028,10 @@ C[localRow * N + localCol] = sum;
1003
  <h2>Flash Attention 1-3</h2>
1004
  <p>Flash attention is a technique pioneered by <a href="https://tridao.me">Tri Dao</a> that optimizes the attention computations by writing custom CUDA kernels to make it much faster <em>and</em> more memory efficient. The idea behind Flash Attention is to make efficient use of the various memories of the GPU to avoid using too much the slowest global memory of the GPU (confusingly called the High Bandwidth Memory, HBM 🫠) </p>
1005
  <p>A basic implementation of the attention mechanism involve a lot of transfer between memory and workers. It requires materializing the S and P matrices in HBM which means that the results need to be sent to HBM and then back to SRAM for the next computations:</p>
1006
- <p><img alt="image.png" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2071.png" /></p>
1007
  <p>Since bandwidth is much lower in HBM this introduces a severe bottleneck in the attention computation. Can we do better? Tri Dao says yes!</p>
1008
  <p>The key element is to compute the S matrices in small pieces which can fit in the smaller shared memory of the SM. But we can do even better and avoid materializing the very large S matrix all together in favor of keeping only the necessary statistics for computing the normalization factor of the softmax. So we can compute part of $O$ directly in one computation in SRAM rather than moving intermediate results back and forth. In this case, not even do we make use of the shared memory but we also release the memory bottleneck resulting from materializing one of the largest activation matrices in the model (at long context length), the attention matrix.</p>
1009
- <p><img alt="From the FLASH-ATTENTION paper (https://arxiv.org/pdf/2205.14135)" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2072.png" /></p>
1010
  <p>From the FLASH-ATTENTION paper (<a href="https://arxiv.org/pdf/2205.14135">https://arxiv.org/pdf/2205.14135</a>)</p>
1011
  <p>The idea of flash attention resolves so many bottlenecks in model training that it has quickly become the default way to perform attention in all transformers:</p>
1012
  <ul>
@@ -1046,10 +1071,10 @@ C[localRow * N + localCol] = sum;
1046
  </p>
1047
  </blockquote>
1048
  <p>Reducing the total number of bits comes at a price (no free lunch here either), but we have some control over how to pay. Either we can sacrifice more bits on the mantissa or exponent. For this reason there exist also two float8 formats, named according to exponent and mantissa, to flexibly choose the most appropriate format. We can look at the possible range of numbers for each format:</p>
1049
- <p><img alt="image.png" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2073.png" /></p>
1050
  <p>We can see that float32 spans 80 orders of magnitude and float16 sacrifices a lot of range while bfloat16 maintains the full range. The two float8 formats reduce the range even further where e5e2 can maintain float16 range and e4m3 has an even smaller ranger.</p>
1051
  <p>How come some format are able to maintain the range and other not? Let’s investigate the resolution by plotting 10,000 points between 1 and 2. Each point will be rounded to the nearest representable number in each format:</p>
1052
- <p><img alt="image.png" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2074.png" /></p>
1053
  <p>We can see here that bfloat16 maintained the range of float32 over float16 but did this with the cost of sacrificing more precision. In case of float8 the situation is even more dire as e4m3 can represent 7 and e5m2 only 3 number on the interval 1-2.</p>
1054
  <p>A common metric to measure a formats resolution is epsilon: the first representable number after 1.00. We can see that for the float32 format $10^{-4}$ is an upper bound (it’s actually $1.19^{-7}$). For float16 it is \tilde 10^{-3} and for bfloat 10x higher still. </p>
1055
  <p>The idea of mixed precision training is to use some of these lower precisions formats while maintaining the performance of full precision training. It turns out we <strong>can’t</strong> totally abandon float32 and usually will need to maintain some parts in full precision. </p>
@@ -1067,7 +1092,7 @@ C[localRow * N + localCol] = sum;
1067
  <h3>FP8 pretraining</h3>
1068
  <p>Even if we perfectly overlap communication with computation, we always eventually run into the low level theoretical FLOPS limit of the hardware itself, i.e. the efficiency of each individual operation on our hardware. This is where numerical precision becomes crucial. For instance, on NVIDIA's H100 GPU, FP8 matrix multiplications (GEMM operations) achieve twice the theoretical FLOPS of bfloat16, making lower-precision training an attractive path for further optimization.</p>
1069
  <p>Recent research - including <a href="https://arxiv.org/abs/2310.18313">FP8-LM</a>, <a href="https://github.com/pytorch/ao/tree/main/torchao/float8#torchaofloat8">torchao</a>, and <a href="https://arxiv.org/abs/2412.19437">DeepSeek-V3</a> - has demonstrated the potential of FP8 training for large-scale models. Still, FP8 pretraining introduces a significant challenge: stability. At lower precision, numerical instability often leads to loss divergence, making it difficult to match the accuracy of higher-precision training.</p>
1070
- <p><img alt="image.png" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2075.png" /></p>
1071
  <p>As <a href="https://arxiv.org/abs/2309.14322">[Wortsman et al.]</a> observed, instability increases as learning rates rise for a fixed model size, making FP8 pretraining particularly tricky.</p>
1072
  <p>The first, successful, very large scale training with FP8 mixed precision was publicly reported on DeepSeek-V3. The authors carefully analyzed each operation of the forward pass (Fprop) as well as the activation (Dgrad) and weight (Wgrad) backward pass. Similar to BF16 mixed precision training, some aggregation and master weights are kept in higher precision while the operations themselves are performed in FP8. </p>
1073
  <p><img alt="Screenshot 2025-02-09 at 22.20.28.png" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/Screenshot_2025-02-09_at_22.20.28.png" /></p>
@@ -1171,11 +1196,11 @@ C[localRow * N + localCol] = sum;
1171
  </ul>
1172
  <p>Throughout this blogpost we’ll scale LLM training from one to hundreds of GPUs. This will require the communication and synchronization of weights, gradients, and data between all the machines. There’s a set of distributed patterns to achieve exactly that called <strong>collective operations</strong>. In this section we’ll do a small crash course of <em>Broadcast, AllReduce, Scatter</em> and co. but if you are already familiar with these patterns feel free to move directly to [SECTION I], otherwise let’s get ☕ #1 (or your neural stimulant of choice) and let’s dig in! </p>
1173
  <p>The general setup is that we have a number of independent nodes which could be CPU cores, GPUs, or compute nodes. Each performs some computation and then we want to communicate the result or parts of it to the other nodes for the next computation step (t+1). </p>
1174
- <p><img alt="image.png" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2076.png" /></p>
1175
  <p>Maybe we need to send the result from one node to all other nodes, or we need to sum all the intermediate results from each node to report the overall result. Usually, there is one node with an elevated status that plays a central role, here denoted with <code>root</code> that is the target or source of some operations. Let’s start with one of the simplest primitives: a broadcast operation.</p>
1176
  <h3>Broadcast</h3>
1177
  <p>A very common pattern is that you have some data on Node 1 and you want to share it with all the other nodes so they can do some computation with the data. The broadcast operation does just that:</p>
1178
- <p><img alt="image.png" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2077.png" /></p>
1179
  <p>Collective operations are natively provided by PyTorch so we can easily write a small example that demonstrates how broadcasting works. We first need to initialize a process group with <code>dist.initi_process_group</code> which sets up the communication backend (we’ll talk about NCCL later), it determines how many workers (aka nodes) exists and assigns a rank to each one (which we can get with <code>dist.get_rank</code>). Finally, it establishes a connection between the workers.</p>
1180
  <p>To showcase the <code>broadcast</code> operation, let's create a tensor with non-zero values on <code>rank=0</code> and tensors full of zeros on the other workers. We then distribute the <code>rank=0</code> tensor to all other ranks with <code>dist.broadcast(tensor, src=0)</code> :</p>
1181
  <p>```python
@@ -1207,7 +1232,7 @@ After broadcast on rank 2: tensor([1., 2., 3., 4., 5.], device='cuda:2')
1207
  <p>Great, seems like it works as expected. Note that the rank messages can be printed out of order as we have no control over which print statement is executed first (we ordered them here for readability). Now let’s move on to the Reduce and AllReduce patterns! </p>
1208
  <h3>Reduce &amp; AllReduce</h3>
1209
  <p>Reduce patterns are among the most fundamental patterns in distributed data processing. The idea is that you want to combine the data present on each node through a function <code>f()</code> which can be for instance summation or averaging. In the Reduce paradigm the result is sent to the root node only, whereas in the AllReduce case the result is broadcasted to all nodes:</p>
1210
- <p><img alt="image.png" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2078.png" /></p>
1211
  <p>Of course no magic “free flying” node that can perform this operation and generally each node does a partial computation in a ring or tree structure of the nodes. Here is a simple example: let’s say we need to compute a sum of numbers on each nodes and our nodes are connected in a ring pattern. The first node sends its number to a neighbour which adds its number to the received number before forwarding it to the next neighbour. At the end of a round along the ring of nodes, the first node will receive the total sum.</p>
1212
  <p>Here’s the code to run a simple Reduce operation summing the tensors, we specify the operation to use with <code>op=dist.ReduceOp.SUM</code> (you can find more information on the supported operations in the docs: <a href="https://pytorch.org/docs/stable/distributed.html#torch.distributed.ReduceOp">https://pytorch.org/docs/stable/distributed.html#torch.distributed.ReduceOp</a>):</p>
1213
  <p>```python
@@ -1278,7 +1303,7 @@ After all_reduce on rank 2: tensor([6., 6., 6., 6., 6.], device='cuda:2')
1278
  <p>Now let’s turn to our next distributed communication operation. In many real cases, each node individually perform many complex computations and we need to share the final results among nodes. Gather and AllGather are the operations we want to use in this case. Let’s take a look! </p>
1279
  <h3>Gather &amp; AllGather</h3>
1280
  <p>Gather and AllGather are quite similar to the Broadcast in that they allow distributing data among node without modification. The main difference to Broadcast is that there is not one value we need to share from one node to all other nodes but each node has an individual chunk of data that we want to either gather all data on one node (in case of Gather) or gather all data on all nodes (in the case of AllGather). A picture being worth 1000 words, let’s take a look:</p>
1281
- <p><img alt="image.png" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2079.png" /></p>
1282
  <p>Note that the dashed lines indicate that some data actually doesn’t move at all (since it’s already present on the node).</p>
1283
  <p>In the case of the gather operation we need to prepare a container objects where the gathered tensors can be stored in this example the <code>gather_list</code>:</p>
1284
  <p>```python
@@ -1334,7 +1359,7 @@ After all_gather on rank 2: [tensor([1., 1., 1., 1., 1.], device='cuda:2'),
1334
  <h3>Scatter &amp; ReduceScatter</h3>
1335
  <p>As the name subtly suggests, the goal of the Scatter operation is to take data on one node and distribute slices of it to all other nodes. It’s thus different from the Broadcast operation which copy data without slicing and it’s the logical the inverse of the Gather operation.</p>
1336
  <p>The ReduceScatter pattern is slightly more complex: imagine you apply an operation like in the Reduce case but instead of moving the result to just one node we also distribute it evenly to all nodes:</p>
1337
- <p><img alt="image.png" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2080.png" /></p>
1338
  <p>The Scatter operation is written in code as the opposite of the Gather: instead of preparing a list of tensors as target we prepare the source data as a list of tensors we want to distribute. We also need to specify the <code>src</code>:</p>
1339
  <p>```python
1340
  def example_scatter():
@@ -1393,7 +1418,7 @@ After ReduceScatter on rank 2: tensor([ 36., 288.], device='cuda:2')
1393
  <p>We now have seen the main building block of distributed operations but before we see them in action let’s have a look at a special operation used for synchronization: the Barrier.</p>
1394
  <h3>Barrier</h3>
1395
  <p>The Barrier is a simple operation to synchronize all nodes. A barrier is not lifted until all nodes have reached it. Then only are they allowed to continue with further computations:</p>
1396
- <p><img alt="image.png" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2081.png" /></p>
1397
  <p>We can easily simulate delayed nodes by setting up a different sleep time on each node and see how long it takes for all of them to pass the barrier:</p>
1398
  <p>```python
1399
  def example_barrier():
@@ -1478,14 +1503,14 @@ import torch.nn.functional as F</p>
1478
  <p>print(p.key_averages().table(sort_by="cuda_time_total", row_limit=8))</p>
1479
  <p>```</p>
1480
  <p>This would print aggregated profiling results sorted by the total CUDA time, and the output would be:</p>
1481
- <p><img alt="image.png" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2082.png" /></p>
1482
  <p>You can also try to inspect the trace as we previously mentioned on <code>chrome://tracing/</code> </p>
1483
  <blockquote>
1484
  <p>If you're new to this tool, you can navigate the trace by using the right and left arrow keys. Additionally, you can zoom in and out by holding the <strong>Alt</strong> key while scrolling left or right with your mouse.
1485
  </p>
1486
  </blockquote>
1487
  <p>After zooming in, you can observe the flow of operations when calling <code>layer_norm</code> in this trace:</p>
1488
- <p><img alt="image.png" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2083.png" /></p>
1489
  <p>The sequence begins in the CPU (the upper section) with <code>aten::layer_norm</code>, progressing to <code>aten::native_layer_norm</code>, and then transitioning to <code>cudaLaunchKernel</code>. From there, we move on to the GPU, where the <code>vectorized_layer_norm_kernel</code> kernel is called. </p>
1490
  <blockquote>
1491
  <p>Note that you can enable memory profiling by setting <code>profile_memory</code> to <code>True</code> in the profiler. However, this can lead to more complex traces.
@@ -1498,7 +1523,7 @@ ncu --set full python layer_norm.py</code></p>
1498
  <p><code>python
1499
  ncu --set full -o output python layer_norm.py</code></p>
1500
  <p>and open the file <code>output.ncu-rep</code> with Nsight Compute, you will have a view that looks like this : </p>
1501
- <p><img alt="image.png" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2084.png" /></p>
1502
  <p>With clear warnings about compute and memory utilization, and how to make the kernel better in balancing compute and memory and achieve maximal occupancy.</p>
1503
  <p><strong>CPP extension</strong></p>
1504
  <p>If the kernel you want to profile isn’t already integrated into PyTorch, you can use PyTorch's <code>cpp_extension</code> module to easily compile and run custom CUDA code. The process is straightforward—just create your CUDA kernel in a <code>.cu</code> file, and use the <code>load</code> function from the <code>cpp_extension</code> module to load it in Python.</p>
@@ -1558,12 +1583,12 @@ $$</p>
1558
  \frac{dL}{dX} = \frac{dL}{dY} \frac{dY}{dX} = \frac{dL}{dY} W</p>
1559
  <p>$$</p>
1560
  <p>The chain rule applies here since the loss (L) depends directly on the output (Y). This equation is telling us that to get the gradient of the loss with respect to our input (dL/dX), we multiply the gradient of the loss with respect to the output (dL/dY) by our weight matrix (W).</p>
1561
- <p><img alt="image.png" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2085.png" /></p>
1562
  <p>Likewise, we can use chain rule to compute the gradient w.r.t to the weight:</p>
1563
  <p>$$
1564
  \frac{dL}{dW} = \frac{dL}{dY} \frac{dY}{dW} = \frac{dL}{dY} X
1565
  $$</p>
1566
- <p><img alt="image.png" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2086.png" /></p>
1567
  <p>Here is a snippet of code to clarify all the concepts above:</p>
1568
  <p>```python
1569
  def column_linear_forward(X, local_W, group):
@@ -1634,10 +1659,10 @@ torch.testing.assert_close(grad_W, split_tensor(W_ref_layer1.grad, dim=0), rtol=
1634
  <pre><code>example_column_row_linear()
1635
  </code></pre>
1636
  <p>```</p>
1637
- <p><img alt="image.png" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2087.png" /></p>
1638
  <p><strong>TODO</strong> add these illustrations somewhere? I found them helpful:</p>
 
1639
  <p><img alt="image.png" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2088.png" /></p>
1640
- <p><img alt="image.png" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2089.png" /></p>
1641
  <h2>A3: ZeRO-R</h2>
1642
  <p>To further optimize memory usage in large-scale training, DeepSpeed ZeRO-R introduces several techniques aimed at reducing the memory footprint of activation values during forward and backward propagation. The key strategies include partitioned activation checkpointing, the use of constant-size buffers, and memory defragmentation.</p>
1643
  <h3>$P_a:$ Partitioned Activation Checkpointing</h3>
@@ -1750,7 +1775,7 @@ torch.testing.assert_close(y_row_1, y_row_2, rtol=1e-5, atol=1e-5)
1750
  </code></pre>
1751
  <p>```</p>
1752
  <h3>Interconnect</h3>
1753
- <p><img alt="image.png" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2090.png" /></p>
1754
  <h2>How to profile your code</h2>
1755
  <p>The profiler is a tremendously useful tool and easy to use. It takes three steps to profile your program:</p>
1756
  <ol>
@@ -1782,7 +1807,7 @@ with profiler: # step 2. Wrap the training with profiler
1782
  </code></pre>
1783
  <p>```</p>
1784
  <p>After running this code, you will find <code>*.trace.json</code> files under the <code>profiler_out_dir</code>. To visualize the results, the easiest way is to open Google Chrome, go to <code>chrome://tracing/</code>, and drag the file into it. This will allow you to view the profiling results. To get more details, we invite you to check out the amazing <a href="https://pytorch.org/tutorials/intermediate/tensorboard_profiler_tutorial.html">**tutorial</a>** created by PyTorch.</p>
1785
- <p><img alt="image.png" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2091.png" /></p>
1786
  <h2>Formulas for compute / comms the balanhe balance</h2>
1787
  <p>```markdown</p>
1788
  <p>Estimates: (number of elements, need to multiply by bytes_per_element)
@@ -1884,7 +1909,7 @@ peak_bandwidth_p2p = (S / t)
1884
  <p>for a single microbatch:
1885
  -&gt; t_comm / t_compute = seq * mbs * h * peak_flops / (num_layers_in_next_pp * 32 * seq * mbs * h^2) * peak_bw = peak_flops / (32 * h*num_layers_in_next_pp) * peak_bw</p>
1886
  <p>```</p>
1887
- <p><img alt="image.png" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2092.png" /></p>
1888
  <h2>Integrating Context Parallelism with TP/SP</h2>
1889
  <p>We’ve seen that both TP/SP and CP shard the activations along sequence dimension, and both require communications in the Attention module, wouldn’t that create issues? In fact not at all!</p>
1890
  <p>In order to integrate CP with TP/SP we just have to:</p>
@@ -1894,14 +1919,14 @@ peak_bandwidth_p2p = (S / t)
1894
  <li><strong>Replace standard attention with ring attention:</strong> During the forward pass, each TP rank relies on the ring attention to compute the correct attention results during both the forward and backward passes. So all CP ranks within TP=0 for example need to all-gather the full KV sequence and calculate attention, but we store only the KV of a sequence chunk to reduce memory activations by CP.</li>
1895
  </ol>
1896
  <p><img alt="TP=0 has GPU0 and GPU2 whereas CP=0 has GPU0 and GPU1
1897
- TP/SP shards the Q/K/V heads across TP ranks (in this example GPU0 and GPU2 get QKV_green, and GPU2 and GPU3 get QKV_blue), since each head can operate independently from others, we can apply ring attention within each TP rank" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2093.png" /></p>
1898
  <p>TP=0 has GPU0 and GPU2 whereas CP=0 has GPU0 and GPU1
1899
  TP/SP shards the Q/K/V heads across TP ranks (in this example GPU0 and GPU2 get QKV_green, and GPU2 and GPU3 get QKV_blue), since each head can operate independently from others, we can apply ring attention within each TP rank</p>
1900
  <p>Context parallelism is naturally compatible with data parallelism which splits the input along the batch size dimension.</p>
1901
  <p>In fact, given an activation value of shape$[ \text{batch_size}, \text{sequence_length}, \text{hidden_dimension} ]$, data parallelism, sequence/context parallelism, and tensor parallelism split it across the 1st, 2nd, and 3rd dimensions, respectively, and these are independent of each other.</p>
1902
  <h2>The nanotron FP8 recipe</h2>
1903
  <p>However, through extensive experimentation, we identified two effective training recipes that allowed us to <strong>fully pretrain a 1B LLaMA model in FP8</strong>, covering both the forward and backward passes, while using an FP8 optimizer. More importantly, our approach successfully matched LLaMA-2’s pretraining learning rate. The result?</p>
1904
- <p><img alt="image.png" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2094.png" /></p>
1905
  <p>A loss curve that perfectly matches mixed-precision bfloat16 (bfloat16 with FP32 master weights as the baseline). We successfully tested this to train a 1B LLaMA up to 100B tokens and a 7B LLaMA up to 25B tokens.</p>
1906
  <p>Here’s what worked:</p>
1907
  <ul>
@@ -1932,9 +1957,9 @@ TP/SP shards the Q/K/V heads across TP ranks (in this example GPU0 and GPU2 get
1932
  <p>The general idea is always the same: if there are parts we will need to communicate soon between workers and that are independent of the current computation, we can parallelize, or as it is also called overlap, the communication and computation. </p>
1933
  <p>Let’s take a moment to look better at this fundamental tool for distributed training and go over the example of Ring Attention using PyTorch Profiler. In its implementation, we can overlap the sending and receiving of key/value pairs with the computation of attention scores. What does this look like?</p>
1934
  <p><strong>Non-overlapping:</strong> If we don't overlap the communication and computation, each computation (represented by the purple block) can only begin after the communication (green block) is complete and total time is the sum of communication and computation.</p>
1935
- <p><img alt="image.png" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2095.png" /></p>
1936
  <p><strong>Overlapping:</strong> However, if we manage to launch communication and computation in parallel, we eliminate the waiting time! Now we can see that the computation (green block) is launched immediately, one after the other. In this case the total time is <em>only</em> the sum of computations.</p>
1937
- <p><img alt="image.png" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2096.png" /></p>
1938
  <p>Context parallelism has helped us going past the intra-node interconnect bottleneck, which blocked us from scaling TP across nodes. However, as you probably noted, it only helps reducing the memory constraints if the activation memory dominates the memory budget due to long sequences. What if we are not working on super long sequences and the model weights alone are too big for a single node?</p>
1939
  <p>Well it turns out we have an other –quite different– option called pipeline parallelism (PP) which the time has come to explore now.</p>
1940
  <p>[TODO: comment from Nouamane on comms overlapping with DP 512]</p>
 
94
  </blockquote>
95
  <p>These items are stored as tensors which come in different <em>shapes</em> and <em>precisions</em>. The <em>shapes</em> are determined by hyper-parameters such as batch size, sequence length, model hidden dimensions, attention heads, vocabulary size, and potential model sharding as we’ll see later. <em>Precision</em> refers to formats like FP32, BF16, or FP8, which respectively require 4, 2, or 1 byte to store each single value in the tensor.</p>
96
  <p>So how can I quickly determine memory usage from these variable? One simple way is to do this empirically and just measure it.</p>
 
 
 
 
 
 
97
  <h3>Memory profiling a training step</h3>
98
  <p>Using this snippet [TODO: link to appendix A5], we can understand how memory is allocated throughout training. We can see that memory utilization is not a static thing but varies a lot during training and during a training step:</p>
99
  <p><img alt="llama-1b-memory.png" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/44e32d18-4ed6-455a-a1f7-bbbdebe2fefd.png" /></p>
 
166
  <p>It’s time to explain our first technique – called <strong><em>activation recomputation</em><em>–</em> </strong>**which will help us cap activation memory footprint. An essential tool in today’s large model training toolbox.</p>
167
  <h2><strong>Activation recomputation</strong></h2>
168
  <p>The general idea behind <strong><em>activation recomputation</em><em> –</em>also called </strong><em>gradient checkpointing</em><strong> or </strong><em>rematerialization</em><em>– </em>****is to discard some activations during the forward pass to save memory and spend some extra compute to recompute these on the fly during the backward pass. Without recomputation, we store every hidden state between two learnable operations (e.g. FF, LayerNorm etc.), such that we can use them during the backward pass to compute gradients. When we use recomputation we typically will only store activations at a few key points along the model architecture, discard the rest of activations and recompute them on the fly during the backward pass from the nearest saved activations, basically performing again a sub-part of the forward pass to trade of memory for compute. It generally looks like this:</p>
169
+ <p><img alt="image.png" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%205.png" /></p>
170
  <p>There are several strategies to select key activations to store:</p>
171
  <ul>
172
  <li><strong>Full</strong>: We checkpoint activations at the transition point between each layer of the Transformer model. This is usually called the <code>full</code> strategy since it requires a forward pass through each layer essentially adding a full forward pass during the backward pass. This strategy saves the most memory but is the most expensive one in terms of compute. It generally increases the compute cost and time by up to 30-40% which is very noticeable.</li>
 
198
  bs = gbs = mbs \times grad_acc
199
  $$</p>
200
  <p>Gradient accumulation allows us to effectively increase our batch size up to infinity (and beyond!) while the memory footprint stays constant. Gradient accumulation is also compatible with activation recomputation for further memory reduction. One drawback however, is that gradient accumulation requires multiple consecutive forward/backward passes per optimization step thereby increasing the compute overhead and slowing down training. No free lunch! </p>
201
+ <p><img alt="image.png" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%206.png" /></p>
202
  <p><strong>Gradient accumulation allows us to reduce memory of activations which grow linearly with batch size by computing only only partial, micro-batches. There is a small overhead caused by the additional forward and backward passes.</strong></p>
203
  <blockquote>
204
  <p>Note: Using gradient accumulation means we need to keep buffers where we accumulate gradients which persist throughout a training step. Whereas without gradient accumulation, in the backward gradients are computed while freeing the activations memory, which means a lower peak memory.
 
206
  </blockquote>
207
  <p>But if you’ve carefully followed, you probably noticed that the forward/backward passes for each micro-batch can actually be run in parallel. Forward/backward passes are independent from each other, with independent input samples being the only difference. Seems like it’s time to start extending our training to more than one GPU! </p>
208
  <p>Let’s get a larger workstation 🖥️ with a couple of GPUs and start investigating our first scaling technique called <em><strong>data parallelism</strong> which is just a parallel version of gradient accumulation</em>.</p>
209
+ <p>TODO: intro for this</p>
210
+ <h2>torch.profiler</h2>
211
+ <p><img alt="**Overlapped backward pass (stream 7) and gradients accumulation (stream 28) means we start the optimizer step as soon as the backward pass is done**" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%207.png" /></p>
212
+ <p><strong>Overlapped backward pass (stream 7) and gradients accumulation (stream 28) means we start the optimizer step as soon as the backward pass is done</strong></p>
213
+ <p><img alt="In this naive approach we see a long AllReduce operation (stream 28) happening to sync all gradients after the backward pass is over, after which comes the optimizer step. GPUs stay idle while the communication happens." src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%208.png" /></p>
214
+ <p>In this naive approach we see a long AllReduce operation (stream 28) happening to sync all gradients after the backward pass is over, after which comes the optimizer step. GPUs stay idle while the communication happens.</p>
215
+ <p><img alt="**Overlapped backward pass (stream 7) and gradients accumulation (stream 28) means we start the optimizer step as soon as the backward pass is done**" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%207.png" /></p>
216
+ <p><strong>Overlapped backward pass (stream 7) and gradients accumulation (stream 28) means we start the optimizer step as soon as the backward pass is done</strong></p>
217
  <h1>Data Parallelism</h1>
218
  <p>The idea behind data parallelism (DP) is to replicate the model on several GPUs (we call the replica's “model instances”) and run forward and backward passes on different micro batches of data in parallel for each GPU, hence the name Data Parallelism. </p>
219
  <p>Using a different micro batch for each GPU means we’ll have different gradients in each GPU, so to keep the model instances in sync across different GPUs, the gradients from the model instances are averaged using an operation called “all-reduce”, which happens during the backward pass, before the optimizer step.</p>
 
305
  </blockquote>
306
  <p>While Data Parallelism is very efficient in scaling training, the naive replication of optimizer states, gradients, and parameters across each DP rank introduces a significant memory redundancy. ZeRO eliminates memory redundancy by partitioning the optimizer states, gradients, and parameters across the data parallel dimension, while still allowing computation with the full set of parameters. This sometimes requires more communications between DP ranks which may or may not be fully overlapped as we’ll see next!</p>
307
  <p>This approach is organized into three possible optimization stage of ZeRO:</p>
308
+ <p>ZeRO-1: optimizer state partitioning</p>
309
+ <p>ZeRO-2: optimizer state + gradient partitioning</p>
310
+ <p>ZeRO-3 (also called FSDP for “Fully-Sharded Data Parallelism”): optimizer state + gradient + parameter partitioning</p>
311
  <blockquote>
312
  <p>Note: You might be missing the activations among the things we can shard. Since each DP replica of the model receives a different microbatch the activations on each DP rank also differ so they are not duplicated and thus can’t be sharded!
313
  </p>
 
469
  - "f" is an all-reduce to synchronize gradients</p>
470
  <p>These operations "f" and "f*" are called conjugate pairs because they complement each other - when one is a no-op in forward, the other is an all-reduce in backward, and vice versa.</p>
471
  <p>For sequence parallelism (SP), we use different operations labeled "g" and "g*". Specifically, we avoid using all-reduce in the SP region since that would require gathering the full activations and increase our peak memory usage, defeating the purpose of SP.</p>
472
+ <p><img alt="image.png" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/c65c0745-6dda-4f5c-a7ae-0092e50cdc0f.png" /></p>
473
+ <p>So what is actually happening here? As a famous LLM would say, let’s take it step-by-step:</p>
 
 
474
  <ol>
475
+ <li><strong>Initial LayerNorm (SP Region)</strong></li>
476
+ <li>Input tensors X1<em> and X2</em> (b,s/2,h) enter LayerNorm, already split across sequence dimension</li>
477
+ <li>
478
+ <p>Each GPU computes LayerNorm independently on its sequence chunk and give Y1<em> and Y2</em></p>
479
+ </li>
480
+ <li>
481
+ <p><strong>First Transition (SP → TP)</strong></p>
482
+ </li>
483
+ <li>"g" operation (all-gather) combines Y1<em> and Y2</em> back to full sequence length</li>
484
+ <li>
485
+ <p>Restores Y (b,s,h) since column linear layer needs full hidden dimension h</p>
486
+ </li>
487
+ <li>
488
+ <p><strong>First Linear Layer (TP Region)</strong></p>
489
+ </li>
490
+ <li>A1 is a column-linear layer, so it splits Y along the hidden dimension</li>
491
+ <li>GeLU is applied independently on each GPU</li>
492
+ <li>
493
+ <p>Z1* is (b,s,h/2)</p>
494
+ </li>
495
+ <li>
496
+ <p><strong>Second Linear Layer (TP Region)</strong></p>
497
+ </li>
498
+ <li>B1 is a row-linear layer, so it restores the hidden dimension</li>
499
+ <li>
500
+ <p>W1 is (b,s,h)</p>
501
+ </li>
502
+ <li>
503
+ <p><strong>Final Transition (TP → SP)</strong></p>
504
+ </li>
505
+ <li>"g*" operation (reduce-scatter) which reduces for previous row-linear correctness while scattering along sequence dimension</li>
506
+ <li>W1* is (b,s/2,h)</li>
507
  </ol>
508
+ <p>A key advantage of sequence parallelism is that it reduces the maximum activation size we need to store. In tensor parallelism alone, we had to store activations of shape (b,s,h) at various points. However, with sequence parallelism, the maximum activation size is reduced to $\frac{b \cdot s \cdot h}{tp}$ since we always either split along the sequence or hidden dimensions.</p>
 
 
 
509
  <p>It’s a bit difficult to keep track of all the parts that are sharded differently in TP and TP/SP - believe us, we find it hard to map as well so we made this small table to summarize how the activations (aka <code>hidden_states</code> ) shape change across hidden dimension h and sequence dimension s during a forward pass:</p>
510
  <p>| Region | Vanilla TP | TP with SP |
511
  | --- | --- | --- |
512
+ | Enter TP (Column Linear) | h: sharded (weight_out is sharded)
513
+ s: full | h: sharded (weight_out is sharded)
514
  s: <strong>all-gather</strong> to full |
515
  | TP Region | h: sharded
516
+ s: full | h: sharded
517
  s: full |
518
  | Exit TP (Row Linear) | h: full (weight_out is full + <strong>all-reduce</strong> for correctness)
519
+ s: full | h: full (weight_out is full + <strong>reduce-scatter</strong> for correctness)
520
  s: <strong>reduce-scatter</strong> to sharded |
521
  | SP Region | h: full
522
+ s: full | h: full
523
  s: sharded |</p>
524
  <p>And for the embedding layer</p>
525
  <p>| Region | Vanilla TP | TP with SP |
 
527
  | Embedding Layer (Row Linear sharded on vocab) | h: full (weight_out is full + <strong>all-reduce</strong> for correctness)
528
  s: unchanged | h: full (weight_out is full + <strong>reduce-scatter</strong> for correctness)
529
  s: <strong>reduce-scatter</strong> to sharded |</p>
530
+ <p>Does that mean that SP incurs more communication than TP? Well, yes and no. In the forward of a vanilla TP we had two all-reduce per transformer block, and in SP we have two all-gather and two reduce-scatter per transformer block. So SP does twice the number of communication operations as TP. But since an all-reduce operation can be broken down into to an all-gather + reduce-scatter (see in [TODO: Appendix link]) they’re actually equivalent in terms of communication. Same reasoning for backward as we just use the conjugate of each operation (no-op ↔ allreduce and allgather ↔ reducescatter).</p>
531
+ <p>You can find an example of implementation of both column and row linear TP in picotron:
532
  <a href="https://github.com/huggingface/picotron/blob/main/picotron/tensor_parallel/tensor_parallel.py">https://github.com/huggingface/picotron/blob/main/picotron/tensor_parallel/tensor_parallel.py</a> </p>
533
+ <p><img alt="image.png" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2031.png" /></p>
 
 
534
  <p>If you’ve been paying close attention, you’ll notice that we’re talking about 4 comms ops <strong>IN EACH LAYER</strong> (2 for Attention and 2 for MLP), as shown here for the MLP region:</p>
535
+ <p><img alt="image.png" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2032.png" /></p>
536
  <p>Besides the fact that TP requires communication in each layer, it also can’t easily be overlapped with compute, which makes throughput heavily dependent on the communication bandwidth. This is why TP is usually done only within a node (TP≤8)</p>
537
  <blockquote>
538
  <p>Notice how all-gather is overlapped with “Y A1” that’s thanks to this trick
 
544
  <ul>
545
  <li>TP</li>
546
  </ul>
547
+ <p><img alt="image.png" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2033.png" /></p>
548
  <ul>
549
  <li>Seq Parall</li>
550
  </ul>
551
+ <p><img alt="image.png" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2034.png" /></p>
552
  <p>Allreduce takes almost double the duration (900us) of reducescatter and allgather (500us)</p>
553
  <p>Let’s compare throughput as we scale TP and TP/SP for a 3B model:</p>
554
+ <p><img alt="Impact of combined Tensor and Sequence Parallelism (TP/SP) on model performance and memory utilization: when scaling both TP and SP together, there's a trade-off between computational efficiency (left) and memory capacity (right). While higher parallelism degrees reduce per-GPU throughput, they enable processing of significantly larger batch sizes by reducing the activation memory." src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2035.png" /></p>
555
  <p>Impact of combined Tensor and Sequence Parallelism (TP/SP) on model performance and memory utilization: when scaling both TP and SP together, there's a trade-off between computational efficiency (left) and memory capacity (right). While higher parallelism degrees reduce per-GPU throughput, they enable processing of significantly larger batch sizes by reducing the activation memory.</p>
556
  <p>Let’s summarize our observations:</p>
557
  <ul>
 
569
  <h1>Context Parallelism</h1>
570
  <p>With Tensor Parallelism and Sequence Parallelism, we can reduce the memory requirements per GPU significantly as both model weights and activations are distributed across GPUs. However, when training models on longer and longer sequences (e.g. when scaling to 128k or more tokens per sequence) we might still exceed the memory available on a single node, because inside the TP region we still have to process a full sequence length.</p>
571
  <p>Even if we use full recomputation of the activations, which comes at a heavy compute overhead (30%), we still need to hold in memory some activations at the layer boundaries which scale linearly with sequence length:</p>
572
+ <p><img alt="image.png" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2036.png" /></p>
573
  <p>Can we apply similar ideas to our sequence parallelism approach but inside in the modules where we apply Tensor Parallelism already, thereby also reducing the effect of sequence length? Yes, it’s time to talk about Context Parallelism, which you will find quite intuitive after all we’ve already convered.</p>
574
  <h2>Introducing Context Parallelism</h2>
575
  <p>The idea of Context Parallelism is quite simple; just like Sequence Parallelism, we’ll split the input along the sequence dimension but we now apply this splitting along the full model, instead of only the sequence parallel regions of the model. Our focus here will be to reduce the activation memory footprint by splitting the long sequences, complementing parallelism strategies like TP which target the hidden dimension of the model.</p>
576
+ <p><img alt="image.png" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2037.png" /></p>
577
  <p>Splitting the sequence doesn't affect most modules like MLP and LayerNorm, where each token is processed independently. It also doesn’t require expensive communication like TP, as only the inputs are split and not the weight matrices. Just as in data parallelism, after computing the gradients, an all-reduce operation is initiated to synchronize the gradients across the context parallelism group.</p>
578
  <p>There is one important exception though, which is the <strong><em>attention module</em></strong>. In this module each token needs to access key/value pairs from <strong>all</strong> other sequence tokens or in the case of causal attention at least attends to each previous token.</p>
579
  <p>Because Context Parallelism splits the inputs along the sequence dimension across GPUs, the attention module requires full communication between GPUs to exchange the necessary key/value data.</p>
 
595
  <p><img alt="ring-attention.gif" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/ring-attention.gif" /></p>
596
  <p>With this animation, it’s also immediately clear why the authors chose to call this approach Ring Attention.</p>
597
  <p>There is one big problem though which is that a naive implementation of Ring Attention lead to some strong imbalance between GPU streaming from the shape of the causal attention matrix. Let’s take a real look at what is happening in the SoftMax computation by considering the attention score matrix with the causal attention mask:</p>
598
+ <p><img alt="image.png" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2038.png" /></p>
599
  <p>The SoftMax is computed row-wise, which means whenever a GPU has received all the tokens of a row it can be computed. We see that GPU1 can immediately compute it as it starts with tokens 1-4 and GPU1 actually doesn’t need to receive any information from any other GPUs. However, GPU2 will need to wait for the second round to also receive 1-4 and thus have all values for tokens 1-8. Also, GPU1 seems to perform much less work than all the other GPUs.</p>
600
  <p>Let’s see if we can balance our computations better:</p>
601
  <h2>Zig-Zag Ring Attention – A Balanced Compute Implementation</h2>
602
  <p>We need a better way to distribute the input sequences. This can be achieved by assigning the tokens not purely sequential to the GPUs but by mixing the ordering a bit such that we have a good mix of early and late tokens on each GPU. This approach is called <a href="https://arxiv.org/pdf/2311.09431">Zig-Zag attention</a> and in this new arrangement, the attention mask will show an even distribution of computation but if you count the number of colored squares, you’ll see that the computation is now balanced across all GPUs.</p>
603
+ <p><img alt="image.png" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2039.png" /></p>
604
  <p>At the same time we’ll also see that in order to complete all rows, each GPU will need information from all the other GPUs.</p>
605
  <p>We have two general ways to overlap computation and communication, either by performing a general all-gather, regrouping all the KV on each GPUs at the same time (in a Zero-3 type of way) or we gather them one-by-one from each GPU to each GPU as needed:</p>
606
+ <p><img alt="Context Parallelism using AllGather implementation" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2040.png" /></p>
607
  <p>Context Parallelism using AllGather implementation</p>
608
+ <p><img alt="Context Parallelism using All-to-All implementation" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2041.png" /></p>
609
  <p>Context Parallelism using All-to-All implementation</p>
610
  <p>TODO: add links to megatronlm(AllGather) and deepspeed(All2All) implementations</p>
611
  <h1></h1>
612
  <h1>Pipeline Parallelism</h1>
613
  <p>In the TP section we saw that if we try to scale Tensor parallelism past the number of GPUs per single node (typically 4 or 8) we hit a lower bandwidth network called “inter-node connection” which can quite strongly impair our performances. We can see this clearly on e.g. the all-reduce operation when we perform it across several nodes:</p>
614
+ <p><img alt="Inter-node communication bandwidth measurements across different node counts, showing median (lines) and 5th-95th percentile ranges (shaded areas) for AllReduce, AllGather and ReduceScatter operations." src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2042.png" /></p>
615
  <p>Inter-node communication bandwidth measurements across different node counts, showing median (lines) and 5th-95th percentile ranges (shaded areas) for AllReduce, AllGather and ReduceScatter operations.</p>
616
  <p>Sequence and context parallelism can help for long sequences but don’t help much if sequence length is not the root cause of our memory issues but rather the size of the model itself. For large model (70B+), the size of the weights alone can already push past the limits of the 4-8 GPUs on a single node. We can solve this issue by summoning the fourth (and last) parallelism dimension: “pipeline parallelism”.</p>
617
+ <p><img alt="image.png" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2043.png" /></p>
618
  <p>Pipeline Parallelism is conceptually very simple –we’ll simply spread the layers of our model across GPUs – but the devil lies in implementing it efficiently. Let’s dive in it!</p>
619
  <h2>Splitting layers on various nodes - All forward, all backward</h2>
620
  <p>So, let’s say we simply spread the layers on several devices, e.g. a first GPU will take the first few layers and a second GPU will take the second part of the models and so on. The forward pass through our model now simply involves sequentially passing the batch of data along the model and thus successively using each compute device.</p>
621
  <p>We have a direct first advantage: the required interconnect bandwidth stays quite low as we only send moderate-sized activations at a handful of location along the model depth. This is a huge difference e.g. compared to the communication in Tensor Parallelism, happening several times within each layer.</p>
622
  <p>But maybe you start feeling a glimpse of the troubles to come: “sequentially” and “successively”?!? This doesn’t sound very efficient in the world of parallel computation, especially after our discussion about computation and communication overlap.</p>
623
  <p>Indeed reader! The main challenge in pipeline parallelism will be how to efficiently circumvent the sequential nature of PP to keep our GPU busy at all times and avoid having one GPU computing while the others are waiting. Here is how our GPU utilization is looking when doing a naive and simple forward and backward pass through the model where the numbers indicate the model layers:</p>
624
+ <p><img alt="An example of Pipeline parallelism for a model with 16 layers distributed across 4 GPUs. The numbers correspond to the layer IDs." src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2044.png" /></p>
625
  <p>An example of Pipeline parallelism for a model with 16 layers distributed across 4 GPUs. The numbers correspond to the layer IDs.</p>
626
  <p>The remaining idle time is indicated in grey and usually called the “bubble” and the sight of this probably break your heart after we spent so much time optimizing throughput.</p>
627
  <p>We can quantify how efficient a pipeline setup is by looking at how much time we loose because of the bubble. Let’s say $t_f$ and $t_b$ are the times for the forward and backward pass, respectively, as measured for one microbatch and one stage of the pipeline (a simple assumption is often to have $t_b \approx 2 \times t_f$ which you can see on the above graph). If we could perfectly parallelize the ideal total time would be $t_{id}=t_f + t_b$. However, we can count on the graph that due to the pipeline bubble there is additional time of $t_{pb}=(p-1)*(t_f+t_b)$ (where $p$ is the degree of pipeline parallelism, i.e the number of GPU on the above graph) ie. the time each GPU is waiting while other GPUs are computing.</p>
 
632
  <p>As we add more stages the bubble time thus increases and the utilization drops.</p>
633
  <p>Thankfully, various pipeline parallelism schemes have been designed to reduce the size of the bubble which as you can see on this naive example can be very large in a naive implementation.</p>
634
  <p>Let’s take a first tool out of our toolbox and think about splitting our batch into smaller bit-sized portions which can be processed in parallel or almost, like we did before in data parallel for instance. Now when the second GPU is busy processing micro-batch 1, the first GPU can already start processing micro-batch 2. Here is a schedule using 8 micro-batches:</p>
635
+ <p><img alt="image.png" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2045.png" /></p>
636
  <blockquote>
637
  <p>Note: before the numbers indicated the layers but in all pipeline parallel plots from now including this one it indicates a microbatch. You can think of each square here to contain several layers as seen in the previous figure.
638
  </p>
 
680
  <p>Since the memory explosion is triggered by the activation we store for the backward pass, let’s try to see if we can start performing the backward pass while we are still performing other forward part of the computation. This will allow us to drop some of the activations we need for the backward pass as soon as possible.</p>
681
  <h2>One-forward-one-backward and LLama 3.1 schemes</h2>
682
  <p>This schedule is called <strong>one-forward-one-backward</strong> <strong>(1F1B)</strong> as the middle/steady state involves alternatively performing one forward and one backward pass. The general idea is to start performing the backward pass as soon as possible. The schedule looks like this:</p>
683
+ <p><img alt="image.png" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2046.png" /></p>
684
  <p>The bubble still has the same size so our training efficiency is not significantly improved. However we only need to store activations for $p$ micro-batches instead of $m$ which quite reduce the activation memory explosion we had in the AFAB schedule. As a consequence we can add more microbatches which then will actually reduce the bubble.</p>
685
+ <p><img alt="image.png" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2047.png" /></p>
686
  <p>A major complexity of this setup, visible on the above graph is how forward and backward passes are not cleanly consecutive anymore but performed in parallel across devices. This means we will have to schedule the switch from forward to backward passes independently on each device instead of in a simple and common central training loop as usual.</p>
687
  <p>This is one of the reason implementing Pipeline Parallelism usually requires rather extensive modifications to training code as well as modeling code.</p>
688
  <p>Here is the example training loop from the above gist:</p>
 
759
  <p>Well it turns out this is possible if we are willing to bring in a few additional communications. Time to talk about “<strong>Interleaved Stages</strong>”.</p>
760
  <p>Up to now we’ve sliced our model naively along the model depth dimensions, locating for instance layers 1-4 on the first GPU and layers 5-8 on the second GPU. But there are other ways we could think about slicing our layers, e.g. having odd layers 1, 3, 5, 7 on the first GPU and even layers 2, 4, 6, 8 on the second GPU.</p>
761
  <p>This can be seen in general as a kind of “looping pipeline” where a micro-batch will move in circles from one GPU to the next as it goes through the forward pass through the model.</p>
762
+ <p><img alt="image.png" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2048.png" /></p>
763
  <p>As a consequence we see additional communications happening as the model goes several times through each GPU for the same computation that previously just took one pass. However, each forward and backward pass is divided by a factor of $v$, where $v$ is the number of stages or model chunks per GPUs as we are able to better interleave forward and backward passes. </p>
764
  <p>$$
765
  t_{pb} = \frac{(p-1)<em>(t_f+t_b)}{v} \
766
  r_{bubble} = \frac{1}{v}\frac{(p-1)</em>(t_f+t_b)}{m<em>(t_f+t_b)} = \frac{p-1}{v</em>m}
767
  $$</p>
768
  <p>So we can now decrease the bubble by adding microbatches and interleaved stages, but note that quantitatively, the amount of communication also increases by 𝑣 so it’s a trade off. In the following plot you can see several configurations for a PP setup with $p=8$, where the special case of $m=1, v=1$ corresponds to naive pipeline parallelism and the configurations with $v=1$ are AFAB or 1F1B setups and $v \neq 1$ are interleaved configurations.</p>
769
+ <p><img alt="image.png" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2049.png" /></p>
770
  <p>Scheduling also becomes more complex here as we need to decide on a GPU whether we are prioritizing at a given moment earlier micro-batches meaning that we close the forward and backward loops as fast as possible (so called “depth-first”, i.e. prioritizing getting batches out of the model as fast as possible) or we prioritize to first complete the forward passes of all microbatches in the queue before going over to backward passes (so called “breadth-first” i.e. prioritizing filling in the pipeline as much as possible). This is explained in details in <a href="https://arxiv.org/pdf/2211.05953">https://arxiv.org/abs/2211.05953</a>.</p>
771
  <p>You now have all the elements to understand the pipeline parallelism approach in Llama 3.1 which is using a one-forward-one-backward setup with interleaved stages and a priority setting tuneable between depth-first and bread-first.</p>
772
+ <p><img alt="image.png" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2050.png" /></p>
773
  <p>However, we haven’t reached the end of possible pipeline schedules and recently some methods have been proposed to reduce the bubble to virtually zero! Peaked your curiosity? Let’s have a look!</p>
774
  <h2>Zero Bubble and DualPipe</h2>
775
  <p>There are even more sophisticated ways to reduce the bubble more and reached close to a “zero bubble” regime. The secret here is to split at an even finer-grained level the operations involved in order to interleave them in the most efficient way. For instance the pipeline implementation approach in DeepSeek V3/R1, called DualPipe reach close to a zero bubble regime.</p>
776
  <p>Let’s very quickly see how this can work by detailing briefly the <a href="https://arxiv.org/abs/2401.10241">ZeroBubble</a> work which is a precursor to DualPipe. The base observation of ZeroBubble is that a backward through a matrix multiplication involve actually two separated operations: backward for the inputs (B) and the backward for the weights (W):</p>
777
+ <p><img alt="image.png" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2051.png" /></p>
778
  <p><img alt="image.png" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2052.png" /></p>
 
779
  <p>While the output of B, the backward pass for the input, is necessary for performing the backward pass of the lower layers, the backward pass of the weights, W, is not necessary for the rest of the backward pass and generally only need to be performed before the optimiser step. This means W can be flexibly scheduled anywhere after the corresponding B of the same stage. This allows for strategic placement of W to fill the pipeline bubbles. The ZB-H2 schedule on the top right is an example of (theoretical) schedule with zero bubble taking advantage for this fine-grained decomposition.</p>
780
  <p>DeepSeek’s DualPipe propose an extension of this decomposed approach to the case of two stream propagating from both sides of the PP ranks and being interleaved to minimize even further idle time in the GPUs are displayed in the following scheduling graph</p>
781
+ <p><img alt="image.png" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2053.png" /></p>
782
  <p>The ZeroBubble and DualPipe schedules are a bit too complex for us to give here code snippets but you should start to have a general idea of the concepts involved. In practice, optimizing these schedules requires careful measurements of the time for each operations followed by a scheduling algorithm able to find the most optimal allocation of time given the constrains. See for instance in the <a href="https://arxiv.org/abs/2401.10241">ZeroBubble</a> paper for a discussion of the heuristics and algorithms to perform such a scheduling.</p>
783
  <h1>Expert parallelism</h1>
784
  <p>One more ~~thing~~ parallelism.</p>
785
  <p>Mixture-of-expert models have gained some traction with models such as Mixtral or more recently DeepSeek-V3/R1! The basic idea is that instead of having a single feedforward module per layer we can have several and route tokens through different ones depending on their context.</p>
786
  <p>So whereas Context parallelism</p>
787
+ <p><img alt="https://arxiv.org/pdf/2407.06204" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2054.png" /></p>
788
  <p><a href="https://arxiv.org/pdf/2407.06204">https://arxiv.org/pdf/2407.06204</a></p>
789
  <p>This design makes it very easy to add a new parallelism paradigm: Expert parallelism (EP). Since the feedforward layers are fully independent we can simply put each expert’s feedforward layer on a different worker. Compared to TP it’s much more lightweight, since we don’t need to split the matrix multiplication, we just need to route the hidden states of a token to the right expert. There are several tricks to make EP work in practice, closely tied to model design. For instance, DeepSeek-V3 enforces a constraint in the router, ensuring that each token is sent to at most M nodes (in their case, 4) to reduce communication overhead.</p>
790
  <p>Congratulation reader, we’ve now covered all current 4 directions of parallelism:</p>
 
810
  <p>In contrast to Pipeline Parallelism, Tensor Parallelism is naturally complementary and interoperable with ZeRO-3. For instance, if a model’s submodules (e.g. layers or layer block) is too large to fit in a GPU when rematerialised by ZeRO-3, there is no other obvious choice than to perform partial local operation and use Tensor Parallelism for this block/sub-model, combined with other dimension of parallelism like ZeRO-3 or PP as we saw above.</p>
811
  <p>Combining ZeRO-3 and TP doesn’t raise any specific issues except how to organize the GPU in groups for each parallelism dimension. As detailed above, TP will typically be kept for high-speed intra-node communications while ZeRO-3 can use parallelism groups spanning lower speed inter-node communications as the overlap with computation is easy to perform.</p>
812
  <h1>How to Find the Best Training Configuration</h1>
813
+ <p><img alt="image.png" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2055.png" /></p>
814
  <p>We’ve now covered all the parallelism techniques that are actually used to distribute and training larger models. There remain a general question: which ones should we choose and which ones are best combined? We touched a little bit on this at the end of the last section but in this section we will walk through the decision process step by step.</p>
815
  <p>| <strong>Method</strong> | <strong>Memory savings</strong> | <strong>Parallel/sharding dimension</strong> | <strong>Disadvantage</strong> |
816
  | --- | --- | --- | --- |
 
825
  <p>On the other hand by using PP or TP we can reduce the DP rank significantly, increase the local batch size and thus reduce the necessary global communication. In general, it is natural to combine 4D parallelism with at least ZeRO-1/2 to save optimizer and gradient memory and keep the number of PP stages or TP ranks under control.</p>
826
  <p>Overall, most training schedule past a certain size of the models wil tend to combine several of these methods.</p>
827
  <p>Let’s try synthesize the decision process into a relatively simple tree structure: </p>
828
+ <p><img alt="image.png" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2056.png" /></p>
829
  <p>To explain briefly, data parallelism is the most efficient method, and you should always prioritize it when memory is not a concern. If communication is not a concern and you can keep the BS/GPU at a big enough value to make good use of the GPU MatMul, ZeRO is an easy method to remove memory bottlenecks and stay close to a simple DP implementation. However, on larger clusters you’ll probably be able to make efficient use for more 4D parallelism. In this case, starting with tensor parallelism is the most direct way to reduce memory usage and is generally faster than pipeline parallelism within a single node(8 GPUs). However, in scenarios with long contexts, the primary memory usage will tend to shifts from model weights, gradients, and optimizer states to activation values. In such cases, context parallelism becomes more beneficial than pipeline parallelism. Note that this is not an exact recipe and you should think of this more as a starting point of hyperparameters to run your own benchmarks. For instance sometimes TP mixed with PP can be more efficient, even if TP&lt;8 and ZeRO-1/2 can make sense to mix in with 4D parallelism as well. </p>
830
  <p>This concludes our very deep dive into the distribution methods of 4D parallelism and ZeRO. However, besides scaling our model efficiently across GPUs there is another way to improve model throughput and memory management. </p>
831
  <p>Time to turn the lights off and activate CUDA mode! </p>
 
837
  <h3>A primer on GPU</h3>
838
  <p>Generally, GPUs have a very hierarchical organization. In this primer we’ll keep the discussion at the concept levels that are necessary for the rest of our presentation.</p>
839
  <p>On the compute side, GPUs consist of an array of compute units called <strong>Streaming Multiprocessors</strong> (SM). Each SM contains and controls a set of streaming processors, also known as cores. For example, an Nvidia H100 GPU has 132 SMs with 128 cores per SM, resulting in a total of 16,896 cores (see <a href="https://resources.nvidia.com/en-us-tensor-core">https://resources.nvidia.com/en-us-tensor-core</a> for details), each capable of handling multiple threads simultaneously.</p>
840
+ <p><img alt="Original figure from https://blog.codingconfessions.com/p/gpu-computing." src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2057.png" /></p>
841
  <p>Original figure from <a href="https://blog.codingconfessions.com/p/gpu-computing">https://blog.codingconfessions.com/p/gpu-computing</a>.</p>
842
  <p>The memory side is also highly hierarchical with several layers of cache and memory: <strong>Registers</strong> are the smallest units and are private to the threads during executions, <strong>Shared Memory</strong> and <strong>L1 cache are</strong> shared between the threads running on a single SM, higher up is the <strong>L2 cache</strong> shared by all SMs, finally there is the <strong>Global Memory</strong> which is the largest memory on the GPU (the advertised 80 GB for a H100 for instance) but also the slowest to access and query. </p>
843
+ <p><img alt="Original figure from https://www.youtube.com/watch?v=ZQKMZIP3Fzg" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2058.png" /></p>
844
  <p>Original figure from <a href="https://www.youtube.com/watch?v=ZQKMZIP3Fzg">https://www.youtube.com/watch?v=ZQKMZIP3Fzg</a></p>
845
  <p>The goal of GPU will be to run as many workloads as possible, in parallel, on the GPU cores, by taking advantage of this hierarchical organization of compute/memory.</p>
846
  <p>A piece of code running on a core of the GPU is called a <strong>kernel</strong>. It can be written at a high-level in <strong>CUDA</strong> or <strong>Triton</strong> for instance, and is then compiled to Parallel Thread Execution, PTX, the low-level assembly used by NVIDIA GPUs.</p>
847
  <p>To run the kernel, you will also need a specific code part (called <strong>host code</strong>) which is executed on the <strong>CPU</strong>/host and will take care of preparing data allocations and loading data and code.</p>
848
+ <p><img alt="image.png" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2059.png" /></p>
849
  <p>Figure 5: Host code for a CUDA kernel for adding two vectors from <a href="https://blog.codingconfessions.com/p/gpu-computing">https://blog.codingconfessions.com/p/gpu-computing</a></p>
850
+ <p><img alt="image.png" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2060.png" /></p>
851
  <p>Figure 6: Device code containing the definition of the vector addition kernel from <a href="https://blog.codingconfessions.com/p/gpu-computing">https://blog.codingconfessions.com/p/gpu-computing</a></p>
852
  <p>Kernels are generally scheduled as follow:</p>
853
  <ul>
 
876
  def elu(x, alpha=1.0):
877
  return torch.where(x &lt; 0, alpha * (torch.exp(x) - 1), x)</code></p>
878
  <p>The distinction between the compiled and non-compiled versions is striking, especially given that we only added a single decorator. This remarkable difference is illustrated in the graph below (N is the number of columns) :</p>
879
+ <p><img alt="image.png" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2061.png" /></p>
880
  <p>However, if this performance increase is insufficient, you can consider implementing Triton kernels. As a starting point, you can take a look at the triton kernel generated by <code>@torch.compile</code> . To do so, you simply need to set the environment variable <code>TORCH_LOGS</code> to “output_code” : </p>
881
  <p><code>bash
882
  export TORCH_LOGS="output_code"</code></p>
 
924
  <p>```</p>
925
  <p>Here, <code>tl.program_id(0)</code> provides a unique block ID, that we use to determine which section of data that block will process. Using this block ID, <code>block_start</code> calculates the starting index for each block’s section, while <code>block_indices</code> specifies the range of indices within that section. A <code>valid_mask</code> ensures that only indices within <code>num_elements</code> are processed, safely loading the data with <code>tl.load</code>. The ELU function is then applied, modifying values based on whether they're negative, and results are written back to memory with <a href="http://tl.store"><code>tl.store</code></a> .</p>
926
  <p>When we benchmark the generated kernel using <code>triton.testing.Benchmark</code> we have the following performance : </p>
927
+ <p><img alt="image.png" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2062.png" /></p>
928
  <p>This standalone kernel demonstrates superior performance with smaller sizes compared to <code>@torch.compile</code> but this is likely here just an artifact from the compilation time of torch. compile. In any case, instead of starting from scratch, we can focus on optimizing this generated kernel, saving us time in the process. </p>
929
  <p>However, in Triton, sometimes, we cannot fully achieve the peak performance of the device due to limitations in handling shared memory and scheduling within streaming multiprocessors (SMs). Our access is restricted to blocks, allowing us only to manage the scheduling of blocks across SMs. To gain even more control, we will need to implement kernels in CUDA, where we have access to all the underlying components. </p>
930
  <p>In CUDA, there are various techniques that can be employed to make kernels more efficient; we will present just a few. These include optimizing memory access patterns to reduce latency, using shared memory to store frequently accessed data, and managing thread workloads to minimize idle times. In summary, the tools for writing code to execute instructions on the GPU are:</p>
 
954
  <p>}
955
  ```</p>
956
  <p>Here’s an excellent visualization of the kernel from this fantastic <a href="https://siboehm.com/articles/22/CUDA-MMM">blogpost</a> : </p>
957
+ <p><img alt="image.png" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2063.png" /></p>
958
  <p>However, when profiling this kernel with a tool like <code>ncu</code>, we can see issues, including low memory throughput and uncoalesced memory accesses.</p>
959
+ <p><img alt="image.png" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2064.png" /></p>
960
  <p><img alt="image.png" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2065.png" /></p>
 
961
  <p>The reason for this is that in this kernel, two threads in the same block with Thread IDs <code>(0, 0)</code> and <code>(1, 0)</code> (which will end up in the same warp) will both load from the same column of matrix <code>B</code> but different rows of matrix <code>A</code>. Since matrix elements are stored in row-major order (meaning each row's elements are in consecutive memory addresses, as shown in the figure below), in the first iteration with <code>i = 0</code>, thread <code>(0, 0)</code> will load $A_{0,0}$, and thread <code>(1, 0)</code> will load $A_{1,0}$. These elements are not stored close to each other in memory, and this misalignment repeats across all iterations along the shared dimension, preventing memory accesses from being coalesced.</p>
962
+ <p><img alt="image.png" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2066.png" /></p>
963
  <p>To improve our kernel we can change the way the coordinates x and y are calculated like the following : </p>
964
  <p>```cpp
965
  const int x = blockIdx.x * BLOCKSIZE + (threadIdx.x / BLOCKSIZE);
 
974
  ```</p>
975
  <p>Instead of using a 2D block, we switch to a 1D block and redefine how we determine the values of <code>x</code> and <code>y</code>. In this new method, threads within the same warp (which have close <code>threadIdx.x</code> values) will share the same <code>x</code> value but have different <code>y</code> values. This means that they will load the same row of matrix <code>A</code> but different columns of matrix <code>B</code>. As a result, memory accesses can be coalesced for a row-major matrix.</p>
976
  <p>When we profile our new kernel, we notice that the warning about uncoalesced memory accesses has disappeared, and <strong>the GPU's memory throughput has increased by approximately 10 times</strong>.</p>
977
+ <p><img alt="image.png" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2067.png" /></p>
978
  <p>We also notice that the execution time of the kernel <strong>decreases by 10x</strong> !</p>
979
  <p>Let’s cover another technique you will often see mentioned in the litterature: tiling.</p>
980
  <h3>Tiling</h3>
981
  <p>Tiling is a technique that leverages <em>shared memory</em> to optimize memory access patterns. As we mentioned above, the shared memory is a small, fast memory accessible by all threads within a block. It allows data to be reused by multiple threads, reducing the need to repeatedly load data from slower global memory.</p>
982
  <p>In matrix multiplication for example, each thread in a block may need elements from two matrices, say A and B. If each thread independently loads the row and column it needs from global memory, we end up with many redundant loads, as multiple threads in a block will access overlapping data. Instead, we can use tiling to load a block (or tile) of A and B into shared memory just once, allowing all threads in that block to reuse the same shared data.</p>
983
  <p>In the tiling approach, each iteration involves all threads within a block cooperatively loading two tiles—one from matrix A and another from matrix B —into shared memory. Specifically, threads load a tile of matrix A (of size <code>BLOCK_SIZE_M</code> by <code>BLOCK_SIZE_K</code>) and a tile of matrix B (of size <code>BLOCK_SIZE_K</code> by <code>BLOCK_SIZE_N</code>). Once the tiles are in shared memory, the threads perform matrix multiplication on these tiles, enabling efficient computation since all necessary data is quickly accessible. The results of the tile multiplication are stored in an accumulation matrix that holds intermediate results. After each iteration, the results from the current tile multiplication are added to this accumulation matrix, continuing until all tiles from both matrices have been processed. </p>
984
+ <p><img alt="From https://cnugteren.github.io/tutorial/pages/page4.html" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2068.png" /></p>
985
  <p>From <a href="https://cnugteren.github.io/tutorial/pages/page4.html">https://cnugteren.github.io/tutorial/pages/page4.html</a></p>
986
  <p>The important parts to understand the implementation are below (for simplicity we consider a square shaped tile) : </p>
987
  <p>```cpp</p>
 
1013
  <p>When benchmarking this kernel using ncu, we noticed that the memory throughput increased to 410 Gb / s, and the kernel execution time decreased by ~43% achieving a ~6.6 TFLOPs performance</p>
1014
  <h3>Thread Coarsening</h3>
1015
  <p>The tiling technique has significantly improved the performance of our kernel. However, when analyzing the warp states which quantify how many cycles were spent in each state, we observe the following:</p>
1016
+ <p><img alt="image.png" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2069.png" /></p>
1017
  <p>The meaning of the states can be found in the <a href="https://docs.nvidia.com/nsight-compute/ProfilingGuide/index.html#metrics-reference">Profiling Guide</a>, specifically in the <strong>Warp Stall Reasons</strong> section. There we can read that : </p>
1018
  <blockquote>
1019
  <p>smsp__pcsamp_warps_issue_stalled_mio_throttle : Warp was stalled waiting for the MIO (memory input/output) instruction queue to be not full. This stall reason is high in cases of extreme utilization of the MIO pipelines, which include special math instructions, dynamic branches, as well as shared memory instructions. When caused by shared memory accesses, trying to use fewer but wider loads can reduce pipeline pressure.
 
1028
  <h2>Flash Attention 1-3</h2>
1029
  <p>Flash attention is a technique pioneered by <a href="https://tridao.me">Tri Dao</a> that optimizes the attention computations by writing custom CUDA kernels to make it much faster <em>and</em> more memory efficient. The idea behind Flash Attention is to make efficient use of the various memories of the GPU to avoid using too much the slowest global memory of the GPU (confusingly called the High Bandwidth Memory, HBM 🫠) </p>
1030
  <p>A basic implementation of the attention mechanism involve a lot of transfer between memory and workers. It requires materializing the S and P matrices in HBM which means that the results need to be sent to HBM and then back to SRAM for the next computations:</p>
1031
+ <p><img alt="image.png" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2070.png" /></p>
1032
  <p>Since bandwidth is much lower in HBM this introduces a severe bottleneck in the attention computation. Can we do better? Tri Dao says yes!</p>
1033
  <p>The key element is to compute the S matrices in small pieces which can fit in the smaller shared memory of the SM. But we can do even better and avoid materializing the very large S matrix all together in favor of keeping only the necessary statistics for computing the normalization factor of the softmax. So we can compute part of $O$ directly in one computation in SRAM rather than moving intermediate results back and forth. In this case, not even do we make use of the shared memory but we also release the memory bottleneck resulting from materializing one of the largest activation matrices in the model (at long context length), the attention matrix.</p>
1034
+ <p><img alt="From the FLASH-ATTENTION paper (https://arxiv.org/pdf/2205.14135)" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2071.png" /></p>
1035
  <p>From the FLASH-ATTENTION paper (<a href="https://arxiv.org/pdf/2205.14135">https://arxiv.org/pdf/2205.14135</a>)</p>
1036
  <p>The idea of flash attention resolves so many bottlenecks in model training that it has quickly become the default way to perform attention in all transformers:</p>
1037
  <ul>
 
1071
  </p>
1072
  </blockquote>
1073
  <p>Reducing the total number of bits comes at a price (no free lunch here either), but we have some control over how to pay. Either we can sacrifice more bits on the mantissa or exponent. For this reason there exist also two float8 formats, named according to exponent and mantissa, to flexibly choose the most appropriate format. We can look at the possible range of numbers for each format:</p>
1074
+ <p><img alt="image.png" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2072.png" /></p>
1075
  <p>We can see that float32 spans 80 orders of magnitude and float16 sacrifices a lot of range while bfloat16 maintains the full range. The two float8 formats reduce the range even further where e5e2 can maintain float16 range and e4m3 has an even smaller ranger.</p>
1076
  <p>How come some format are able to maintain the range and other not? Let’s investigate the resolution by plotting 10,000 points between 1 and 2. Each point will be rounded to the nearest representable number in each format:</p>
1077
+ <p><img alt="image.png" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2073.png" /></p>
1078
  <p>We can see here that bfloat16 maintained the range of float32 over float16 but did this with the cost of sacrificing more precision. In case of float8 the situation is even more dire as e4m3 can represent 7 and e5m2 only 3 number on the interval 1-2.</p>
1079
  <p>A common metric to measure a formats resolution is epsilon: the first representable number after 1.00. We can see that for the float32 format $10^{-4}$ is an upper bound (it’s actually $1.19^{-7}$). For float16 it is \tilde 10^{-3} and for bfloat 10x higher still. </p>
1080
  <p>The idea of mixed precision training is to use some of these lower precisions formats while maintaining the performance of full precision training. It turns out we <strong>can’t</strong> totally abandon float32 and usually will need to maintain some parts in full precision. </p>
 
1092
  <h3>FP8 pretraining</h3>
1093
  <p>Even if we perfectly overlap communication with computation, we always eventually run into the low level theoretical FLOPS limit of the hardware itself, i.e. the efficiency of each individual operation on our hardware. This is where numerical precision becomes crucial. For instance, on NVIDIA's H100 GPU, FP8 matrix multiplications (GEMM operations) achieve twice the theoretical FLOPS of bfloat16, making lower-precision training an attractive path for further optimization.</p>
1094
  <p>Recent research - including <a href="https://arxiv.org/abs/2310.18313">FP8-LM</a>, <a href="https://github.com/pytorch/ao/tree/main/torchao/float8#torchaofloat8">torchao</a>, and <a href="https://arxiv.org/abs/2412.19437">DeepSeek-V3</a> - has demonstrated the potential of FP8 training for large-scale models. Still, FP8 pretraining introduces a significant challenge: stability. At lower precision, numerical instability often leads to loss divergence, making it difficult to match the accuracy of higher-precision training.</p>
1095
+ <p><img alt="image.png" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2074.png" /></p>
1096
  <p>As <a href="https://arxiv.org/abs/2309.14322">[Wortsman et al.]</a> observed, instability increases as learning rates rise for a fixed model size, making FP8 pretraining particularly tricky.</p>
1097
  <p>The first, successful, very large scale training with FP8 mixed precision was publicly reported on DeepSeek-V3. The authors carefully analyzed each operation of the forward pass (Fprop) as well as the activation (Dgrad) and weight (Wgrad) backward pass. Similar to BF16 mixed precision training, some aggregation and master weights are kept in higher precision while the operations themselves are performed in FP8. </p>
1098
  <p><img alt="Screenshot 2025-02-09 at 22.20.28.png" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/Screenshot_2025-02-09_at_22.20.28.png" /></p>
 
1196
  </ul>
1197
  <p>Throughout this blogpost we’ll scale LLM training from one to hundreds of GPUs. This will require the communication and synchronization of weights, gradients, and data between all the machines. There’s a set of distributed patterns to achieve exactly that called <strong>collective operations</strong>. In this section we’ll do a small crash course of <em>Broadcast, AllReduce, Scatter</em> and co. but if you are already familiar with these patterns feel free to move directly to [SECTION I], otherwise let’s get ☕ #1 (or your neural stimulant of choice) and let’s dig in! </p>
1198
  <p>The general setup is that we have a number of independent nodes which could be CPU cores, GPUs, or compute nodes. Each performs some computation and then we want to communicate the result or parts of it to the other nodes for the next computation step (t+1). </p>
1199
+ <p><img alt="image.png" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2075.png" /></p>
1200
  <p>Maybe we need to send the result from one node to all other nodes, or we need to sum all the intermediate results from each node to report the overall result. Usually, there is one node with an elevated status that plays a central role, here denoted with <code>root</code> that is the target or source of some operations. Let’s start with one of the simplest primitives: a broadcast operation.</p>
1201
  <h3>Broadcast</h3>
1202
  <p>A very common pattern is that you have some data on Node 1 and you want to share it with all the other nodes so they can do some computation with the data. The broadcast operation does just that:</p>
1203
+ <p><img alt="image.png" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2076.png" /></p>
1204
  <p>Collective operations are natively provided by PyTorch so we can easily write a small example that demonstrates how broadcasting works. We first need to initialize a process group with <code>dist.initi_process_group</code> which sets up the communication backend (we’ll talk about NCCL later), it determines how many workers (aka nodes) exists and assigns a rank to each one (which we can get with <code>dist.get_rank</code>). Finally, it establishes a connection between the workers.</p>
1205
  <p>To showcase the <code>broadcast</code> operation, let's create a tensor with non-zero values on <code>rank=0</code> and tensors full of zeros on the other workers. We then distribute the <code>rank=0</code> tensor to all other ranks with <code>dist.broadcast(tensor, src=0)</code> :</p>
1206
  <p>```python
 
1232
  <p>Great, seems like it works as expected. Note that the rank messages can be printed out of order as we have no control over which print statement is executed first (we ordered them here for readability). Now let’s move on to the Reduce and AllReduce patterns! </p>
1233
  <h3>Reduce &amp; AllReduce</h3>
1234
  <p>Reduce patterns are among the most fundamental patterns in distributed data processing. The idea is that you want to combine the data present on each node through a function <code>f()</code> which can be for instance summation or averaging. In the Reduce paradigm the result is sent to the root node only, whereas in the AllReduce case the result is broadcasted to all nodes:</p>
1235
+ <p><img alt="image.png" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2077.png" /></p>
1236
  <p>Of course no magic “free flying” node that can perform this operation and generally each node does a partial computation in a ring or tree structure of the nodes. Here is a simple example: let’s say we need to compute a sum of numbers on each nodes and our nodes are connected in a ring pattern. The first node sends its number to a neighbour which adds its number to the received number before forwarding it to the next neighbour. At the end of a round along the ring of nodes, the first node will receive the total sum.</p>
1237
  <p>Here’s the code to run a simple Reduce operation summing the tensors, we specify the operation to use with <code>op=dist.ReduceOp.SUM</code> (you can find more information on the supported operations in the docs: <a href="https://pytorch.org/docs/stable/distributed.html#torch.distributed.ReduceOp">https://pytorch.org/docs/stable/distributed.html#torch.distributed.ReduceOp</a>):</p>
1238
  <p>```python
 
1303
  <p>Now let’s turn to our next distributed communication operation. In many real cases, each node individually perform many complex computations and we need to share the final results among nodes. Gather and AllGather are the operations we want to use in this case. Let’s take a look! </p>
1304
  <h3>Gather &amp; AllGather</h3>
1305
  <p>Gather and AllGather are quite similar to the Broadcast in that they allow distributing data among node without modification. The main difference to Broadcast is that there is not one value we need to share from one node to all other nodes but each node has an individual chunk of data that we want to either gather all data on one node (in case of Gather) or gather all data on all nodes (in the case of AllGather). A picture being worth 1000 words, let’s take a look:</p>
1306
+ <p><img alt="image.png" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2078.png" /></p>
1307
  <p>Note that the dashed lines indicate that some data actually doesn’t move at all (since it’s already present on the node).</p>
1308
  <p>In the case of the gather operation we need to prepare a container objects where the gathered tensors can be stored in this example the <code>gather_list</code>:</p>
1309
  <p>```python
 
1359
  <h3>Scatter &amp; ReduceScatter</h3>
1360
  <p>As the name subtly suggests, the goal of the Scatter operation is to take data on one node and distribute slices of it to all other nodes. It’s thus different from the Broadcast operation which copy data without slicing and it’s the logical the inverse of the Gather operation.</p>
1361
  <p>The ReduceScatter pattern is slightly more complex: imagine you apply an operation like in the Reduce case but instead of moving the result to just one node we also distribute it evenly to all nodes:</p>
1362
+ <p><img alt="image.png" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2079.png" /></p>
1363
  <p>The Scatter operation is written in code as the opposite of the Gather: instead of preparing a list of tensors as target we prepare the source data as a list of tensors we want to distribute. We also need to specify the <code>src</code>:</p>
1364
  <p>```python
1365
  def example_scatter():
 
1418
  <p>We now have seen the main building block of distributed operations but before we see them in action let’s have a look at a special operation used for synchronization: the Barrier.</p>
1419
  <h3>Barrier</h3>
1420
  <p>The Barrier is a simple operation to synchronize all nodes. A barrier is not lifted until all nodes have reached it. Then only are they allowed to continue with further computations:</p>
1421
+ <p><img alt="image.png" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2080.png" /></p>
1422
  <p>We can easily simulate delayed nodes by setting up a different sleep time on each node and see how long it takes for all of them to pass the barrier:</p>
1423
  <p>```python
1424
  def example_barrier():
 
1503
  <p>print(p.key_averages().table(sort_by="cuda_time_total", row_limit=8))</p>
1504
  <p>```</p>
1505
  <p>This would print aggregated profiling results sorted by the total CUDA time, and the output would be:</p>
1506
+ <p><img alt="image.png" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2081.png" /></p>
1507
  <p>You can also try to inspect the trace as we previously mentioned on <code>chrome://tracing/</code> </p>
1508
  <blockquote>
1509
  <p>If you're new to this tool, you can navigate the trace by using the right and left arrow keys. Additionally, you can zoom in and out by holding the <strong>Alt</strong> key while scrolling left or right with your mouse.
1510
  </p>
1511
  </blockquote>
1512
  <p>After zooming in, you can observe the flow of operations when calling <code>layer_norm</code> in this trace:</p>
1513
+ <p><img alt="image.png" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2082.png" /></p>
1514
  <p>The sequence begins in the CPU (the upper section) with <code>aten::layer_norm</code>, progressing to <code>aten::native_layer_norm</code>, and then transitioning to <code>cudaLaunchKernel</code>. From there, we move on to the GPU, where the <code>vectorized_layer_norm_kernel</code> kernel is called. </p>
1515
  <blockquote>
1516
  <p>Note that you can enable memory profiling by setting <code>profile_memory</code> to <code>True</code> in the profiler. However, this can lead to more complex traces.
 
1523
  <p><code>python
1524
  ncu --set full -o output python layer_norm.py</code></p>
1525
  <p>and open the file <code>output.ncu-rep</code> with Nsight Compute, you will have a view that looks like this : </p>
1526
+ <p><img alt="image.png" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2083.png" /></p>
1527
  <p>With clear warnings about compute and memory utilization, and how to make the kernel better in balancing compute and memory and achieve maximal occupancy.</p>
1528
  <p><strong>CPP extension</strong></p>
1529
  <p>If the kernel you want to profile isn’t already integrated into PyTorch, you can use PyTorch's <code>cpp_extension</code> module to easily compile and run custom CUDA code. The process is straightforward—just create your CUDA kernel in a <code>.cu</code> file, and use the <code>load</code> function from the <code>cpp_extension</code> module to load it in Python.</p>
 
1583
  \frac{dL}{dX} = \frac{dL}{dY} \frac{dY}{dX} = \frac{dL}{dY} W</p>
1584
  <p>$$</p>
1585
  <p>The chain rule applies here since the loss (L) depends directly on the output (Y). This equation is telling us that to get the gradient of the loss with respect to our input (dL/dX), we multiply the gradient of the loss with respect to the output (dL/dY) by our weight matrix (W).</p>
1586
+ <p><img alt="image.png" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2084.png" /></p>
1587
  <p>Likewise, we can use chain rule to compute the gradient w.r.t to the weight:</p>
1588
  <p>$$
1589
  \frac{dL}{dW} = \frac{dL}{dY} \frac{dY}{dW} = \frac{dL}{dY} X
1590
  $$</p>
1591
+ <p><img alt="image.png" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2085.png" /></p>
1592
  <p>Here is a snippet of code to clarify all the concepts above:</p>
1593
  <p>```python
1594
  def column_linear_forward(X, local_W, group):
 
1659
  <pre><code>example_column_row_linear()
1660
  </code></pre>
1661
  <p>```</p>
1662
+ <p><img alt="image.png" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2086.png" /></p>
1663
  <p><strong>TODO</strong> add these illustrations somewhere? I found them helpful:</p>
1664
+ <p><img alt="image.png" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2087.png" /></p>
1665
  <p><img alt="image.png" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2088.png" /></p>
 
1666
  <h2>A3: ZeRO-R</h2>
1667
  <p>To further optimize memory usage in large-scale training, DeepSpeed ZeRO-R introduces several techniques aimed at reducing the memory footprint of activation values during forward and backward propagation. The key strategies include partitioned activation checkpointing, the use of constant-size buffers, and memory defragmentation.</p>
1668
  <h3>$P_a:$ Partitioned Activation Checkpointing</h3>
 
1775
  </code></pre>
1776
  <p>```</p>
1777
  <h3>Interconnect</h3>
1778
+ <p><img alt="image.png" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2089.png" /></p>
1779
  <h2>How to profile your code</h2>
1780
  <p>The profiler is a tremendously useful tool and easy to use. It takes three steps to profile your program:</p>
1781
  <ol>
 
1807
  </code></pre>
1808
  <p>```</p>
1809
  <p>After running this code, you will find <code>*.trace.json</code> files under the <code>profiler_out_dir</code>. To visualize the results, the easiest way is to open Google Chrome, go to <code>chrome://tracing/</code>, and drag the file into it. This will allow you to view the profiling results. To get more details, we invite you to check out the amazing <a href="https://pytorch.org/tutorials/intermediate/tensorboard_profiler_tutorial.html">**tutorial</a>** created by PyTorch.</p>
1810
+ <p><img alt="image.png" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2090.png" /></p>
1811
  <h2>Formulas for compute / comms the balanhe balance</h2>
1812
  <p>```markdown</p>
1813
  <p>Estimates: (number of elements, need to multiply by bytes_per_element)
 
1909
  <p>for a single microbatch:
1910
  -&gt; t_comm / t_compute = seq * mbs * h * peak_flops / (num_layers_in_next_pp * 32 * seq * mbs * h^2) * peak_bw = peak_flops / (32 * h*num_layers_in_next_pp) * peak_bw</p>
1911
  <p>```</p>
1912
+ <p><img alt="image.png" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2091.png" /></p>
1913
  <h2>Integrating Context Parallelism with TP/SP</h2>
1914
  <p>We’ve seen that both TP/SP and CP shard the activations along sequence dimension, and both require communications in the Attention module, wouldn’t that create issues? In fact not at all!</p>
1915
  <p>In order to integrate CP with TP/SP we just have to:</p>
 
1919
  <li><strong>Replace standard attention with ring attention:</strong> During the forward pass, each TP rank relies on the ring attention to compute the correct attention results during both the forward and backward passes. So all CP ranks within TP=0 for example need to all-gather the full KV sequence and calculate attention, but we store only the KV of a sequence chunk to reduce memory activations by CP.</li>
1920
  </ol>
1921
  <p><img alt="TP=0 has GPU0 and GPU2 whereas CP=0 has GPU0 and GPU1
1922
+ TP/SP shards the Q/K/V heads across TP ranks (in this example GPU0 and GPU2 get QKV_green, and GPU2 and GPU3 get QKV_blue), since each head can operate independently from others, we can apply ring attention within each TP rank" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2092.png" /></p>
1923
  <p>TP=0 has GPU0 and GPU2 whereas CP=0 has GPU0 and GPU1
1924
  TP/SP shards the Q/K/V heads across TP ranks (in this example GPU0 and GPU2 get QKV_green, and GPU2 and GPU3 get QKV_blue), since each head can operate independently from others, we can apply ring attention within each TP rank</p>
1925
  <p>Context parallelism is naturally compatible with data parallelism which splits the input along the batch size dimension.</p>
1926
  <p>In fact, given an activation value of shape$[ \text{batch_size}, \text{sequence_length}, \text{hidden_dimension} ]$, data parallelism, sequence/context parallelism, and tensor parallelism split it across the 1st, 2nd, and 3rd dimensions, respectively, and these are independent of each other.</p>
1927
  <h2>The nanotron FP8 recipe</h2>
1928
  <p>However, through extensive experimentation, we identified two effective training recipes that allowed us to <strong>fully pretrain a 1B LLaMA model in FP8</strong>, covering both the forward and backward passes, while using an FP8 optimizer. More importantly, our approach successfully matched LLaMA-2’s pretraining learning rate. The result?</p>
1929
+ <p><img alt="image.png" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2093.png" /></p>
1930
  <p>A loss curve that perfectly matches mixed-precision bfloat16 (bfloat16 with FP32 master weights as the baseline). We successfully tested this to train a 1B LLaMA up to 100B tokens and a 7B LLaMA up to 25B tokens.</p>
1931
  <p>Here’s what worked:</p>
1932
  <ul>
 
1957
  <p>The general idea is always the same: if there are parts we will need to communicate soon between workers and that are independent of the current computation, we can parallelize, or as it is also called overlap, the communication and computation. </p>
1958
  <p>Let’s take a moment to look better at this fundamental tool for distributed training and go over the example of Ring Attention using PyTorch Profiler. In its implementation, we can overlap the sending and receiving of key/value pairs with the computation of attention scores. What does this look like?</p>
1959
  <p><strong>Non-overlapping:</strong> If we don't overlap the communication and computation, each computation (represented by the purple block) can only begin after the communication (green block) is complete and total time is the sum of communication and computation.</p>
1960
+ <p><img alt="image.png" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2094.png" /></p>
1961
  <p><strong>Overlapping:</strong> However, if we manage to launch communication and computation in parallel, we eliminate the waiting time! Now we can see that the computation (green block) is launched immediately, one after the other. In this case the total time is <em>only</em> the sum of computations.</p>
1962
+ <p><img alt="image.png" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2095.png" /></p>
1963
  <p>Context parallelism has helped us going past the intra-node interconnect bottleneck, which blocked us from scaling TP across nodes. However, as you probably noted, it only helps reducing the memory constraints if the activation memory dominates the memory budget due to long sequences. What if we are not working on super long sequences and the model weights alone are too big for a single node?</p>
1964
  <p>Well it turns out we have an other –quite different– option called pipeline parallelism (PP) which the time has come to explore now.</p>
1965
  <p>[TODO: comment from Nouamane on comms overlapping with DP 512]</p>
blog-export.md CHANGED
@@ -120,18 +120,6 @@ These items are stored as tensors which come in different *shapes* and *precisio
120
 
121
  So how can I quickly determine memory usage from these variable? One simple way is to do this empirically and just measure it.
122
 
123
- ![**Overlapped backward pass (stream 7) and gradients accumulation (stream 28) means we start the optimizer step as soon as the backward pass is done**](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%205.png)
124
-
125
- **Overlapped backward pass (stream 7) and gradients accumulation (stream 28) means we start the optimizer step as soon as the backward pass is done**
126
-
127
- ![In this naive approach we see a long AllReduce operation (stream 28) happening to sync all gradients after the backward pass is over, after which comes the optimizer step. GPUs stay idle while the communication happens.](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%206.png)
128
-
129
- In this naive approach we see a long AllReduce operation (stream 28) happening to sync all gradients after the backward pass is over, after which comes the optimizer step. GPUs stay idle while the communication happens.
130
-
131
- ![**Overlapped backward pass (stream 7) and gradients accumulation (stream 28) means we start the optimizer step as soon as the backward pass is done**](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%205.png)
132
-
133
- **Overlapped backward pass (stream 7) and gradients accumulation (stream 28) means we start the optimizer step as soon as the backward pass is done**
134
-
135
  ### Memory profiling a training step
136
 
137
  Using this snippet [TODO: link to appendix A5], we can understand how memory is allocated throughout training. We can see that memory utilization is not a static thing but varies a lot during training and during a training step:
@@ -235,7 +223,7 @@ It’s time to explain our first technique – called ***activation recomputatio
235
 
236
  The general idea behind ***activation recomputation** –*also called ***gradient checkpointing*** or ***rematerialization**– *****is to discard some activations during the forward pass to save memory and spend some extra compute to recompute these on the fly during the backward pass. Without recomputation, we store every hidden state between two learnable operations (e.g. FF, LayerNorm etc.), such that we can use them during the backward pass to compute gradients. When we use recomputation we typically will only store activations at a few key points along the model architecture, discard the rest of activations and recompute them on the fly during the backward pass from the nearest saved activations, basically performing again a sub-part of the forward pass to trade of memory for compute. It generally looks like this:
237
 
238
- ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%207.png)
239
 
240
  There are several strategies to select key activations to store:
241
 
@@ -282,7 +270,7 @@ $$
282
 
283
  Gradient accumulation allows us to effectively increase our batch size up to infinity (and beyond!) while the memory footprint stays constant. Gradient accumulation is also compatible with activation recomputation for further memory reduction. One drawback however, is that gradient accumulation requires multiple consecutive forward/backward passes per optimization step thereby increasing the compute overhead and slowing down training. No free lunch!
284
 
285
- ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%208.png)
286
 
287
  **Gradient accumulation allows us to reduce memory of activations which grow linearly with batch size by computing only only partial, micro-batches. There is a small overhead caused by the additional forward and backward passes.**
288
 
@@ -293,6 +281,22 @@ But if you’ve carefully followed, you probably noticed that the forward/backwa
293
 
294
  Let’s get a larger workstation 🖥️ with a couple of GPUs and start investigating our first scaling technique called ***data parallelism** which is just a parallel version of gradient accumulation*.
295
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
296
  # Data Parallelism
297
 
298
  The idea behind data parallelism (DP) is to replicate the model on several GPUs (we call the replica's “model instances”) and run forward and backward passes on different micro batches of data in parallel for each GPU, hence the name Data Parallelism.
@@ -432,11 +436,11 @@ While Data Parallelism is very efficient in scaling training, the naive replicat
432
 
433
  This approach is organized into three possible optimization stage of ZeRO:
434
 
435
- ZeRO-1: optimizer state partitioning ($P_{os}$)
436
 
437
- ZeRO-2: optimizer state + gradient partitioning ($P_{os + g}$)
438
 
439
- ZeRO-3 (also called FSDP for “Fully-Sharded Data Parallelism”): optimizer state + gradient + parameter partitioning ($P_{os + g + p}$)
440
 
441
  > Note: You might be missing the activations among the things we can shard. Since each DP replica of the model receives a different microbatch the activations on each DP rank also differ so they are not duplicated and thus can’t be sharded!
442
  >
@@ -677,41 +681,48 @@ These operations "f" and "f*" are called conjugate pairs because they complement
677
 
678
  For sequence parallelism (SP), we use different operations labeled "g" and "g*". Specifically, we avoid using all-reduce in the SP region since that would require gathering the full activations and increase our peak memory usage, defeating the purpose of SP.
679
 
680
- Which is why we call “f” and “f*” conjugate (as explained in [https://arxiv.org/pdf/2205.05198](https://arxiv.org/pdf/2205.05198)). For sequence parallelism, in “g*”, we don’t want to all-reduce activations in the “SP” region as that would increase our peak memory usage.
681
-
682
- ![The MLP block in TP+SP: After matrix multiplications Z1*B1 and Z1*B2, we need to restore the hidden dimension. Instead of using all-reduce like in vanilla TP, we use reduce-scatter along the sequence dimension to maintain sharding. This keeps activations sharded in both TP and SP regions.](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2031.png)
683
 
684
- The MLP block in TP+SP: After matrix multiplications Z1*B1 and Z1*B2, we need to restore the hidden dimension. Instead of using all-reduce like in vanilla TP, we use reduce-scatter along the sequence dimension to maintain sharding. This keeps activations sharded in both TP and SP regions.
685
 
686
- Let's look at the MLP block in detail:
 
 
687
 
688
- 1. After the matrix multiplications Z1*B1 and Z1*B2, we need to restore the hidden dimension for correctness (this is part of row-linear)
689
- 2. However, instead of using all-reduce like in vanilla TP, we use reduce-scatter along the sequence dimension. We can do this because the SP region don’t need the full sequence dimension.
690
- 3. To optimize memory usage, we overlap the computation of Z1*B1 with the reduce-scatter operation. This keeps our peak memory usage low - the activation shape never exceeds (seq*mbs*hidden_dim/tp) at any point.
691
 
692
- This approach maintains sharding in both TP and SP regions while ensuring mathematical correctness.
 
 
 
693
 
694
- Instead, after the “Z1 B1” and “Z1 B2” matmuls we restore the hidden_dim (see row-linear section above), we need to reduce in (g*) to ensure correctness, because that’s part of the row-linear, but instead of all-reduce we scatter along sequence dimension, since SP region is independent along this dimension, so that we keep activations sharded in both TP and SP regions.
 
 
 
 
 
 
695
 
696
- Notice that at the moment where we restore the hidden_dim for “W1” and “W2”, we already have full activations with shape (seq*mbs*hidden_dim), so in practice we overlap the compute of “Z1 B1” with the *reduce-scatter* to always keep a max activation shape of seq*mbs*hidden_dim/tp
697
-
698
- Does that mean that SP incurs more communication than TP? Well, yes and no. In the forward of a vanilla TP we had two all-reduce per transformer block, and in SP we have two all-gather and two reduce-scatter per transformer block. So SP does twice the number of communication operations as TP. But since an all-reduce operation can be broken down into to an all-gather + reduce-scatter (see in [TODO: Appendix link]) they’re actually equivalent in terms of communication. Same reasoning for backward as we just use the conjugate of each operation (no-op ↔ allreduce and allgather ↔ reducescatter).
699
 
700
  It’s a bit difficult to keep track of all the parts that are sharded differently in TP and TP/SP - believe us, we find it hard to map as well so we made this small table to summarize how the activations (aka `hidden_states` ) shape change across hidden dimension h and sequence dimension s during a forward pass:
701
 
702
  | Region | Vanilla TP | TP with SP |
703
  | --- | --- | --- |
704
- | Enter TP (Column Linear) | hidden dimension: sharded (weight_out is sharded)
705
- sequence dimension: unchanged | h: sharded (weight_out is sharded)
706
  s: **all-gather** to full |
707
  | TP Region | h: sharded
708
- s: unchanged | h: sharded
709
  s: full |
710
  | Exit TP (Row Linear) | h: full (weight_out is full + **all-reduce** for correctness)
711
- s: unchanged | h: full (weight_out is full + **reduce-scatter** for correctness)
712
  s: **reduce-scatter** to sharded |
713
  | SP Region | h: full
714
- s: unchanged | h: full
715
  s: sharded |
716
 
717
  And for the embedding layer
@@ -722,18 +733,16 @@ And for the embedding layer
722
  s: unchanged | h: full (weight_out is full + **reduce-scatter** for correctness)
723
  s: **reduce-scatter** to sharded |
724
 
725
- You can find an example of implementation of both column and row linear TP in pictotron:
726
- [https://github.com/huggingface/picotron/blob/main/picotron/tensor_parallel/tensor_parallel.py](https://github.com/huggingface/picotron/blob/main/picotron/tensor_parallel/tensor_parallel.py)
727
-
728
- TODO: everything after here is still a bit a mess (comment by leandro)
729
 
730
- Experimentally, TP+SP reduces the memory requirements while keeping a minimal compute overhead as we can see here, typically reducing by 30-50% activation memories
 
731
 
732
- ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2032.png)
733
 
734
  If you’ve been paying close attention, you’ll notice that we’re talking about 4 comms ops **IN EACH LAYER** (2 for Attention and 2 for MLP), as shown here for the MLP region:
735
 
736
- ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2033.png)
737
 
738
  Besides the fact that TP requires communication in each layer, it also can’t easily be overlapped with compute, which makes throughput heavily dependent on the communication bandwidth. This is why TP is usually done only within a node (TP≤8)
739
 
@@ -746,17 +755,17 @@ TODO: remove, Profiling:
746
 
747
  - TP
748
 
749
- ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2034.png)
750
 
751
  - Seq Parall
752
 
753
- ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2035.png)
754
 
755
  Allreduce takes almost double the duration (900us) of reducescatter and allgather (500us)
756
 
757
  Let’s compare throughput as we scale TP and TP/SP for a 3B model:
758
 
759
- ![Impact of combined Tensor and Sequence Parallelism (TP/SP) on model performance and memory utilization: when scaling both TP and SP together, there's a trade-off between computational efficiency (left) and memory capacity (right). While higher parallelism degrees reduce per-GPU throughput, they enable processing of significantly larger batch sizes by reducing the activation memory.](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2036.png)
760
 
761
  Impact of combined Tensor and Sequence Parallelism (TP/SP) on model performance and memory utilization: when scaling both TP and SP together, there's a trade-off between computational efficiency (left) and memory capacity (right). While higher parallelism degrees reduce per-GPU throughput, they enable processing of significantly larger batch sizes by reducing the activation memory.
762
 
@@ -782,7 +791,7 @@ With Tensor Parallelism and Sequence Parallelism, we can reduce the memory requi
782
 
783
  Even if we use full recomputation of the activations, which comes at a heavy compute overhead (30%), we still need to hold in memory some activations at the layer boundaries which scale linearly with sequence length:
784
 
785
- ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2037.png)
786
 
787
  Can we apply similar ideas to our sequence parallelism approach but inside in the modules where we apply Tensor Parallelism already, thereby also reducing the effect of sequence length? Yes, it’s time to talk about Context Parallelism, which you will find quite intuitive after all we’ve already convered.
788
 
@@ -790,7 +799,7 @@ Can we apply similar ideas to our sequence parallelism approach but inside in th
790
 
791
  The idea of Context Parallelism is quite simple; just like Sequence Parallelism, we’ll split the input along the sequence dimension but we now apply this splitting along the full model, instead of only the sequence parallel regions of the model. Our focus here will be to reduce the activation memory footprint by splitting the long sequences, complementing parallelism strategies like TP which target the hidden dimension of the model.
792
 
793
- ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2038.png)
794
 
795
  Splitting the sequence doesn't affect most modules like MLP and LayerNorm, where each token is processed independently. It also doesn’t require expensive communication like TP, as only the inputs are split and not the weight matrices. Just as in data parallelism, after computing the gradients, an all-reduce operation is initiated to synchronize the gradients across the context parallelism group.
796
 
@@ -823,7 +832,7 @@ With this animation, it’s also immediately clear why the authors chose to call
823
 
824
  There is one big problem though which is that a naive implementation of Ring Attention lead to some strong imbalance between GPU streaming from the shape of the causal attention matrix. Let’s take a real look at what is happening in the SoftMax computation by considering the attention score matrix with the causal attention mask:
825
 
826
- ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2039.png)
827
 
828
  The SoftMax is computed row-wise, which means whenever a GPU has received all the tokens of a row it can be computed. We see that GPU1 can immediately compute it as it starts with tokens 1-4 and GPU1 actually doesn’t need to receive any information from any other GPUs. However, GPU2 will need to wait for the second round to also receive 1-4 and thus have all values for tokens 1-8. Also, GPU1 seems to perform much less work than all the other GPUs.
829
 
@@ -833,17 +842,17 @@ Let’s see if we can balance our computations better:
833
 
834
  We need a better way to distribute the input sequences. This can be achieved by assigning the tokens not purely sequential to the GPUs but by mixing the ordering a bit such that we have a good mix of early and late tokens on each GPU. This approach is called [Zig-Zag attention](https://arxiv.org/pdf/2311.09431) and in this new arrangement, the attention mask will show an even distribution of computation but if you count the number of colored squares, you’ll see that the computation is now balanced across all GPUs.
835
 
836
- ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2040.png)
837
 
838
  At the same time we’ll also see that in order to complete all rows, each GPU will need information from all the other GPUs.
839
 
840
  We have two general ways to overlap computation and communication, either by performing a general all-gather, regrouping all the KV on each GPUs at the same time (in a Zero-3 type of way) or we gather them one-by-one from each GPU to each GPU as needed:
841
 
842
- ![Context Parallelism using AllGather implementation](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2041.png)
843
 
844
  Context Parallelism using AllGather implementation
845
 
846
- ![Context Parallelism using All-to-All implementation](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2042.png)
847
 
848
  Context Parallelism using All-to-All implementation
849
 
@@ -855,13 +864,13 @@ TODO: add links to megatronlm(AllGather) and deepspeed(All2All) implementations
855
 
856
  In the TP section we saw that if we try to scale Tensor parallelism past the number of GPUs per single node (typically 4 or 8) we hit a lower bandwidth network called “inter-node connection” which can quite strongly impair our performances. We can see this clearly on e.g. the all-reduce operation when we perform it across several nodes:
857
 
858
- ![Inter-node communication bandwidth measurements across different node counts, showing median (lines) and 5th-95th percentile ranges (shaded areas) for AllReduce, AllGather and ReduceScatter operations.](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2043.png)
859
 
860
  Inter-node communication bandwidth measurements across different node counts, showing median (lines) and 5th-95th percentile ranges (shaded areas) for AllReduce, AllGather and ReduceScatter operations.
861
 
862
  Sequence and context parallelism can help for long sequences but don’t help much if sequence length is not the root cause of our memory issues but rather the size of the model itself. For large model (70B+), the size of the weights alone can already push past the limits of the 4-8 GPUs on a single node. We can solve this issue by summoning the fourth (and last) parallelism dimension: “pipeline parallelism”.
863
 
864
- ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2044.png)
865
 
866
  Pipeline Parallelism is conceptually very simple –we’ll simply spread the layers of our model across GPUs – but the devil lies in implementing it efficiently. Let’s dive in it!
867
 
@@ -875,7 +884,7 @@ But maybe you start feeling a glimpse of the troubles to come: “sequentially
875
 
876
  Indeed reader! The main challenge in pipeline parallelism will be how to efficiently circumvent the sequential nature of PP to keep our GPU busy at all times and avoid having one GPU computing while the others are waiting. Here is how our GPU utilization is looking when doing a naive and simple forward and backward pass through the model where the numbers indicate the model layers:
877
 
878
- ![An example of Pipeline parallelism for a model with 16 layers distributed across 4 GPUs. The numbers correspond to the layer IDs.](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2045.png)
879
 
880
  An example of Pipeline parallelism for a model with 16 layers distributed across 4 GPUs. The numbers correspond to the layer IDs.
881
 
@@ -895,7 +904,7 @@ Thankfully, various pipeline parallelism schemes have been designed to reduce th
895
 
896
  Let’s take a first tool out of our toolbox and think about splitting our batch into smaller bit-sized portions which can be processed in parallel or almost, like we did before in data parallel for instance. Now when the second GPU is busy processing micro-batch 1, the first GPU can already start processing micro-batch 2. Here is a schedule using 8 micro-batches:
897
 
898
- ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2046.png)
899
 
900
  > Note: before the numbers indicated the layers but in all pipeline parallel plots from now including this one it indicates a microbatch. You can think of each square here to contain several layers as seen in the previous figure.
901
  >
@@ -954,11 +963,11 @@ Since the memory explosion is triggered by the activation we store for the backw
954
 
955
  This schedule is called **one-forward-one-backward** **(1F1B)** as the middle/steady state involves alternatively performing one forward and one backward pass. The general idea is to start performing the backward pass as soon as possible. The schedule looks like this:
956
 
957
- ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2047.png)
958
 
959
  The bubble still has the same size so our training efficiency is not significantly improved. However we only need to store activations for $p$ micro-batches instead of $m$ which quite reduce the activation memory explosion we had in the AFAB schedule. As a consequence we can add more microbatches which then will actually reduce the bubble.
960
 
961
- ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2048.png)
962
 
963
  A major complexity of this setup, visible on the above graph is how forward and backward passes are not cleanly consecutive anymore but performed in parallel across devices. This means we will have to schedule the switch from forward to backward passes independently on each device instead of in a simple and common central training loop as usual.
964
 
@@ -1047,7 +1056,7 @@ Up to now we’ve sliced our model naively along the model depth dimensions, loc
1047
 
1048
  This can be seen in general as a kind of “looping pipeline” where a micro-batch will move in circles from one GPU to the next as it goes through the forward pass through the model.
1049
 
1050
- ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2049.png)
1051
 
1052
  As a consequence we see additional communications happening as the model goes several times through each GPU for the same computation that previously just took one pass. However, each forward and backward pass is divided by a factor of $v$, where $v$ is the number of stages or model chunks per GPUs as we are able to better interleave forward and backward passes.
1053
 
@@ -1058,13 +1067,13 @@ $$
1058
 
1059
  So we can now decrease the bubble by adding microbatches and interleaved stages, but note that quantitatively, the amount of communication also increases by 𝑣 so it’s a trade off. In the following plot you can see several configurations for a PP setup with $p=8$, where the special case of $m=1, v=1$ corresponds to naive pipeline parallelism and the configurations with $v=1$ are AFAB or 1F1B setups and $v \neq 1$ are interleaved configurations.
1060
 
1061
- ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2050.png)
1062
 
1063
  Scheduling also becomes more complex here as we need to decide on a GPU whether we are prioritizing at a given moment earlier micro-batches meaning that we close the forward and backward loops as fast as possible (so called “depth-first”, i.e. prioritizing getting batches out of the model as fast as possible) or we prioritize to first complete the forward passes of all microbatches in the queue before going over to backward passes (so called “breadth-first” i.e. prioritizing filling in the pipeline as much as possible). This is explained in details in [https://arxiv.org/abs/2211.05953](https://arxiv.org/pdf/2211.05953).
1064
 
1065
  You now have all the elements to understand the pipeline parallelism approach in Llama 3.1 which is using a one-forward-one-backward setup with interleaved stages and a priority setting tuneable between depth-first and bread-first.
1066
 
1067
- ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2051.png)
1068
 
1069
  However, we haven’t reached the end of possible pipeline schedules and recently some methods have been proposed to reduce the bubble to virtually zero! Peaked your curiosity? Let’s have a look!
1070
 
@@ -1074,17 +1083,17 @@ There are even more sophisticated ways to reduce the bubble more and reached clo
1074
 
1075
  Let’s very quickly see how this can work by detailing briefly the [ZeroBubble](https://arxiv.org/abs/2401.10241) work which is a precursor to DualPipe. The base observation of ZeroBubble is that a backward through a matrix multiplication involve actually two separated operations: backward for the inputs (B) and the backward for the weights (W):
1076
 
1077
- ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2052.png)
1078
 
1079
 
1080
 
1081
- ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2053.png)
1082
 
1083
  While the output of B, the backward pass for the input, is necessary for performing the backward pass of the lower layers, the backward pass of the weights, W, is not necessary for the rest of the backward pass and generally only need to be performed before the optimiser step. This means W can be flexibly scheduled anywhere after the corresponding B of the same stage. This allows for strategic placement of W to fill the pipeline bubbles. The ZB-H2 schedule on the top right is an example of (theoretical) schedule with zero bubble taking advantage for this fine-grained decomposition.
1084
 
1085
  DeepSeek’s DualPipe propose an extension of this decomposed approach to the case of two stream propagating from both sides of the PP ranks and being interleaved to minimize even further idle time in the GPUs are displayed in the following scheduling graph
1086
 
1087
- ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2054.png)
1088
 
1089
  The ZeroBubble and DualPipe schedules are a bit too complex for us to give here code snippets but you should start to have a general idea of the concepts involved. In practice, optimizing these schedules requires careful measurements of the time for each operations followed by a scheduling algorithm able to find the most optimal allocation of time given the constrains. See for instance in the [ZeroBubble](https://arxiv.org/abs/2401.10241) paper for a discussion of the heuristics and algorithms to perform such a scheduling.
1090
 
@@ -1096,7 +1105,7 @@ Mixture-of-expert models have gained some traction with models such as Mixtral o
1096
 
1097
  So whereas Context parallelism
1098
 
1099
- ![[https://arxiv.org/pdf/2407.06204](https://arxiv.org/pdf/2407.06204)](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2055.png)
1100
 
1101
  [https://arxiv.org/pdf/2407.06204](https://arxiv.org/pdf/2407.06204)
1102
 
@@ -1136,7 +1145,7 @@ Combining ZeRO-3 and TP doesn’t raise any specific issues except how to organi
1136
 
1137
  # How to Find the Best Training Configuration
1138
 
1139
- ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2056.png)
1140
 
1141
  We’ve now covered all the parallelism techniques that are actually used to distribute and training larger models. There remain a general question: which ones should we choose and which ones are best combined? We touched a little bit on this at the end of the last section but in this section we will walk through the decision process step by step.
1142
 
@@ -1158,7 +1167,7 @@ Overall, most training schedule past a certain size of the models wil tend to co
1158
 
1159
  Let’s try synthesize the decision process into a relatively simple tree structure:
1160
 
1161
- ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2057.png)
1162
 
1163
  To explain briefly, data parallelism is the most efficient method, and you should always prioritize it when memory is not a concern. If communication is not a concern and you can keep the BS/GPU at a big enough value to make good use of the GPU MatMul, ZeRO is an easy method to remove memory bottlenecks and stay close to a simple DP implementation. However, on larger clusters you’ll probably be able to make efficient use for more 4D parallelism. In this case, starting with tensor parallelism is the most direct way to reduce memory usage and is generally faster than pipeline parallelism within a single node(8 GPUs). However, in scenarios with long contexts, the primary memory usage will tend to shifts from model weights, gradients, and optimizer states to activation values. In such cases, context parallelism becomes more beneficial than pipeline parallelism. Note that this is not an exact recipe and you should think of this more as a starting point of hyperparameters to run your own benchmarks. For instance sometimes TP mixed with PP can be more efficient, even if TP<8 and ZeRO-1/2 can make sense to mix in with 4D parallelism as well.
1164
 
@@ -1182,13 +1191,13 @@ Generally, GPUs have a very hierarchical organization. In this primer we’ll ke
1182
 
1183
  On the compute side, GPUs consist of an array of compute units called **Streaming Multiprocessors** (SM). Each SM contains and controls a set of streaming processors, also known as cores. For example, an Nvidia H100 GPU has 132 SMs with 128 cores per SM, resulting in a total of 16,896 cores (see [https://resources.nvidia.com/en-us-tensor-core](https://resources.nvidia.com/en-us-tensor-core) for details), each capable of handling multiple threads simultaneously.
1184
 
1185
- ![Original figure from [https://blog.codingconfessions.com/p/gpu-computing](https://blog.codingconfessions.com/p/gpu-computing).](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2058.png)
1186
 
1187
  Original figure from [https://blog.codingconfessions.com/p/gpu-computing](https://blog.codingconfessions.com/p/gpu-computing).
1188
 
1189
  The memory side is also highly hierarchical with several layers of cache and memory: **Registers** are the smallest units and are private to the threads during executions, **Shared Memory** and **L1 cache are** shared between the threads running on a single SM, higher up is the **L2 cache** shared by all SMs, finally there is the **Global Memory** which is the largest memory on the GPU (the advertised 80 GB for a H100 for instance) but also the slowest to access and query.
1190
 
1191
- ![Original figure from [https://www.youtube.com/watch?v=ZQKMZIP3Fzg](https://www.youtube.com/watch?v=ZQKMZIP3Fzg)](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2059.png)
1192
 
1193
  Original figure from [https://www.youtube.com/watch?v=ZQKMZIP3Fzg](https://www.youtube.com/watch?v=ZQKMZIP3Fzg)
1194
 
@@ -1198,11 +1207,11 @@ A piece of code running on a core of the GPU is called a **kernel**. It can be w
1198
 
1199
  To run the kernel, you will also need a specific code part (called **host code**) which is executed on the **CPU**/host and will take care of preparing data allocations and loading data and code.
1200
 
1201
- ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2060.png)
1202
 
1203
  Figure 5: Host code for a CUDA kernel for adding two vectors from [https://blog.codingconfessions.com/p/gpu-computing](https://blog.codingconfessions.com/p/gpu-computing)
1204
 
1205
- ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2061.png)
1206
 
1207
  Figure 6: Device code containing the definition of the vector addition kernel from [https://blog.codingconfessions.com/p/gpu-computing](https://blog.codingconfessions.com/p/gpu-computing)
1208
 
@@ -1241,7 +1250,7 @@ def elu(x, alpha=1.0):
1241
 
1242
  The distinction between the compiled and non-compiled versions is striking, especially given that we only added a single decorator. This remarkable difference is illustrated in the graph below (N is the number of columns) :
1243
 
1244
- ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2062.png)
1245
 
1246
  However, if this performance increase is insufficient, you can consider implementing Triton kernels. As a starting point, you can take a look at the triton kernel generated by `@torch.compile` . To do so, you simply need to set the environment variable `TORCH_LOGS` to “output_code” :
1247
 
@@ -1300,7 +1309,7 @@ Here, `tl.program_id(0)` provides a unique block ID, that we use to determine wh
1300
 
1301
  When we benchmark the generated kernel using `triton.testing.Benchmark` we have the following performance :
1302
 
1303
- ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2063.png)
1304
 
1305
  This standalone kernel demonstrates superior performance with smaller sizes compared to `@torch.compile` but this is likely here just an artifact from the compilation time of torch. compile. In any case, instead of starting from scratch, we can focus on optimizing this generated kernel, saving us time in the process.
1306
 
@@ -1340,17 +1349,17 @@ __global__ void matmul_naive(int M, int N, int K, const float *A, const float *B
1340
 
1341
  Here’s an excellent visualization of the kernel from this fantastic [blogpost](https://siboehm.com/articles/22/CUDA-MMM) :
1342
 
1343
- ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2064.png)
1344
 
1345
  However, when profiling this kernel with a tool like `ncu`, we can see issues, including low memory throughput and uncoalesced memory accesses.
1346
 
1347
- ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2065.png)
1348
 
1349
- ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2066.png)
1350
 
1351
  The reason for this is that in this kernel, two threads in the same block with Thread IDs `(0, 0)` and `(1, 0)` (which will end up in the same warp) will both load from the same column of matrix `B` but different rows of matrix `A`. Since matrix elements are stored in row-major order (meaning each row's elements are in consecutive memory addresses, as shown in the figure below), in the first iteration with `i = 0`, thread `(0, 0)` will load $A_{0,0}$, and thread `(1, 0)` will load $A_{1,0}$. These elements are not stored close to each other in memory, and this misalignment repeats across all iterations along the shared dimension, preventing memory accesses from being coalesced.
1352
 
1353
- ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2067.png)
1354
 
1355
  To improve our kernel we can change the way the coordinates x and y are calculated like the following :
1356
 
@@ -1371,7 +1380,7 @@ Instead of using a 2D block, we switch to a 1D block and redefine how we determi
1371
 
1372
  When we profile our new kernel, we notice that the warning about uncoalesced memory accesses has disappeared, and **the GPU's memory throughput has increased by approximately 10 times**.
1373
 
1374
- ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2068.png)
1375
 
1376
  We also notice that the execution time of the kernel **decreases by 10x** !
1377
 
@@ -1385,7 +1394,7 @@ In matrix multiplication for example, each thread in a block may need elements f
1385
 
1386
  In the tiling approach, each iteration involves all threads within a block cooperatively loading two tiles—one from matrix A and another from matrix B —into shared memory. Specifically, threads load a tile of matrix A (of size `BLOCK_SIZE_M` by `BLOCK_SIZE_K`) and a tile of matrix B (of size `BLOCK_SIZE_K` by `BLOCK_SIZE_N`). Once the tiles are in shared memory, the threads perform matrix multiplication on these tiles, enabling efficient computation since all necessary data is quickly accessible. The results of the tile multiplication are stored in an accumulation matrix that holds intermediate results. After each iteration, the results from the current tile multiplication are added to this accumulation matrix, continuing until all tiles from both matrices have been processed.
1387
 
1388
- ![From [https://cnugteren.github.io/tutorial/pages/page4.html](https://cnugteren.github.io/tutorial/pages/page4.html)](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2069.png)
1389
 
1390
  From [https://cnugteren.github.io/tutorial/pages/page4.html](https://cnugteren.github.io/tutorial/pages/page4.html)
1391
 
@@ -1429,7 +1438,7 @@ When benchmarking this kernel using ncu, we noticed that the memory throughput i
1429
 
1430
  The tiling technique has significantly improved the performance of our kernel. However, when analyzing the warp states which quantify how many cycles were spent in each state, we observe the following:
1431
 
1432
- ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2070.png)
1433
 
1434
  The meaning of the states can be found in the [Profiling Guide](https://docs.nvidia.com/nsight-compute/ProfilingGuide/index.html#metrics-reference), specifically in the **Warp Stall Reasons** section. There we can read that :
1435
 
@@ -1454,13 +1463,13 @@ Flash attention is a technique pioneered by [Tri Dao](https://tridao.me) that op
1454
 
1455
  A basic implementation of the attention mechanism involve a lot of transfer between memory and workers. It requires materializing the S and P matrices in HBM which means that the results need to be sent to HBM and then back to SRAM for the next computations:
1456
 
1457
- ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2071.png)
1458
 
1459
  Since bandwidth is much lower in HBM this introduces a severe bottleneck in the attention computation. Can we do better? Tri Dao says yes!
1460
 
1461
  The key element is to compute the S matrices in small pieces which can fit in the smaller shared memory of the SM. But we can do even better and avoid materializing the very large S matrix all together in favor of keeping only the necessary statistics for computing the normalization factor of the softmax. So we can compute part of $O$ directly in one computation in SRAM rather than moving intermediate results back and forth. In this case, not even do we make use of the shared memory but we also release the memory bottleneck resulting from materializing one of the largest activation matrices in the model (at long context length), the attention matrix.
1462
 
1463
- ![From the FLASH-ATTENTION paper ([https://arxiv.org/pdf/2205.14135](https://arxiv.org/pdf/2205.14135))](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2072.png)
1464
 
1465
  From the FLASH-ATTENTION paper ([https://arxiv.org/pdf/2205.14135](https://arxiv.org/pdf/2205.14135))
1466
 
@@ -1518,13 +1527,13 @@ The principle of floating point numbers can be easily illustrated by recalling t
1518
 
1519
  Reducing the total number of bits comes at a price (no free lunch here either), but we have some control over how to pay. Either we can sacrifice more bits on the mantissa or exponent. For this reason there exist also two float8 formats, named according to exponent and mantissa, to flexibly choose the most appropriate format. We can look at the possible range of numbers for each format:
1520
 
1521
- ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2073.png)
1522
 
1523
  We can see that float32 spans 80 orders of magnitude and float16 sacrifices a lot of range while bfloat16 maintains the full range. The two float8 formats reduce the range even further where e5e2 can maintain float16 range and e4m3 has an even smaller ranger.
1524
 
1525
  How come some format are able to maintain the range and other not? Let’s investigate the resolution by plotting 10,000 points between 1 and 2. Each point will be rounded to the nearest representable number in each format:
1526
 
1527
- ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2074.png)
1528
 
1529
  We can see here that bfloat16 maintained the range of float32 over float16 but did this with the cost of sacrificing more precision. In case of float8 the situation is even more dire as e4m3 can represent 7 and e5m2 only 3 number on the interval 1-2.
1530
 
@@ -1554,7 +1563,7 @@ Even if we perfectly overlap communication with computation, we always eventuall
1554
 
1555
  Recent research - including [FP8-LM](https://arxiv.org/abs/2310.18313), [torchao](https://github.com/pytorch/ao/tree/main/torchao/float8#torchaofloat8), and [DeepSeek-V3](https://arxiv.org/abs/2412.19437) - has demonstrated the potential of FP8 training for large-scale models. Still, FP8 pretraining introduces a significant challenge: stability. At lower precision, numerical instability often leads to loss divergence, making it difficult to match the accuracy of higher-precision training.
1556
 
1557
- ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2075.png)
1558
 
1559
  As [[Wortsman et al.]](https://arxiv.org/abs/2309.14322) observed, instability increases as learning rates rise for a fixed model size, making FP8 pretraining particularly tricky.
1560
 
@@ -1692,7 +1701,7 @@ Throughout this blogpost we’ll scale LLM training from one to hundreds of GPUs
1692
 
1693
  The general setup is that we have a number of independent nodes which could be CPU cores, GPUs, or compute nodes. Each performs some computation and then we want to communicate the result or parts of it to the other nodes for the next computation step (t+1).
1694
 
1695
- ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2076.png)
1696
 
1697
  Maybe we need to send the result from one node to all other nodes, or we need to sum all the intermediate results from each node to report the overall result. Usually, there is one node with an elevated status that plays a central role, here denoted with `root` that is the target or source of some operations. Let’s start with one of the simplest primitives: a broadcast operation.
1698
 
@@ -1700,7 +1709,7 @@ Maybe we need to send the result from one node to all other nodes, or we need to
1700
 
1701
  A very common pattern is that you have some data on Node 1 and you want to share it with all the other nodes so they can do some computation with the data. The broadcast operation does just that:
1702
 
1703
- ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2077.png)
1704
 
1705
  Collective operations are natively provided by PyTorch so we can easily write a small example that demonstrates how broadcasting works. We first need to initialize a process group with `dist.initi_process_group` which sets up the communication backend (we’ll talk about NCCL later), it determines how many workers (aka nodes) exists and assigns a rank to each one (which we can get with `dist.get_rank`). Finally, it establishes a connection between the workers.
1706
 
@@ -1745,7 +1754,7 @@ Great, seems like it works as expected. Note that the rank messages can be print
1745
 
1746
  Reduce patterns are among the most fundamental patterns in distributed data processing. The idea is that you want to combine the data present on each node through a function `f()` which can be for instance summation or averaging. In the Reduce paradigm the result is sent to the root node only, whereas in the AllReduce case the result is broadcasted to all nodes:
1747
 
1748
- ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2078.png)
1749
 
1750
  Of course no magic “free flying” node that can perform this operation and generally each node does a partial computation in a ring or tree structure of the nodes. Here is a simple example: let’s say we need to compute a sum of numbers on each nodes and our nodes are connected in a ring pattern. The first node sends its number to a neighbour which adds its number to the received number before forwarding it to the next neighbour. At the end of a round along the ring of nodes, the first node will receive the total sum.
1751
 
@@ -1840,7 +1849,7 @@ Now let’s turn to our next distributed communication operation. In many real c
1840
 
1841
  Gather and AllGather are quite similar to the Broadcast in that they allow distributing data among node without modification. The main difference to Broadcast is that there is not one value we need to share from one node to all other nodes but each node has an individual chunk of data that we want to either gather all data on one node (in case of Gather) or gather all data on all nodes (in the case of AllGather). A picture being worth 1000 words, let’s take a look:
1842
 
1843
- ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2079.png)
1844
 
1845
  Note that the dashed lines indicate that some data actually doesn’t move at all (since it’s already present on the node).
1846
 
@@ -1914,7 +1923,7 @@ As the name subtly suggests, the goal of the Scatter operation is to take data o
1914
 
1915
  The ReduceScatter pattern is slightly more complex: imagine you apply an operation like in the Reduce case but instead of moving the result to just one node we also distribute it evenly to all nodes:
1916
 
1917
- ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2080.png)
1918
 
1919
  The Scatter operation is written in code as the opposite of the Gather: instead of preparing a list of tensors as target we prepare the source data as a list of tensors we want to distribute. We also need to specify the `src`:
1920
 
@@ -1989,7 +1998,7 @@ We now have seen the main building block of distributed operations but before we
1989
 
1990
  The Barrier is a simple operation to synchronize all nodes. A barrier is not lifted until all nodes have reached it. Then only are they allowed to continue with further computations:
1991
 
1992
- ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2081.png)
1993
 
1994
  We can easily simulate delayed nodes by setting up a different sleep time on each node and see how long it takes for all of them to pass the barrier:
1995
 
@@ -2102,7 +2111,7 @@ print(p.key_averages().table(sort_by="cuda_time_total", row_limit=8))
2102
 
2103
  This would print aggregated profiling results sorted by the total CUDA time, and the output would be:
2104
 
2105
- ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2082.png)
2106
 
2107
  You can also try to inspect the trace as we previously mentioned on `chrome://tracing/`
2108
 
@@ -2111,7 +2120,7 @@ You can also try to inspect the trace as we previously mentioned on `chrome://t
2111
 
2112
  After zooming in, you can observe the flow of operations when calling `layer_norm` in this trace:
2113
 
2114
- ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2083.png)
2115
 
2116
  The sequence begins in the CPU (the upper section) with `aten::layer_norm`, progressing to `aten::native_layer_norm`, and then transitioning to `cudaLaunchKernel`. From there, we move on to the GPU, where the `vectorized_layer_norm_kernel` kernel is called.
2117
 
@@ -2132,7 +2141,7 @@ ncu --set full -o output python layer_norm.py
2132
 
2133
  and open the file `output.ncu-rep` with Nsight Compute, you will have a view that looks like this :
2134
 
2135
- ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2084.png)
2136
 
2137
  With clear warnings about compute and memory utilization, and how to make the kernel better in balancing compute and memory and achieve maximal occupancy.
2138
 
@@ -2214,7 +2223,7 @@ $$
2214
 
2215
  The chain rule applies here since the loss (L) depends directly on the output (Y). This equation is telling us that to get the gradient of the loss with respect to our input (dL/dX), we multiply the gradient of the loss with respect to the output (dL/dY) by our weight matrix (W).
2216
 
2217
- ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2085.png)
2218
 
2219
  Likewise, we can use chain rule to compute the gradient w.r.t to the weight:
2220
 
@@ -2222,7 +2231,7 @@ $$
2222
  \frac{dL}{dW} = \frac{dL}{dY} \frac{dY}{dW} = \frac{dL}{dY} X
2223
  $$
2224
 
2225
- ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2086.png)
2226
 
2227
  Here is a snippet of code to clarify all the concepts above:
2228
 
@@ -2301,13 +2310,13 @@ if __name__ == "__main__":
2301
  example_column_row_linear()
2302
  ```
2303
 
2304
- ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2087.png)
2305
 
2306
  **TODO** add these illustrations somewhere? I found them helpful:
2307
 
2308
- ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2088.png)
2309
 
2310
- ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2089.png)
2311
 
2312
  ## A3: ZeRO-R
2313
 
@@ -2452,7 +2461,7 @@ def example_gelu():
2452
 
2453
  ### Interconnect
2454
 
2455
- ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2090.png)
2456
 
2457
  ## How to profile your code
2458
 
@@ -2488,7 +2497,7 @@ with profiler: # step 2. Wrap the training with profiler
2488
 
2489
  After running this code, you will find `*.trace.json` files under the `profiler_out_dir`. To visualize the results, the easiest way is to open Google Chrome, go to `chrome://tracing/`, and drag the file into it. This will allow you to view the profiling results. To get more details, we invite you to check out the amazing [**tutorial](https://pytorch.org/tutorials/intermediate/tensorboard_profiler_tutorial.html)** created by PyTorch.
2490
 
2491
- ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2091.png)
2492
 
2493
  ## Formulas for compute / comms the balanhe balance
2494
 
@@ -2601,7 +2610,7 @@ for a single microbatch:
2601
 
2602
  ```
2603
 
2604
- ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2092.png)
2605
 
2606
  ## Integrating Context Parallelism with TP/SP
2607
 
@@ -2614,7 +2623,7 @@ In order to integrate CP with TP/SP we just have to:
2614
  3. **Replace standard attention with ring attention:** During the forward pass, each TP rank relies on the ring attention to compute the correct attention results during both the forward and backward passes. So all CP ranks within TP=0 for example need to all-gather the full KV sequence and calculate attention, but we store only the KV of a sequence chunk to reduce memory activations by CP.
2615
 
2616
  ![TP=0 has GPU0 and GPU2 whereas CP=0 has GPU0 and GPU1
2617
- TP/SP shards the Q/K/V heads across TP ranks (in this example GPU0 and GPU2 get QKV_green, and GPU2 and GPU3 get QKV_blue), since each head can operate independently from others, we can apply ring attention within each TP rank](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2093.png)
2618
 
2619
  TP=0 has GPU0 and GPU2 whereas CP=0 has GPU0 and GPU1
2620
  TP/SP shards the Q/K/V heads across TP ranks (in this example GPU0 and GPU2 get QKV_green, and GPU2 and GPU3 get QKV_blue), since each head can operate independently from others, we can apply ring attention within each TP rank
@@ -2627,7 +2636,7 @@ In fact, given an activation value of shape$[ \text{batch\_size}, \text{sequence
2627
 
2628
  However, through extensive experimentation, we identified two effective training recipes that allowed us to **fully pretrain a 1B LLaMA model in FP8**, covering both the forward and backward passes, while using an FP8 optimizer. More importantly, our approach successfully matched LLaMA-2’s pretraining learning rate. The result?
2629
 
2630
- ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2094.png)
2631
 
2632
  A loss curve that perfectly matches mixed-precision bfloat16 (bfloat16 with FP32 master weights as the baseline). We successfully tested this to train a 1B LLaMA up to 100B tokens and a 7B LLaMA up to 25B tokens.
2633
 
@@ -2665,11 +2674,11 @@ Let’s take a moment to look better at this fundamental tool for distributed tr
2665
 
2666
  **Non-overlapping:** If we don't overlap the communication and computation, each computation (represented by the purple block) can only begin after the communication (green block) is complete and total time is the sum of communication and computation.
2667
 
2668
- ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2095.png)
2669
 
2670
  **Overlapping:** However, if we manage to launch communication and computation in parallel, we eliminate the waiting time! Now we can see that the computation (green block) is launched immediately, one after the other. In this case the total time is *only* the sum of computations.
2671
 
2672
- ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2096.png)
2673
 
2674
  Context parallelism has helped us going past the intra-node interconnect bottleneck, which blocked us from scaling TP across nodes. However, as you probably noted, it only helps reducing the memory constraints if the activation memory dominates the memory budget due to long sequences. What if we are not working on super long sequences and the model weights alone are too big for a single node?
2675
 
 
120
 
121
  So how can I quickly determine memory usage from these variable? One simple way is to do this empirically and just measure it.
122
 
 
 
 
 
 
 
 
 
 
 
 
 
123
  ### Memory profiling a training step
124
 
125
  Using this snippet [TODO: link to appendix A5], we can understand how memory is allocated throughout training. We can see that memory utilization is not a static thing but varies a lot during training and during a training step:
 
223
 
224
  The general idea behind ***activation recomputation** –*also called ***gradient checkpointing*** or ***rematerialization**– *****is to discard some activations during the forward pass to save memory and spend some extra compute to recompute these on the fly during the backward pass. Without recomputation, we store every hidden state between two learnable operations (e.g. FF, LayerNorm etc.), such that we can use them during the backward pass to compute gradients. When we use recomputation we typically will only store activations at a few key points along the model architecture, discard the rest of activations and recompute them on the fly during the backward pass from the nearest saved activations, basically performing again a sub-part of the forward pass to trade of memory for compute. It generally looks like this:
225
 
226
+ ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%205.png)
227
 
228
  There are several strategies to select key activations to store:
229
 
 
270
 
271
  Gradient accumulation allows us to effectively increase our batch size up to infinity (and beyond!) while the memory footprint stays constant. Gradient accumulation is also compatible with activation recomputation for further memory reduction. One drawback however, is that gradient accumulation requires multiple consecutive forward/backward passes per optimization step thereby increasing the compute overhead and slowing down training. No free lunch!
272
 
273
+ ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%206.png)
274
 
275
  **Gradient accumulation allows us to reduce memory of activations which grow linearly with batch size by computing only only partial, micro-batches. There is a small overhead caused by the additional forward and backward passes.**
276
 
 
281
 
282
  Let’s get a larger workstation 🖥️ with a couple of GPUs and start investigating our first scaling technique called ***data parallelism** which is just a parallel version of gradient accumulation*.
283
 
284
+ TODO: intro for this
285
+
286
+ ## torch.profiler
287
+
288
+ ![**Overlapped backward pass (stream 7) and gradients accumulation (stream 28) means we start the optimizer step as soon as the backward pass is done**](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%207.png)
289
+
290
+ **Overlapped backward pass (stream 7) and gradients accumulation (stream 28) means we start the optimizer step as soon as the backward pass is done**
291
+
292
+ ![In this naive approach we see a long AllReduce operation (stream 28) happening to sync all gradients after the backward pass is over, after which comes the optimizer step. GPUs stay idle while the communication happens.](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%208.png)
293
+
294
+ In this naive approach we see a long AllReduce operation (stream 28) happening to sync all gradients after the backward pass is over, after which comes the optimizer step. GPUs stay idle while the communication happens.
295
+
296
+ ![**Overlapped backward pass (stream 7) and gradients accumulation (stream 28) means we start the optimizer step as soon as the backward pass is done**](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%207.png)
297
+
298
+ **Overlapped backward pass (stream 7) and gradients accumulation (stream 28) means we start the optimizer step as soon as the backward pass is done**
299
+
300
  # Data Parallelism
301
 
302
  The idea behind data parallelism (DP) is to replicate the model on several GPUs (we call the replica's “model instances”) and run forward and backward passes on different micro batches of data in parallel for each GPU, hence the name Data Parallelism.
 
436
 
437
  This approach is organized into three possible optimization stage of ZeRO:
438
 
439
+ ZeRO-1: optimizer state partitioning
440
 
441
+ ZeRO-2: optimizer state + gradient partitioning
442
 
443
+ ZeRO-3 (also called FSDP for “Fully-Sharded Data Parallelism”): optimizer state + gradient + parameter partitioning
444
 
445
  > Note: You might be missing the activations among the things we can shard. Since each DP replica of the model receives a different microbatch the activations on each DP rank also differ so they are not duplicated and thus can’t be sharded!
446
  >
 
681
 
682
  For sequence parallelism (SP), we use different operations labeled "g" and "g*". Specifically, we avoid using all-reduce in the SP region since that would require gathering the full activations and increase our peak memory usage, defeating the purpose of SP.
683
 
684
+ ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/c65c0745-6dda-4f5c-a7ae-0092e50cdc0f.png)
 
 
685
 
686
+ So what is actually happening here? As a famous LLM would say, let’s take it step-by-step:
687
 
688
+ 1. **Initial LayerNorm (SP Region)**
689
+ - Input tensors X1* and X2* (b,s/2,h) enter LayerNorm, already split across sequence dimension
690
+ - Each GPU computes LayerNorm independently on its sequence chunk and give Y1* and Y2*
691
 
692
+ 2. **First Transition (SP TP)**
693
+ - "g" operation (all-gather) combines Y1* and Y2* back to full sequence length
694
+ - Restores Y (b,s,h) since column linear layer needs full hidden dimension h
695
 
696
+ 3. **First Linear Layer (TP Region)**
697
+ - A1 is a column-linear layer, so it splits Y along the hidden dimension
698
+ - GeLU is applied independently on each GPU
699
+ - Z1* is (b,s,h/2)
700
 
701
+ 4. **Second Linear Layer (TP Region)**
702
+ - B1 is a row-linear layer, so it restores the hidden dimension
703
+ - W1 is (b,s,h)
704
+
705
+ 5. **Final Transition (TP → SP)**
706
+ - "g*" operation (reduce-scatter) which reduces for previous row-linear correctness while scattering along sequence dimension
707
+ - W1* is (b,s/2,h)
708
 
709
+ A key advantage of sequence parallelism is that it reduces the maximum activation size we need to store. In tensor parallelism alone, we had to store activations of shape (b,s,h) at various points. However, with sequence parallelism, the maximum activation size is reduced to $\frac{b \cdot s \cdot h}{tp}$ since we always either split along the sequence or hidden dimensions.
 
 
710
 
711
  It’s a bit difficult to keep track of all the parts that are sharded differently in TP and TP/SP - believe us, we find it hard to map as well so we made this small table to summarize how the activations (aka `hidden_states` ) shape change across hidden dimension h and sequence dimension s during a forward pass:
712
 
713
  | Region | Vanilla TP | TP with SP |
714
  | --- | --- | --- |
715
+ | Enter TP (Column Linear) | h: sharded (weight_out is sharded)
716
+ s: full | h: sharded (weight_out is sharded)
717
  s: **all-gather** to full |
718
  | TP Region | h: sharded
719
+ s: full | h: sharded
720
  s: full |
721
  | Exit TP (Row Linear) | h: full (weight_out is full + **all-reduce** for correctness)
722
+ s: full | h: full (weight_out is full + **reduce-scatter** for correctness)
723
  s: **reduce-scatter** to sharded |
724
  | SP Region | h: full
725
+ s: full | h: full
726
  s: sharded |
727
 
728
  And for the embedding layer
 
733
  s: unchanged | h: full (weight_out is full + **reduce-scatter** for correctness)
734
  s: **reduce-scatter** to sharded |
735
 
736
+ Does that mean that SP incurs more communication than TP? Well, yes and no. In the forward of a vanilla TP we had two all-reduce per transformer block, and in SP we have two all-gather and two reduce-scatter per transformer block. So SP does twice the number of communication operations as TP. But since an all-reduce operation can be broken down into to an all-gather + reduce-scatter (see in [TODO: Appendix link]) they’re actually equivalent in terms of communication. Same reasoning for backward as we just use the conjugate of each operation (no-op ↔ allreduce and allgather ↔ reducescatter).
 
 
 
737
 
738
+ You can find an example of implementation of both column and row linear TP in picotron:
739
+ [https://github.com/huggingface/picotron/blob/main/picotron/tensor_parallel/tensor_parallel.py](https://github.com/huggingface/picotron/blob/main/picotron/tensor_parallel/tensor_parallel.py)
740
 
741
+ ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2031.png)
742
 
743
  If you’ve been paying close attention, you’ll notice that we’re talking about 4 comms ops **IN EACH LAYER** (2 for Attention and 2 for MLP), as shown here for the MLP region:
744
 
745
+ ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2032.png)
746
 
747
  Besides the fact that TP requires communication in each layer, it also can’t easily be overlapped with compute, which makes throughput heavily dependent on the communication bandwidth. This is why TP is usually done only within a node (TP≤8)
748
 
 
755
 
756
  - TP
757
 
758
+ ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2033.png)
759
 
760
  - Seq Parall
761
 
762
+ ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2034.png)
763
 
764
  Allreduce takes almost double the duration (900us) of reducescatter and allgather (500us)
765
 
766
  Let’s compare throughput as we scale TP and TP/SP for a 3B model:
767
 
768
+ ![Impact of combined Tensor and Sequence Parallelism (TP/SP) on model performance and memory utilization: when scaling both TP and SP together, there's a trade-off between computational efficiency (left) and memory capacity (right). While higher parallelism degrees reduce per-GPU throughput, they enable processing of significantly larger batch sizes by reducing the activation memory.](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2035.png)
769
 
770
  Impact of combined Tensor and Sequence Parallelism (TP/SP) on model performance and memory utilization: when scaling both TP and SP together, there's a trade-off between computational efficiency (left) and memory capacity (right). While higher parallelism degrees reduce per-GPU throughput, they enable processing of significantly larger batch sizes by reducing the activation memory.
771
 
 
791
 
792
  Even if we use full recomputation of the activations, which comes at a heavy compute overhead (30%), we still need to hold in memory some activations at the layer boundaries which scale linearly with sequence length:
793
 
794
+ ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2036.png)
795
 
796
  Can we apply similar ideas to our sequence parallelism approach but inside in the modules where we apply Tensor Parallelism already, thereby also reducing the effect of sequence length? Yes, it’s time to talk about Context Parallelism, which you will find quite intuitive after all we’ve already convered.
797
 
 
799
 
800
  The idea of Context Parallelism is quite simple; just like Sequence Parallelism, we’ll split the input along the sequence dimension but we now apply this splitting along the full model, instead of only the sequence parallel regions of the model. Our focus here will be to reduce the activation memory footprint by splitting the long sequences, complementing parallelism strategies like TP which target the hidden dimension of the model.
801
 
802
+ ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2037.png)
803
 
804
  Splitting the sequence doesn't affect most modules like MLP and LayerNorm, where each token is processed independently. It also doesn’t require expensive communication like TP, as only the inputs are split and not the weight matrices. Just as in data parallelism, after computing the gradients, an all-reduce operation is initiated to synchronize the gradients across the context parallelism group.
805
 
 
832
 
833
  There is one big problem though which is that a naive implementation of Ring Attention lead to some strong imbalance between GPU streaming from the shape of the causal attention matrix. Let’s take a real look at what is happening in the SoftMax computation by considering the attention score matrix with the causal attention mask:
834
 
835
+ ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2038.png)
836
 
837
  The SoftMax is computed row-wise, which means whenever a GPU has received all the tokens of a row it can be computed. We see that GPU1 can immediately compute it as it starts with tokens 1-4 and GPU1 actually doesn’t need to receive any information from any other GPUs. However, GPU2 will need to wait for the second round to also receive 1-4 and thus have all values for tokens 1-8. Also, GPU1 seems to perform much less work than all the other GPUs.
838
 
 
842
 
843
  We need a better way to distribute the input sequences. This can be achieved by assigning the tokens not purely sequential to the GPUs but by mixing the ordering a bit such that we have a good mix of early and late tokens on each GPU. This approach is called [Zig-Zag attention](https://arxiv.org/pdf/2311.09431) and in this new arrangement, the attention mask will show an even distribution of computation but if you count the number of colored squares, you’ll see that the computation is now balanced across all GPUs.
844
 
845
+ ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2039.png)
846
 
847
  At the same time we’ll also see that in order to complete all rows, each GPU will need information from all the other GPUs.
848
 
849
  We have two general ways to overlap computation and communication, either by performing a general all-gather, regrouping all the KV on each GPUs at the same time (in a Zero-3 type of way) or we gather them one-by-one from each GPU to each GPU as needed:
850
 
851
+ ![Context Parallelism using AllGather implementation](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2040.png)
852
 
853
  Context Parallelism using AllGather implementation
854
 
855
+ ![Context Parallelism using All-to-All implementation](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2041.png)
856
 
857
  Context Parallelism using All-to-All implementation
858
 
 
864
 
865
  In the TP section we saw that if we try to scale Tensor parallelism past the number of GPUs per single node (typically 4 or 8) we hit a lower bandwidth network called “inter-node connection” which can quite strongly impair our performances. We can see this clearly on e.g. the all-reduce operation when we perform it across several nodes:
866
 
867
+ ![Inter-node communication bandwidth measurements across different node counts, showing median (lines) and 5th-95th percentile ranges (shaded areas) for AllReduce, AllGather and ReduceScatter operations.](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2042.png)
868
 
869
  Inter-node communication bandwidth measurements across different node counts, showing median (lines) and 5th-95th percentile ranges (shaded areas) for AllReduce, AllGather and ReduceScatter operations.
870
 
871
  Sequence and context parallelism can help for long sequences but don’t help much if sequence length is not the root cause of our memory issues but rather the size of the model itself. For large model (70B+), the size of the weights alone can already push past the limits of the 4-8 GPUs on a single node. We can solve this issue by summoning the fourth (and last) parallelism dimension: “pipeline parallelism”.
872
 
873
+ ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2043.png)
874
 
875
  Pipeline Parallelism is conceptually very simple –we’ll simply spread the layers of our model across GPUs – but the devil lies in implementing it efficiently. Let’s dive in it!
876
 
 
884
 
885
  Indeed reader! The main challenge in pipeline parallelism will be how to efficiently circumvent the sequential nature of PP to keep our GPU busy at all times and avoid having one GPU computing while the others are waiting. Here is how our GPU utilization is looking when doing a naive and simple forward and backward pass through the model where the numbers indicate the model layers:
886
 
887
+ ![An example of Pipeline parallelism for a model with 16 layers distributed across 4 GPUs. The numbers correspond to the layer IDs.](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2044.png)
888
 
889
  An example of Pipeline parallelism for a model with 16 layers distributed across 4 GPUs. The numbers correspond to the layer IDs.
890
 
 
904
 
905
  Let’s take a first tool out of our toolbox and think about splitting our batch into smaller bit-sized portions which can be processed in parallel or almost, like we did before in data parallel for instance. Now when the second GPU is busy processing micro-batch 1, the first GPU can already start processing micro-batch 2. Here is a schedule using 8 micro-batches:
906
 
907
+ ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2045.png)
908
 
909
  > Note: before the numbers indicated the layers but in all pipeline parallel plots from now including this one it indicates a microbatch. You can think of each square here to contain several layers as seen in the previous figure.
910
  >
 
963
 
964
  This schedule is called **one-forward-one-backward** **(1F1B)** as the middle/steady state involves alternatively performing one forward and one backward pass. The general idea is to start performing the backward pass as soon as possible. The schedule looks like this:
965
 
966
+ ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2046.png)
967
 
968
  The bubble still has the same size so our training efficiency is not significantly improved. However we only need to store activations for $p$ micro-batches instead of $m$ which quite reduce the activation memory explosion we had in the AFAB schedule. As a consequence we can add more microbatches which then will actually reduce the bubble.
969
 
970
+ ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2047.png)
971
 
972
  A major complexity of this setup, visible on the above graph is how forward and backward passes are not cleanly consecutive anymore but performed in parallel across devices. This means we will have to schedule the switch from forward to backward passes independently on each device instead of in a simple and common central training loop as usual.
973
 
 
1056
 
1057
  This can be seen in general as a kind of “looping pipeline” where a micro-batch will move in circles from one GPU to the next as it goes through the forward pass through the model.
1058
 
1059
+ ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2048.png)
1060
 
1061
  As a consequence we see additional communications happening as the model goes several times through each GPU for the same computation that previously just took one pass. However, each forward and backward pass is divided by a factor of $v$, where $v$ is the number of stages or model chunks per GPUs as we are able to better interleave forward and backward passes.
1062
 
 
1067
 
1068
  So we can now decrease the bubble by adding microbatches and interleaved stages, but note that quantitatively, the amount of communication also increases by 𝑣 so it’s a trade off. In the following plot you can see several configurations for a PP setup with $p=8$, where the special case of $m=1, v=1$ corresponds to naive pipeline parallelism and the configurations with $v=1$ are AFAB or 1F1B setups and $v \neq 1$ are interleaved configurations.
1069
 
1070
+ ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2049.png)
1071
 
1072
  Scheduling also becomes more complex here as we need to decide on a GPU whether we are prioritizing at a given moment earlier micro-batches meaning that we close the forward and backward loops as fast as possible (so called “depth-first”, i.e. prioritizing getting batches out of the model as fast as possible) or we prioritize to first complete the forward passes of all microbatches in the queue before going over to backward passes (so called “breadth-first” i.e. prioritizing filling in the pipeline as much as possible). This is explained in details in [https://arxiv.org/abs/2211.05953](https://arxiv.org/pdf/2211.05953).
1073
 
1074
  You now have all the elements to understand the pipeline parallelism approach in Llama 3.1 which is using a one-forward-one-backward setup with interleaved stages and a priority setting tuneable between depth-first and bread-first.
1075
 
1076
+ ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2050.png)
1077
 
1078
  However, we haven’t reached the end of possible pipeline schedules and recently some methods have been proposed to reduce the bubble to virtually zero! Peaked your curiosity? Let’s have a look!
1079
 
 
1083
 
1084
  Let’s very quickly see how this can work by detailing briefly the [ZeroBubble](https://arxiv.org/abs/2401.10241) work which is a precursor to DualPipe. The base observation of ZeroBubble is that a backward through a matrix multiplication involve actually two separated operations: backward for the inputs (B) and the backward for the weights (W):
1085
 
1086
+ ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2051.png)
1087
 
1088
 
1089
 
1090
+ ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2052.png)
1091
 
1092
  While the output of B, the backward pass for the input, is necessary for performing the backward pass of the lower layers, the backward pass of the weights, W, is not necessary for the rest of the backward pass and generally only need to be performed before the optimiser step. This means W can be flexibly scheduled anywhere after the corresponding B of the same stage. This allows for strategic placement of W to fill the pipeline bubbles. The ZB-H2 schedule on the top right is an example of (theoretical) schedule with zero bubble taking advantage for this fine-grained decomposition.
1093
 
1094
  DeepSeek’s DualPipe propose an extension of this decomposed approach to the case of two stream propagating from both sides of the PP ranks and being interleaved to minimize even further idle time in the GPUs are displayed in the following scheduling graph
1095
 
1096
+ ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2053.png)
1097
 
1098
  The ZeroBubble and DualPipe schedules are a bit too complex for us to give here code snippets but you should start to have a general idea of the concepts involved. In practice, optimizing these schedules requires careful measurements of the time for each operations followed by a scheduling algorithm able to find the most optimal allocation of time given the constrains. See for instance in the [ZeroBubble](https://arxiv.org/abs/2401.10241) paper for a discussion of the heuristics and algorithms to perform such a scheduling.
1099
 
 
1105
 
1106
  So whereas Context parallelism
1107
 
1108
+ ![[https://arxiv.org/pdf/2407.06204](https://arxiv.org/pdf/2407.06204)](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2054.png)
1109
 
1110
  [https://arxiv.org/pdf/2407.06204](https://arxiv.org/pdf/2407.06204)
1111
 
 
1145
 
1146
  # How to Find the Best Training Configuration
1147
 
1148
+ ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2055.png)
1149
 
1150
  We’ve now covered all the parallelism techniques that are actually used to distribute and training larger models. There remain a general question: which ones should we choose and which ones are best combined? We touched a little bit on this at the end of the last section but in this section we will walk through the decision process step by step.
1151
 
 
1167
 
1168
  Let’s try synthesize the decision process into a relatively simple tree structure:
1169
 
1170
+ ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2056.png)
1171
 
1172
  To explain briefly, data parallelism is the most efficient method, and you should always prioritize it when memory is not a concern. If communication is not a concern and you can keep the BS/GPU at a big enough value to make good use of the GPU MatMul, ZeRO is an easy method to remove memory bottlenecks and stay close to a simple DP implementation. However, on larger clusters you’ll probably be able to make efficient use for more 4D parallelism. In this case, starting with tensor parallelism is the most direct way to reduce memory usage and is generally faster than pipeline parallelism within a single node(8 GPUs). However, in scenarios with long contexts, the primary memory usage will tend to shifts from model weights, gradients, and optimizer states to activation values. In such cases, context parallelism becomes more beneficial than pipeline parallelism. Note that this is not an exact recipe and you should think of this more as a starting point of hyperparameters to run your own benchmarks. For instance sometimes TP mixed with PP can be more efficient, even if TP<8 and ZeRO-1/2 can make sense to mix in with 4D parallelism as well.
1173
 
 
1191
 
1192
  On the compute side, GPUs consist of an array of compute units called **Streaming Multiprocessors** (SM). Each SM contains and controls a set of streaming processors, also known as cores. For example, an Nvidia H100 GPU has 132 SMs with 128 cores per SM, resulting in a total of 16,896 cores (see [https://resources.nvidia.com/en-us-tensor-core](https://resources.nvidia.com/en-us-tensor-core) for details), each capable of handling multiple threads simultaneously.
1193
 
1194
+ ![Original figure from [https://blog.codingconfessions.com/p/gpu-computing](https://blog.codingconfessions.com/p/gpu-computing).](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2057.png)
1195
 
1196
  Original figure from [https://blog.codingconfessions.com/p/gpu-computing](https://blog.codingconfessions.com/p/gpu-computing).
1197
 
1198
  The memory side is also highly hierarchical with several layers of cache and memory: **Registers** are the smallest units and are private to the threads during executions, **Shared Memory** and **L1 cache are** shared between the threads running on a single SM, higher up is the **L2 cache** shared by all SMs, finally there is the **Global Memory** which is the largest memory on the GPU (the advertised 80 GB for a H100 for instance) but also the slowest to access and query.
1199
 
1200
+ ![Original figure from [https://www.youtube.com/watch?v=ZQKMZIP3Fzg](https://www.youtube.com/watch?v=ZQKMZIP3Fzg)](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2058.png)
1201
 
1202
  Original figure from [https://www.youtube.com/watch?v=ZQKMZIP3Fzg](https://www.youtube.com/watch?v=ZQKMZIP3Fzg)
1203
 
 
1207
 
1208
  To run the kernel, you will also need a specific code part (called **host code**) which is executed on the **CPU**/host and will take care of preparing data allocations and loading data and code.
1209
 
1210
+ ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2059.png)
1211
 
1212
  Figure 5: Host code for a CUDA kernel for adding two vectors from [https://blog.codingconfessions.com/p/gpu-computing](https://blog.codingconfessions.com/p/gpu-computing)
1213
 
1214
+ ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2060.png)
1215
 
1216
  Figure 6: Device code containing the definition of the vector addition kernel from [https://blog.codingconfessions.com/p/gpu-computing](https://blog.codingconfessions.com/p/gpu-computing)
1217
 
 
1250
 
1251
  The distinction between the compiled and non-compiled versions is striking, especially given that we only added a single decorator. This remarkable difference is illustrated in the graph below (N is the number of columns) :
1252
 
1253
+ ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2061.png)
1254
 
1255
  However, if this performance increase is insufficient, you can consider implementing Triton kernels. As a starting point, you can take a look at the triton kernel generated by `@torch.compile` . To do so, you simply need to set the environment variable `TORCH_LOGS` to “output_code” :
1256
 
 
1309
 
1310
  When we benchmark the generated kernel using `triton.testing.Benchmark` we have the following performance :
1311
 
1312
+ ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2062.png)
1313
 
1314
  This standalone kernel demonstrates superior performance with smaller sizes compared to `@torch.compile` but this is likely here just an artifact from the compilation time of torch. compile. In any case, instead of starting from scratch, we can focus on optimizing this generated kernel, saving us time in the process.
1315
 
 
1349
 
1350
  Here’s an excellent visualization of the kernel from this fantastic [blogpost](https://siboehm.com/articles/22/CUDA-MMM) :
1351
 
1352
+ ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2063.png)
1353
 
1354
  However, when profiling this kernel with a tool like `ncu`, we can see issues, including low memory throughput and uncoalesced memory accesses.
1355
 
1356
+ ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2064.png)
1357
 
1358
+ ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2065.png)
1359
 
1360
  The reason for this is that in this kernel, two threads in the same block with Thread IDs `(0, 0)` and `(1, 0)` (which will end up in the same warp) will both load from the same column of matrix `B` but different rows of matrix `A`. Since matrix elements are stored in row-major order (meaning each row's elements are in consecutive memory addresses, as shown in the figure below), in the first iteration with `i = 0`, thread `(0, 0)` will load $A_{0,0}$, and thread `(1, 0)` will load $A_{1,0}$. These elements are not stored close to each other in memory, and this misalignment repeats across all iterations along the shared dimension, preventing memory accesses from being coalesced.
1361
 
1362
+ ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2066.png)
1363
 
1364
  To improve our kernel we can change the way the coordinates x and y are calculated like the following :
1365
 
 
1380
 
1381
  When we profile our new kernel, we notice that the warning about uncoalesced memory accesses has disappeared, and **the GPU's memory throughput has increased by approximately 10 times**.
1382
 
1383
+ ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2067.png)
1384
 
1385
  We also notice that the execution time of the kernel **decreases by 10x** !
1386
 
 
1394
 
1395
  In the tiling approach, each iteration involves all threads within a block cooperatively loading two tiles—one from matrix A and another from matrix B —into shared memory. Specifically, threads load a tile of matrix A (of size `BLOCK_SIZE_M` by `BLOCK_SIZE_K`) and a tile of matrix B (of size `BLOCK_SIZE_K` by `BLOCK_SIZE_N`). Once the tiles are in shared memory, the threads perform matrix multiplication on these tiles, enabling efficient computation since all necessary data is quickly accessible. The results of the tile multiplication are stored in an accumulation matrix that holds intermediate results. After each iteration, the results from the current tile multiplication are added to this accumulation matrix, continuing until all tiles from both matrices have been processed.
1396
 
1397
+ ![From [https://cnugteren.github.io/tutorial/pages/page4.html](https://cnugteren.github.io/tutorial/pages/page4.html)](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2068.png)
1398
 
1399
  From [https://cnugteren.github.io/tutorial/pages/page4.html](https://cnugteren.github.io/tutorial/pages/page4.html)
1400
 
 
1438
 
1439
  The tiling technique has significantly improved the performance of our kernel. However, when analyzing the warp states which quantify how many cycles were spent in each state, we observe the following:
1440
 
1441
+ ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2069.png)
1442
 
1443
  The meaning of the states can be found in the [Profiling Guide](https://docs.nvidia.com/nsight-compute/ProfilingGuide/index.html#metrics-reference), specifically in the **Warp Stall Reasons** section. There we can read that :
1444
 
 
1463
 
1464
  A basic implementation of the attention mechanism involve a lot of transfer between memory and workers. It requires materializing the S and P matrices in HBM which means that the results need to be sent to HBM and then back to SRAM for the next computations:
1465
 
1466
+ ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2070.png)
1467
 
1468
  Since bandwidth is much lower in HBM this introduces a severe bottleneck in the attention computation. Can we do better? Tri Dao says yes!
1469
 
1470
  The key element is to compute the S matrices in small pieces which can fit in the smaller shared memory of the SM. But we can do even better and avoid materializing the very large S matrix all together in favor of keeping only the necessary statistics for computing the normalization factor of the softmax. So we can compute part of $O$ directly in one computation in SRAM rather than moving intermediate results back and forth. In this case, not even do we make use of the shared memory but we also release the memory bottleneck resulting from materializing one of the largest activation matrices in the model (at long context length), the attention matrix.
1471
 
1472
+ ![From the FLASH-ATTENTION paper ([https://arxiv.org/pdf/2205.14135](https://arxiv.org/pdf/2205.14135))](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2071.png)
1473
 
1474
  From the FLASH-ATTENTION paper ([https://arxiv.org/pdf/2205.14135](https://arxiv.org/pdf/2205.14135))
1475
 
 
1527
 
1528
  Reducing the total number of bits comes at a price (no free lunch here either), but we have some control over how to pay. Either we can sacrifice more bits on the mantissa or exponent. For this reason there exist also two float8 formats, named according to exponent and mantissa, to flexibly choose the most appropriate format. We can look at the possible range of numbers for each format:
1529
 
1530
+ ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2072.png)
1531
 
1532
  We can see that float32 spans 80 orders of magnitude and float16 sacrifices a lot of range while bfloat16 maintains the full range. The two float8 formats reduce the range even further where e5e2 can maintain float16 range and e4m3 has an even smaller ranger.
1533
 
1534
  How come some format are able to maintain the range and other not? Let’s investigate the resolution by plotting 10,000 points between 1 and 2. Each point will be rounded to the nearest representable number in each format:
1535
 
1536
+ ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2073.png)
1537
 
1538
  We can see here that bfloat16 maintained the range of float32 over float16 but did this with the cost of sacrificing more precision. In case of float8 the situation is even more dire as e4m3 can represent 7 and e5m2 only 3 number on the interval 1-2.
1539
 
 
1563
 
1564
  Recent research - including [FP8-LM](https://arxiv.org/abs/2310.18313), [torchao](https://github.com/pytorch/ao/tree/main/torchao/float8#torchaofloat8), and [DeepSeek-V3](https://arxiv.org/abs/2412.19437) - has demonstrated the potential of FP8 training for large-scale models. Still, FP8 pretraining introduces a significant challenge: stability. At lower precision, numerical instability often leads to loss divergence, making it difficult to match the accuracy of higher-precision training.
1565
 
1566
+ ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2074.png)
1567
 
1568
  As [[Wortsman et al.]](https://arxiv.org/abs/2309.14322) observed, instability increases as learning rates rise for a fixed model size, making FP8 pretraining particularly tricky.
1569
 
 
1701
 
1702
  The general setup is that we have a number of independent nodes which could be CPU cores, GPUs, or compute nodes. Each performs some computation and then we want to communicate the result or parts of it to the other nodes for the next computation step (t+1).
1703
 
1704
+ ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2075.png)
1705
 
1706
  Maybe we need to send the result from one node to all other nodes, or we need to sum all the intermediate results from each node to report the overall result. Usually, there is one node with an elevated status that plays a central role, here denoted with `root` that is the target or source of some operations. Let’s start with one of the simplest primitives: a broadcast operation.
1707
 
 
1709
 
1710
  A very common pattern is that you have some data on Node 1 and you want to share it with all the other nodes so they can do some computation with the data. The broadcast operation does just that:
1711
 
1712
+ ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2076.png)
1713
 
1714
  Collective operations are natively provided by PyTorch so we can easily write a small example that demonstrates how broadcasting works. We first need to initialize a process group with `dist.initi_process_group` which sets up the communication backend (we’ll talk about NCCL later), it determines how many workers (aka nodes) exists and assigns a rank to each one (which we can get with `dist.get_rank`). Finally, it establishes a connection between the workers.
1715
 
 
1754
 
1755
  Reduce patterns are among the most fundamental patterns in distributed data processing. The idea is that you want to combine the data present on each node through a function `f()` which can be for instance summation or averaging. In the Reduce paradigm the result is sent to the root node only, whereas in the AllReduce case the result is broadcasted to all nodes:
1756
 
1757
+ ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2077.png)
1758
 
1759
  Of course no magic “free flying” node that can perform this operation and generally each node does a partial computation in a ring or tree structure of the nodes. Here is a simple example: let’s say we need to compute a sum of numbers on each nodes and our nodes are connected in a ring pattern. The first node sends its number to a neighbour which adds its number to the received number before forwarding it to the next neighbour. At the end of a round along the ring of nodes, the first node will receive the total sum.
1760
 
 
1849
 
1850
  Gather and AllGather are quite similar to the Broadcast in that they allow distributing data among node without modification. The main difference to Broadcast is that there is not one value we need to share from one node to all other nodes but each node has an individual chunk of data that we want to either gather all data on one node (in case of Gather) or gather all data on all nodes (in the case of AllGather). A picture being worth 1000 words, let’s take a look:
1851
 
1852
+ ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2078.png)
1853
 
1854
  Note that the dashed lines indicate that some data actually doesn’t move at all (since it’s already present on the node).
1855
 
 
1923
 
1924
  The ReduceScatter pattern is slightly more complex: imagine you apply an operation like in the Reduce case but instead of moving the result to just one node we also distribute it evenly to all nodes:
1925
 
1926
+ ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2079.png)
1927
 
1928
  The Scatter operation is written in code as the opposite of the Gather: instead of preparing a list of tensors as target we prepare the source data as a list of tensors we want to distribute. We also need to specify the `src`:
1929
 
 
1998
 
1999
  The Barrier is a simple operation to synchronize all nodes. A barrier is not lifted until all nodes have reached it. Then only are they allowed to continue with further computations:
2000
 
2001
+ ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2080.png)
2002
 
2003
  We can easily simulate delayed nodes by setting up a different sleep time on each node and see how long it takes for all of them to pass the barrier:
2004
 
 
2111
 
2112
  This would print aggregated profiling results sorted by the total CUDA time, and the output would be:
2113
 
2114
+ ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2081.png)
2115
 
2116
  You can also try to inspect the trace as we previously mentioned on `chrome://tracing/`
2117
 
 
2120
 
2121
  After zooming in, you can observe the flow of operations when calling `layer_norm` in this trace:
2122
 
2123
+ ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2082.png)
2124
 
2125
  The sequence begins in the CPU (the upper section) with `aten::layer_norm`, progressing to `aten::native_layer_norm`, and then transitioning to `cudaLaunchKernel`. From there, we move on to the GPU, where the `vectorized_layer_norm_kernel` kernel is called.
2126
 
 
2141
 
2142
  and open the file `output.ncu-rep` with Nsight Compute, you will have a view that looks like this :
2143
 
2144
+ ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2083.png)
2145
 
2146
  With clear warnings about compute and memory utilization, and how to make the kernel better in balancing compute and memory and achieve maximal occupancy.
2147
 
 
2223
 
2224
  The chain rule applies here since the loss (L) depends directly on the output (Y). This equation is telling us that to get the gradient of the loss with respect to our input (dL/dX), we multiply the gradient of the loss with respect to the output (dL/dY) by our weight matrix (W).
2225
 
2226
+ ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2084.png)
2227
 
2228
  Likewise, we can use chain rule to compute the gradient w.r.t to the weight:
2229
 
 
2231
  \frac{dL}{dW} = \frac{dL}{dY} \frac{dY}{dW} = \frac{dL}{dY} X
2232
  $$
2233
 
2234
+ ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2085.png)
2235
 
2236
  Here is a snippet of code to clarify all the concepts above:
2237
 
 
2310
  example_column_row_linear()
2311
  ```
2312
 
2313
+ ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2086.png)
2314
 
2315
  **TODO** add these illustrations somewhere? I found them helpful:
2316
 
2317
+ ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2087.png)
2318
 
2319
+ ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2088.png)
2320
 
2321
  ## A3: ZeRO-R
2322
 
 
2461
 
2462
  ### Interconnect
2463
 
2464
+ ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2089.png)
2465
 
2466
  ## How to profile your code
2467
 
 
2497
 
2498
  After running this code, you will find `*.trace.json` files under the `profiler_out_dir`. To visualize the results, the easiest way is to open Google Chrome, go to `chrome://tracing/`, and drag the file into it. This will allow you to view the profiling results. To get more details, we invite you to check out the amazing [**tutorial](https://pytorch.org/tutorials/intermediate/tensorboard_profiler_tutorial.html)** created by PyTorch.
2499
 
2500
+ ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2090.png)
2501
 
2502
  ## Formulas for compute / comms the balanhe balance
2503
 
 
2610
 
2611
  ```
2612
 
2613
+ ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2091.png)
2614
 
2615
  ## Integrating Context Parallelism with TP/SP
2616
 
 
2623
  3. **Replace standard attention with ring attention:** During the forward pass, each TP rank relies on the ring attention to compute the correct attention results during both the forward and backward passes. So all CP ranks within TP=0 for example need to all-gather the full KV sequence and calculate attention, but we store only the KV of a sequence chunk to reduce memory activations by CP.
2624
 
2625
  ![TP=0 has GPU0 and GPU2 whereas CP=0 has GPU0 and GPU1
2626
+ TP/SP shards the Q/K/V heads across TP ranks (in this example GPU0 and GPU2 get QKV_green, and GPU2 and GPU3 get QKV_blue), since each head can operate independently from others, we can apply ring attention within each TP rank](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2092.png)
2627
 
2628
  TP=0 has GPU0 and GPU2 whereas CP=0 has GPU0 and GPU1
2629
  TP/SP shards the Q/K/V heads across TP ranks (in this example GPU0 and GPU2 get QKV_green, and GPU2 and GPU3 get QKV_blue), since each head can operate independently from others, we can apply ring attention within each TP rank
 
2636
 
2637
  However, through extensive experimentation, we identified two effective training recipes that allowed us to **fully pretrain a 1B LLaMA model in FP8**, covering both the forward and backward passes, while using an FP8 optimizer. More importantly, our approach successfully matched LLaMA-2’s pretraining learning rate. The result?
2638
 
2639
+ ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2093.png)
2640
 
2641
  A loss curve that perfectly matches mixed-precision bfloat16 (bfloat16 with FP32 master weights as the baseline). We successfully tested this to train a 1B LLaMA up to 100B tokens and a 7B LLaMA up to 25B tokens.
2642
 
 
2674
 
2675
  **Non-overlapping:** If we don't overlap the communication and computation, each computation (represented by the purple block) can only begin after the communication (green block) is complete and total time is the sum of communication and computation.
2676
 
2677
+ ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2094.png)
2678
 
2679
  **Overlapping:** However, if we manage to launch communication and computation in parallel, we eliminate the waiting time! Now we can see that the computation (green block) is launched immediately, one after the other. In this case the total time is *only* the sum of computations.
2680
 
2681
+ ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2095.png)
2682
 
2683
  Context parallelism has helped us going past the intra-node interconnect bottleneck, which blocked us from scaling TP across nodes. However, as you probably noted, it only helps reducing the memory constraints if the activation memory dominates the memory budget due to long sequences. What if we are not working on super long sequences and the model weights alone are too big for a single node?
2684
 
dist/bibliography.bib CHANGED
@@ -367,4 +367,40 @@ url = {https://github.com/meta-llama/llama3/blob/main/MODEL_CARD.md}
367
  archivePrefix={arXiv},
368
  primaryClass={cs.CL},
369
  url={https://arxiv.org/abs/2412.19437},
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
370
  }
 
367
  archivePrefix={arXiv},
368
  primaryClass={cs.CL},
369
  url={https://arxiv.org/abs/2412.19437},
370
+ }
371
+ @misc{mccandlish2018largebatchtraining,
372
+ title={An Empirical Model of Large-Batch Training},
373
+ author={Sam McCandlish and Jared Kaplan and Dario Amodei and OpenAI Dota Team},
374
+ year={2018},
375
+ eprint={1812.06162},
376
+ archivePrefix={arXiv},
377
+ primaryClass={cs.LG},
378
+ url={https://arxiv.org/abs/1812.06162},
379
+ }
380
+ @misc{micikevicius2018mixedprecisiontraining,
381
+ title={Mixed Precision Training},
382
+ author={Paulius Micikevicius and Sharan Narang and Jonah Alben and Gregory Diamos and Erich Elsen and David Garcia and Boris Ginsburg and Michael Houston and Oleksii Kuchaiev and Ganesh Venkatesh and Hao Wu},
383
+ year={2018},
384
+ eprint={1710.03740},
385
+ archivePrefix={arXiv},
386
+ primaryClass={cs.AI},
387
+ url={https://arxiv.org/abs/1710.03740},
388
+ }
389
+ @misc{rajbhandari2020zero,
390
+ title={ZeRO: Memory Optimizations Toward Training Trillion Parameter Models},
391
+ author={Samyam Rajbhandari and Jeff Rasley and Olatunji Ruwase and Yuxiong He},
392
+ year={2020},
393
+ eprint={1910.02054},
394
+ archivePrefix={arXiv},
395
+ primaryClass={cs.LG},
396
+ url={https://arxiv.org/abs/1910.02054},
397
+ }
398
+ @misc{korthikanti2022recomputation,
399
+ title={Reducing Activation Recomputation in Large Transformer Models},
400
+ author={Vijay Korthikanti and Jared Casper and Sangkug Lym and Lawrence McAfee and Michael Andersch and Mohammad Shoeybi and Bryan Catanzaro},
401
+ year={2022},
402
+ eprint={2205.05198},
403
+ archivePrefix={arXiv},
404
+ primaryClass={cs.LG},
405
+ url={https://arxiv.org/abs/2205.05198},
406
  }
dist/index.html CHANGED
@@ -184,6 +184,7 @@
184
  <p>As you can see, there’s a lot of ground to be covered. Before getting into the trenches of distributed training let’s take a quick high level look on we’ll cover in the post.</p>
185
 
186
  <h2>TL;DR</h2>
 
187
  <p>This book is very extensive so we decide to start with a very general overview of how you can think about distributed training. At a high level, the key challenge in scaling LLM training is to make a training step (forward/backward/optimizer step) with a large batch size the fastest possible.</p>
188
  <p>When scaling up models and input batches, we quickly end up in situations where either our target batch size won't fit in memory, or/and the model itself is too large to fit in a single GPU's memory.</p>
189
  <p>To solve this scaling issue we’ll need to carefully evaluate different parallelization strategies and find the optimal balance between three main factors:</p>
@@ -213,18 +214,258 @@
213
 
214
  <h2>First Steps: Training on one GPU</h2>
215
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
216
  <h3>Memory usage in Transformers</h3>
 
 
 
 
 
 
 
 
 
217
 
 
 
 
 
 
 
 
 
 
 
 
218
  <h4>Memory profiling a training step</h4>
219
-
 
 
 
 
 
 
 
 
 
 
 
 
 
220
  <h4>Weights/grads/optimizer states memory</h4>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
221
 
 
 
222
  <h4>Activations memory</h4>
223
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
224
  <h3>Activation recomputation</h3>
225
 
226
- <h3>Gradient accumulation</h3>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
227
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
228
  <h2>Data Parallelism</h2>
229
 
230
  <h4><strong>First optimization:</strong> Overlap gradient synchronization with backward pass</h4>
 
184
  <p>As you can see, there’s a lot of ground to be covered. Before getting into the trenches of distributed training let’s take a quick high level look on we’ll cover in the post.</p>
185
 
186
  <h2>TL;DR</h2>
187
+
188
  <p>This book is very extensive so we decide to start with a very general overview of how you can think about distributed training. At a high level, the key challenge in scaling LLM training is to make a training step (forward/backward/optimizer step) with a large batch size the fastest possible.</p>
189
  <p>When scaling up models and input batches, we quickly end up in situations where either our target batch size won't fit in memory, or/and the model itself is too large to fit in a single GPU's memory.</p>
190
  <p>To solve this scaling issue we’ll need to carefully evaluate different parallelization strategies and find the optimal balance between three main factors:</p>
 
214
 
215
  <h2>First Steps: Training on one GPU</h2>
216
 
217
+ <p>Let’s start by quickly reviewing the very basics of model training before we start to scale to many GPUs. When a model is trained on a single GPU, the training typically consists of three steps: </p>
218
+
219
+ <ol>
220
+ <li>a forward pass which passes inputs through the model to yield its outputs,</li>
221
+ <li>a backward pass to compute the gradients, and</li>
222
+ <li>an optimization step using the gradients to update the parameters</li>
223
+ </ol>
224
+
225
+ <p>It looks generally like this: </p>
226
+ <p><img alt="image.png" src="assets/images/placeholder.png" /></p>
227
+
228
+ <aside>As we’ll see later, these steps may be repeated or intertwined but for now we’ll start simple.</aside>
229
+
230
+ <p>In this figure, the boxes on the top line can be seen as successive layers inside a model (same for the last line). The red boxes are the associated gradients for each of these layers, computed during the backward pass.</p>
231
+
232
+ <p>The batch size (<d-math>bs</d-math>) is one of the important hyper-parameters for model training and affects both model convergence and throughput.</p>
233
+
234
+ <p>If the batch size is too small, gradients will tend to be noisy and the model may not be able to converge to the most optimal performance, on the contrary it can be useful in early training to navigate quickly in the training landscape. On the other hand, a batch size too large will make less use of each training token rendering convergence slower and wasting compute. You can find a nice discussion of this topic in OpenAI’s paper on large batch training<d-cite bibtex-key="mccandlish2018largebatchtraining"></d-cite> or Section 4.2 of MiniMax-01 <a href="https://filecdn.minimax.chat/_Arxiv_MiniMax_01_Report.pdf">technical report</a>.</p>
235
+
236
+ <aside>For instance, during DeepSeek-V3/R1 training “the batch size is gradually increased from 3072 to 15360 in the training of the first 469B tokens, and then keeps 15360 in the remaining training”.</aside>
237
+
238
+ <p>Batch size also affects the time it takes to train on a given text dataset: a small batch size will require more optimizer steps to train on the same amount of samples. Optimizer steps are costly (in compute time) and the total time to train will thus increase compared to a larger batch size. This being said, note that the batch size can often be adjusted quite largely around the optimal batch size without major impact to the performance of the model, i.e. the sensitivity of final model performances to the exact batch size value is usually rather low around the optimal batch size.</p>
239
+
240
+ <p>In the LLM pretraining community, batch sizes are commonly reported in terms of tokens rather than in number of samples (<d-math>bst</d-math> = Batch Size Tokens), this makes training numbers generally independent of the exact input sequence length used during the training.</p>
241
+
242
+ <p>In the simplest case, training on a single machine, the <d-math>bs</d-math> (in samples) and <d-math>bst</d-math> can be computed from the model input sequence length (seq) as follows :</p>
243
+
244
+ <aside><p>From here onward we’ll show the formulas for the batch size in terms of samples but you can always get its token-unit counterpart by multiplying it with the sequence length.
245
+ </aside>
246
+
247
+ <d-math block>
248
+ bst=bs *seq
249
+ </d-math>
250
+
251
+ <p>A sweet spot for recent LLM training is typically on the order of 4-60 million tokens per batch. However, a typical issue when scaling the training of our model to these large batch sizes is out-of-memory issues, ie. our GPU doesn’t have enough memory.</p>
252
+
253
+ <aside>Note: Llama 1 was trained with a batch size of ~4M tokens for 1.4 trillions tokens while DeepSeek was trained with a batch size of ~60M tokens for 14 trillion tokens.
254
+ </aside>
255
+
256
+ <p><strong>It’s time to tackle our first scaling problem: what if our model starts exploding GPU memory before we’ve reached our target batch size (maybe in some case even when using the lowest possible batch size, <code>bs=1</code>)?</strong></p>
257
+
258
+ <p>Let’s start by quickly understanding what led to our out-of-memory issue in the first place. This will help us gain some useful intuitions for later.</p>
259
+
260
  <h3>Memory usage in Transformers</h3>
261
+
262
+ <p>When training a neural network model, one store several items in memory:</p>
263
+
264
+ <ul>
265
+ <li>Model weights</li>
266
+ <li>Activations needed to compute the gradients</li>
267
+ <li>Model gradients</li>
268
+ <li>Optimizer states</li>
269
+ </ul>
270
 
271
+ <aside >You would think for a model you could compute the memory requirements exactly but there are a few additional memory occupants that makes it hard to be exact:
272
+ <ul>
273
+ <li>CUDA Kernels typically require 1-2 GB of GPU memory, which you can quickly verify by running <code>import torch; torch.ones((1, 1)).to("cuda")</code> and then checking the GPU memory with <code>nvidia-smi</code>.</li>
274
+ <li>Some rest memory usage from buffers, intermediate results and some memory that can’t be used due to fragmentation</li>
275
+ </ul>
276
+ We’ll neglect these last two contributors as they are typically small and constant factors.</aside>
277
+
278
+ <p>These items are stored as tensors which come in different <em>shapes</em> and <em>precisions</em>. The <em>shapes</em> are determined by hyper-parameters such as batch size, sequence length, model hidden dimensions, attention heads, vocabulary size, and potential model sharding as we’ll see later. <em>Precision</em> refers to formats like FP32, BF16, or FP8, which respectively require 4, 2, or 1 byte to store each single value in the tensor.</p>
279
+
280
+ <p>So how can I quickly determine memory usage from these variable? One simple way is to do this empirically and just measure it.</p>
281
+
282
  <h4>Memory profiling a training step</h4>
283
+
284
+ <p>Using this snippet [TODO: link to appendix A5], we can understand how memory is allocated throughout training. We can see that memory utilization is not a static thing but varies a lot during training and during a training step:</p>
285
+
286
+ <p><img alt="llama-1b-memory.png" src="assets/images/placeholder.png" /></p>
287
+
288
+ <p>Clearly the first step looks very different from the subsequent ones, but let’s first have a look at the general anatomy of a step: first the activations increase quickly as we do the forward pass, then during the backward pass the gradients build up and as the backward pass propagates, the stored activations used to compute the gradients are progressively cleared. Finally, we perform the optimization step during which we need all the gradients and then update the optimizer states before we start the next forward pass. </p>
289
+
290
+ <p>Why does the first step looks different: the activations increase quickly and then plateau for a while. In this first step the torch cache allocator does a lot of preparation preparing memory allocations to speed up the subsequent steps so that they don’t require searching for free memory blocks afterwards (see <a href="https://zdevito.github.io/2022/08/04/cuda-caching-allocator.html">Zach’s blog</a>). After the first step we also see the optimizer states appearing which generally offset the memory usage for further training steps.</p>
291
+
292
+ <aside>Ever noticed how sometimes the training succeeds in the first step but then OOMs during the following training steps? This can be explained by the build-up of the optimizer state after the first step.
293
+ </aside>
294
+
295
+ <p>Now that we’ve a first view of memory, let’s see how scaling up training is often a question of maximizing compute efficiency while keeping the memory requirements of these various items (activations, parameters, gradients, optimizer states) within the memory constraints of the GPUs.</p>
296
+
297
  <h4>Weights/grads/optimizer states memory</h4>
298
+
299
+ <p>We can actually pretty easily estimate the memory needed for the model’s weights, gradients and optimizer states.</p>
300
+
301
+ <p>For a simple transformer LLM the number of parameters is given by the <a href="https://michaelwornow.net/2024/01/18/counting-params-in-transformer">following formula</a>:</p>
302
+
303
+ <d-math block>
304
+ N = h * v + L * (12 * h^2 + 13 * h) + 2*h
305
+ </d-math>
306
+
307
+ <aside>We excluded the positional embedding count as rotary embeddings are not learned.</aside>
308
+
309
+ <p>In that equation, <d-math>h</d-math> is the hidden dimension, <d-math>v</d-math> the vocabulary size, and <d-math>L</d-math> the number of layers in the model. Note that looking at the equation we can see that the term that will dominate at large hidden dimensions is the <d-math>h^2</d-math> term since it’s the only one growing quadratically as we scale the parameters.</p>
310
+
311
+ <p>Memory requirements for the parameters and gradients are simply the number of parameters multiplied by the number of bytes per parameter. In good old-fashioned full precision (FP32) training both parameters and gradients require 4 bytes while the optimizer, if we use Adam, requires the momentum and variance to be stored, which adds another two 4 bytes per parameter. In summary:</p>
312
+
313
+ <d-math block>
314
+ \begin{aligned}
315
+ & m_{params} = 4 * N \\
316
+ & m_{grad} = 4 * N \\
317
+ & m_{opt} = (4+4) * N
318
+ \end{aligned}
319
+ </d-math>
320
+
321
+ <p>Now let’s have look how things change if we train with mixed precision<d-cite bibtex-key="micikevicius2018mixedprecisiontraining"></d-cite>. The default nowadays is for mixed precision training is BF16, requires 2 bytes per parameter and gradient as well as an additional copy of the model weights and gradients in FP32, thus 12 bytes per parameter in total. In addition to the parameters and gradient, we need to store the optimizer states: for the Adam optimizer, this requires the momentum and the variance usually stored in FP32 for numerical stability, each using 4 bytes. </p>
322
+
323
+ <aside>See some more details below when we cover the ZeRO methods.</aside>
324
+
325
+ <p>Here’s the summary:</p>
326
+
327
+ <d-math block>
328
+ \begin{aligned}
329
+ & m_{params} = 2 * N \\
330
+ & m_{grad} = 2 * N \\
331
+ & m_{params_fp32} = 4 * N \\
332
+ & m_{opt} = (4+4) * N
333
+ \end{aligned}
334
+ </d-math>
335
+
336
+ <aside>Some librarie store grads in fp32 which would require an additional $m_{params_fp32} = 4 * N$ memory. This is done for example in nanotron, because <code>bf16</code> is lossy for smaller values and we always prioritize stability. See <a href="https://github.com/microsoft/DeepSpeed/issues/1773">this DeepSpeed issue</a> for more information.
337
+ </aside>
338
+
339
+ <p>Interestingly, mixed precision itself doesn’t save overall memory as it just distributes the memory differently across the three components, and in fact adds another 4 bytes over full precision training if we accumulate gradients in FP32. It’s still advantageous as having the model which does the forward/backward in half precision it allows us to (1) use optimized lower precision operations on the GPU which are faster and (2) reduces the activation memory requirements during the forward pass.</p>
340
+
341
+ <p>Let’s get a sense of how much general memory we need for a model (full and mixed precision giving the same overall value):</p>
342
+
343
+ <table>
344
+ <thead>
345
+ <tr>
346
+ <th><strong>Model parameters</strong></th>
347
+ <th><strong>FP32 or BF16 w/o FP32 grad acc</strong></th>
348
+ <th><strong>BF16 w/ FP32 grad acc</strong></th>
349
+ </tr>
350
+ </thead>
351
+ <tbody>
352
+ <tr>
353
+ <td>1B</td>
354
+ <td>16 GB</td>
355
+ <td>20 GB</td>
356
+ </tr>
357
+ <tr>
358
+ <td>7B</td>
359
+ <td>112 GB</td>
360
+ <td>140 GB</td>
361
+ </tr>
362
+ <tr>
363
+ <td>70B</td>
364
+ <td>1120 GB</td>
365
+ <td>1400 GB</td>
366
+ </tr>
367
+ <tr>
368
+ <td>405B</td>
369
+ <td>6480 GB</td>
370
+ <td>8100 GB</td>
371
+ </tr>
372
+ </tbody>
373
+ </table>
374
+
375
+ <aside><p>Using FP8 training instead of BF16 would further decrease the memory usage but it is less stable and a very active research topic (see <a href="https://x.com/xariusrke/status/1826669126955278401">this tweet</a>) and we’ll cover it in more detail later.
376
+ </aside>
377
+
378
+ <p>As we can see, as soon as we reach <strong>7B</strong> (!), weights and optimizer requirements already starts to add up significantly and exceed the size of a typical GPU memory, e.g. 80GB for a H100 GPU.</p>
379
 
380
+ <p>But for now, let’s start with models which still fits in a single GPU, take a look at the other big contributor to our memory budget: the activation memory.</p>
381
+
382
  <h4>Activations memory</h4>
383
 
384
+ <p>Activation memory is a bit more complex to compute than the weights, gradients and optimizer states, in part because it depends on the inputs of the model. If you’re unsure why we even need to store activations for the backward pass, <a href="https://www.determined.ai/blog/act-mem-2">this reference</a> is a good quick refresh. After a careful inspection of how backward pass is computed we can estimate the total memory required for the activations in mixed precision and we arrive at the following equation:</p>
385
+
386
+ <d-math block>
387
+ m_{act} = L<em> seq * bs * h * (34 + \frac{5</em>n_{heads}*seq}{h})</p>
388
+ </d-math>
389
+
390
+ <p>Here <d-math>L</d-math> is the number of layers, <d-math>seq</d-math> the sequence length, <d-math>bs</d-math> the batch size in samples, <d-math>h</d-math> the hidden dimension of the model and <d-math>n_{heads}</d-math> the number of heads.</p>
391
+
392
+ <p>For the exact derivation of the numbers, you can follow this original NVIDIA paper on recomputation <d-cite bibtex-key="korthikanti2022recomputation"></d-cite>, it essentially requires you to do some accounting of all the sizes of intermediate activations between each operation in a transformer layer.</p>
393
+
394
+ <p>An interesting observation here is how the memory is not static for a given model but it scales linearly with both the sequence length and batch size. This means the activation memory is the part which will blow up when we increase our batch size or train with longer sequences. We can use this equation to look at how memory usage changes for various sequence lengths for example for Llama models (<code>bs=1</code>):</p>
395
+
396
+ <p><img alt="llama-memory-bars-no-recomp.png" src="/assets/images/placeholder.png" /></p>
397
+
398
+ <p>This graph tells a striking story: for short sequences (or similar for small batch-sizes), activations are almost negligible, but starting at around 2-4k tokens they come to take a significant amount of memory while parameter, gradient and optimizer states usage (that we’ll discuss later) stays roughly independent of the sequence length and batch size.</p>
399
+
400
+ <p><strong>For large input tokens (a.k.a large batch-sizes/sequences), activations become by far the largest memory burden.</strong> </p>
401
+
402
+ <p>Is there a way to tame this “activation explosion”? Good question, reader!</p>
403
+
404
+ <p>It’s time to explain our first technique – called <strong><em>activation recomputation</em><em>–</em> </strong>**which will help us cap activation memory footprint. An essential tool in today’s large model training toolbox.</p>
405
+
406
  <h3>Activation recomputation</h3>
407
 
408
+ <p>The general idea behind <strong><em>activation recomputation</em></strong> – also called <em>gradient checkpointing</em> or <em>rematerialization</em> – is to discard some activations during the forward pass to save memory and spend some extra compute to recompute these on the fly during the backward pass. Without recomputation, we store every hidden state between two learnable operations (e.g. FF, LayerNorm etc.), such that we can use them during the backward pass to compute gradients. When we use recomputation we typically will only store activations at a few key points along the model architecture, discard the rest of activations and recompute them on the fly during the backward pass from the nearest saved activations, basically performing again a sub-part of the forward pass to trade of memory for compute. It generally looks like this:</p>
409
+
410
+ <p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
411
+
412
+ <p>There are several strategies to select key activations to store:</p>
413
+
414
+ <ul>
415
+ <li><strong>Full</strong>: We checkpoint activations at the transition point between each layer of the Transformer model. This is usually called the <code>full</code> strategy since it requires a forward pass through each layer essentially adding a full forward pass during the backward pass. This strategy saves the most memory but is the most expensive one in terms of compute. It generally increases the compute cost and time by up to 30-40% which is very noticeable.</li>
416
+ <li><strong>Selective</strong>: In general we can do better than full. The authors of the recomputation paper<d-cite bibtex-key="korthikanti2022recomputation"></d-cite> did a detailed analysis studying which activations grow the largest and have the cheapest recomputation cost in terms of FLOPs. Turns out that the attention computations fall in that category, and thus we can usually discard them and focus on checkpointing expensive the feedforward computations. For a GPT-3 (175B) model this means <strong>70% activation memory reduction at a 2.7% compute cost</strong>.</li>
417
+ </ul>
418
+
419
+ <aside>In recent models like DeepSeek V3, selective checkpointing is performed, storing even a smaller size of attention activation —using so-called “Multi-Head Latent Attention” (MLA)– to optimize activation memory usage.</aside>
420
+
421
+ <p>Let’s see how drastically recomputation strategies can in practice reduce the memory footprint and how selective recomputation strikes a nice balance between memory saving and recomputation cost:</p>
422
+
423
+ <p><img alt="llama-8b-memory-bars--recomp.png" src="/assets/images/placeholder.png" /></p>
424
+
425
+ <aside>When you’re measuring how efficient your training setup is at using the accelerator’s available compute, you may want to take recomputation into account when measuring the total FLOPS (Floating point operations per second) of your training setup and comparing it to theoretical maximum FLOPS of your GPU/TPU/accelerator to estimate GPU utilization. Taking recomputation into account when calculating FLOPS for a training step gives a value called “hardware FLOPS” which is the real number of operations performed on the accelerator. Dividing this number by the duration of one training step and the maximum accelerator FLOPS yields the <em>Hardware FLOPS Utilization (HFU).</em> </aside>
426
 
427
+ <aside>However, when comparing various accelerators together, what really matters at the end of the day is the start-to-end time needed to train the same models on the same dataset, ie. if an accelerator allows to skip recomputation and thus perform less operation per second for a faster training it should be rewarded. Thus, alternative is to compute what is called <em>Model FLOPS Utilization (MFU)</em>, which in contrast to HFU only accounts for the required operations to compute the forward+backward passes, and not recomputation, ie. is specific to the model, not the training implementation.</aside>
428
+
429
+ <p>Most training frameworks these days use FlashAttention (which we’ll cover a bit later) which integrate natively activation recomputation in its optimization strategy by recomputing attention scores and matrices in the backward pass instead of storing them. Thus most people using Flash Attention are already making use of selective recomputation.</p>
430
+
431
+ <p><strong>As you’ve now understood, activation recomputation increases the number of FLOPs slightly due to recomputation, while it significantly reduces memory access overhead.</strong> </p>
432
+
433
+ <p>This trade-off is particularly advantageous on hardware with small high-speed memory, like GPUs, as accessing memory is typically slower than performing computations. Despite the additional operations involves, the overall effect is thus often faster computation as well, in addition to the much lower memory footprint.</p>
434
+
435
+ <p>Now that we’ve learned about recomputation, we can tame the activations memory usage as we saw in the above graphs!</p>
436
+
437
+ <p>However, activations still bears a linear dependance on the batch size and all our profiles in the barplots above were using <code>bs=1</code> so as we move to larger batch sizes it might become an issue again. Do not despair as we have a second tool in our box - <strong><em>gradient accumulation</em></strong> to the rescue!</p>
438
+
439
+ <h3>Gradient accumulation</h3>
440
+
441
+ <p>Now that we’ve used activation recomputation to fit our model with a small batch size on a single GPU, we still need to reach our target batch size, let’s say 1M tokens (see our earlier discussion on optimal batch size). Gradient accumulation is a very straightforward method to avoid memory explosion when doing this.</p>
442
+
443
+ <p>With <em>gradient accumulation</em> we split our batch into micro-batches, do forward and backward passes repeatedly on each micro-batch, compute the gradients, and, as the name suggests, sum the gradients for each micro-batch before doing a final optimizer step. In practice, we perform the optimization step not on the sum but on the average of the gradients, so the result is independent of the number of gradient accumulation steps.</p>
444
+
445
+ <p>Let’s call the batch size for each forward pass the <code>micro batch size</code> (mbs). We’ll refer to the overall batch size between each optimizer step as the <code>global batch size</code> (gbs). If we do one optimizer step for each 8 forward/backward passes, the <code>global batch size</code> will be 8 times the <code>micro batch size</code>.</p>
446
+
447
+ <p>What we now call <code>global batch size</code> thus corresponds to what we’ve called up to now just <code>batch size</code> for simplicity (we now make our terms more precise to avoid ambiguity).</p>
448
+
449
+ <p>With gradient accumulation the global batch size can be simply computed as follows:</p>
450
+
451
+ <d-math block>
452
+ bs = gbs = mbs \times grad\_acc
453
+ </d-math>
454
+
455
+ <p>Gradient accumulation allows us to effectively increase our batch size up to infinity (and beyond!) while the memory footprint stays constant. Gradient accumulation is also compatible with activation recomputation for further memory reduction. One drawback however, is that gradient accumulation requires multiple consecutive forward/backward passes per optimization step thereby increasing the compute overhead and slowing down training. No free lunch! </p>
456
+
457
+ <p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
458
+
459
+ <p><strong>Gradient accumulation allows us to reduce memory of activations which grow linearly with batch size by computing only only partial, micro-batches. There is a small overhead caused by the additional forward and backward passes.</strong></p>
460
+
461
+ <aside>Using gradient accumulation means we need to keep buffers where we accumulate gradients which persist throughout a training step. Whereas without gradient accumulation, in the backward gradients are computed while freeing the activations memory, which means a lower peak memory.</aside>
462
+
463
+ <p>But if you’ve carefully followed, you probably noticed that the forward/backward passes for each micro-batch can actually be run in parallel. Forward/backward passes are independent from each other, with independent input samples being the only difference. Seems like it’s time to start extending our training to more than one GPU! </p>
464
+
465
+ <p>Let’s get a larger workstation 🖥️ with a couple of GPUs and start investigating our first scaling technique called <em><strong>data parallelism</strong> which is just a parallel version of gradient accumulation</em>.</p>
466
+
467
+ <p><strong>TODO: add profiling here or not?</strong></p>
468
+
469
  <h2>Data Parallelism</h2>
470
 
471
  <h4><strong>First optimization:</strong> Overlap gradient synchronization with backward pass</h4>
src/bibliography.bib CHANGED
@@ -367,4 +367,40 @@ url = {https://github.com/meta-llama/llama3/blob/main/MODEL_CARD.md}
367
  archivePrefix={arXiv},
368
  primaryClass={cs.CL},
369
  url={https://arxiv.org/abs/2412.19437},
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
370
  }
 
367
  archivePrefix={arXiv},
368
  primaryClass={cs.CL},
369
  url={https://arxiv.org/abs/2412.19437},
370
+ }
371
+ @misc{mccandlish2018largebatchtraining,
372
+ title={An Empirical Model of Large-Batch Training},
373
+ author={Sam McCandlish and Jared Kaplan and Dario Amodei and OpenAI Dota Team},
374
+ year={2018},
375
+ eprint={1812.06162},
376
+ archivePrefix={arXiv},
377
+ primaryClass={cs.LG},
378
+ url={https://arxiv.org/abs/1812.06162},
379
+ }
380
+ @misc{micikevicius2018mixedprecisiontraining,
381
+ title={Mixed Precision Training},
382
+ author={Paulius Micikevicius and Sharan Narang and Jonah Alben and Gregory Diamos and Erich Elsen and David Garcia and Boris Ginsburg and Michael Houston and Oleksii Kuchaiev and Ganesh Venkatesh and Hao Wu},
383
+ year={2018},
384
+ eprint={1710.03740},
385
+ archivePrefix={arXiv},
386
+ primaryClass={cs.AI},
387
+ url={https://arxiv.org/abs/1710.03740},
388
+ }
389
+ @misc{rajbhandari2020zero,
390
+ title={ZeRO: Memory Optimizations Toward Training Trillion Parameter Models},
391
+ author={Samyam Rajbhandari and Jeff Rasley and Olatunji Ruwase and Yuxiong He},
392
+ year={2020},
393
+ eprint={1910.02054},
394
+ archivePrefix={arXiv},
395
+ primaryClass={cs.LG},
396
+ url={https://arxiv.org/abs/1910.02054},
397
+ }
398
+ @misc{korthikanti2022recomputation,
399
+ title={Reducing Activation Recomputation in Large Transformer Models},
400
+ author={Vijay Korthikanti and Jared Casper and Sangkug Lym and Lawrence McAfee and Michael Andersch and Mohammad Shoeybi and Bryan Catanzaro},
401
+ year={2022},
402
+ eprint={2205.05198},
403
+ archivePrefix={arXiv},
404
+ primaryClass={cs.LG},
405
+ url={https://arxiv.org/abs/2205.05198},
406
  }
src/index.html CHANGED
@@ -184,6 +184,7 @@
184
  <p>As you can see, there’s a lot of ground to be covered. Before getting into the trenches of distributed training let’s take a quick high level look on we’ll cover in the post.</p>
185
 
186
  <h2>TL;DR</h2>
 
187
  <p>This book is very extensive so we decide to start with a very general overview of how you can think about distributed training. At a high level, the key challenge in scaling LLM training is to make a training step (forward/backward/optimizer step) with a large batch size the fastest possible.</p>
188
  <p>When scaling up models and input batches, we quickly end up in situations where either our target batch size won't fit in memory, or/and the model itself is too large to fit in a single GPU's memory.</p>
189
  <p>To solve this scaling issue we’ll need to carefully evaluate different parallelization strategies and find the optimal balance between three main factors:</p>
@@ -213,18 +214,258 @@
213
 
214
  <h2>First Steps: Training on one GPU</h2>
215
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
216
  <h3>Memory usage in Transformers</h3>
 
 
 
 
 
 
 
 
 
217
 
 
 
 
 
 
 
 
 
 
 
 
218
  <h4>Memory profiling a training step</h4>
219
-
 
 
 
 
 
 
 
 
 
 
 
 
 
220
  <h4>Weights/grads/optimizer states memory</h4>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
221
 
 
 
222
  <h4>Activations memory</h4>
223
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
224
  <h3>Activation recomputation</h3>
225
 
226
- <h3>Gradient accumulation</h3>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
227
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
228
  <h2>Data Parallelism</h2>
229
 
230
  <h4><strong>First optimization:</strong> Overlap gradient synchronization with backward pass</h4>
 
184
  <p>As you can see, there’s a lot of ground to be covered. Before getting into the trenches of distributed training let’s take a quick high level look on we’ll cover in the post.</p>
185
 
186
  <h2>TL;DR</h2>
187
+
188
  <p>This book is very extensive so we decide to start with a very general overview of how you can think about distributed training. At a high level, the key challenge in scaling LLM training is to make a training step (forward/backward/optimizer step) with a large batch size the fastest possible.</p>
189
  <p>When scaling up models and input batches, we quickly end up in situations where either our target batch size won't fit in memory, or/and the model itself is too large to fit in a single GPU's memory.</p>
190
  <p>To solve this scaling issue we’ll need to carefully evaluate different parallelization strategies and find the optimal balance between three main factors:</p>
 
214
 
215
  <h2>First Steps: Training on one GPU</h2>
216
 
217
+ <p>Let’s start by quickly reviewing the very basics of model training before we start to scale to many GPUs. When a model is trained on a single GPU, the training typically consists of three steps: </p>
218
+
219
+ <ol>
220
+ <li>a forward pass which passes inputs through the model to yield its outputs,</li>
221
+ <li>a backward pass to compute the gradients, and</li>
222
+ <li>an optimization step using the gradients to update the parameters</li>
223
+ </ol>
224
+
225
+ <p>It looks generally like this: </p>
226
+ <p><img alt="image.png" src="assets/images/placeholder.png" /></p>
227
+
228
+ <aside>As we’ll see later, these steps may be repeated or intertwined but for now we’ll start simple.</aside>
229
+
230
+ <p>In this figure, the boxes on the top line can be seen as successive layers inside a model (same for the last line). The red boxes are the associated gradients for each of these layers, computed during the backward pass.</p>
231
+
232
+ <p>The batch size (<d-math>bs</d-math>) is one of the important hyper-parameters for model training and affects both model convergence and throughput.</p>
233
+
234
+ <p>If the batch size is too small, gradients will tend to be noisy and the model may not be able to converge to the most optimal performance, on the contrary it can be useful in early training to navigate quickly in the training landscape. On the other hand, a batch size too large will make less use of each training token rendering convergence slower and wasting compute. You can find a nice discussion of this topic in OpenAI’s paper on large batch training<d-cite bibtex-key="mccandlish2018largebatchtraining"></d-cite> or Section 4.2 of MiniMax-01 <a href="https://filecdn.minimax.chat/_Arxiv_MiniMax_01_Report.pdf">technical report</a>.</p>
235
+
236
+ <aside>For instance, during DeepSeek-V3/R1 training “the batch size is gradually increased from 3072 to 15360 in the training of the first 469B tokens, and then keeps 15360 in the remaining training”.</aside>
237
+
238
+ <p>Batch size also affects the time it takes to train on a given text dataset: a small batch size will require more optimizer steps to train on the same amount of samples. Optimizer steps are costly (in compute time) and the total time to train will thus increase compared to a larger batch size. This being said, note that the batch size can often be adjusted quite largely around the optimal batch size without major impact to the performance of the model, i.e. the sensitivity of final model performances to the exact batch size value is usually rather low around the optimal batch size.</p>
239
+
240
+ <p>In the LLM pretraining community, batch sizes are commonly reported in terms of tokens rather than in number of samples (<d-math>bst</d-math> = Batch Size Tokens), this makes training numbers generally independent of the exact input sequence length used during the training.</p>
241
+
242
+ <p>In the simplest case, training on a single machine, the <d-math>bs</d-math> (in samples) and <d-math>bst</d-math> can be computed from the model input sequence length (seq) as follows :</p>
243
+
244
+ <aside><p>From here onward we’ll show the formulas for the batch size in terms of samples but you can always get its token-unit counterpart by multiplying it with the sequence length.
245
+ </aside>
246
+
247
+ <d-math block>
248
+ bst=bs *seq
249
+ </d-math>
250
+
251
+ <p>A sweet spot for recent LLM training is typically on the order of 4-60 million tokens per batch. However, a typical issue when scaling the training of our model to these large batch sizes is out-of-memory issues, ie. our GPU doesn’t have enough memory.</p>
252
+
253
+ <aside>Note: Llama 1 was trained with a batch size of ~4M tokens for 1.4 trillions tokens while DeepSeek was trained with a batch size of ~60M tokens for 14 trillion tokens.
254
+ </aside>
255
+
256
+ <p><strong>It’s time to tackle our first scaling problem: what if our model starts exploding GPU memory before we’ve reached our target batch size (maybe in some case even when using the lowest possible batch size, <code>bs=1</code>)?</strong></p>
257
+
258
+ <p>Let’s start by quickly understanding what led to our out-of-memory issue in the first place. This will help us gain some useful intuitions for later.</p>
259
+
260
  <h3>Memory usage in Transformers</h3>
261
+
262
+ <p>When training a neural network model, one store several items in memory:</p>
263
+
264
+ <ul>
265
+ <li>Model weights</li>
266
+ <li>Activations needed to compute the gradients</li>
267
+ <li>Model gradients</li>
268
+ <li>Optimizer states</li>
269
+ </ul>
270
 
271
+ <aside >You would think for a model you could compute the memory requirements exactly but there are a few additional memory occupants that makes it hard to be exact:
272
+ <ul>
273
+ <li>CUDA Kernels typically require 1-2 GB of GPU memory, which you can quickly verify by running <code>import torch; torch.ones((1, 1)).to("cuda")</code> and then checking the GPU memory with <code>nvidia-smi</code>.</li>
274
+ <li>Some rest memory usage from buffers, intermediate results and some memory that can’t be used due to fragmentation</li>
275
+ </ul>
276
+ We’ll neglect these last two contributors as they are typically small and constant factors.</aside>
277
+
278
+ <p>These items are stored as tensors which come in different <em>shapes</em> and <em>precisions</em>. The <em>shapes</em> are determined by hyper-parameters such as batch size, sequence length, model hidden dimensions, attention heads, vocabulary size, and potential model sharding as we’ll see later. <em>Precision</em> refers to formats like FP32, BF16, or FP8, which respectively require 4, 2, or 1 byte to store each single value in the tensor.</p>
279
+
280
+ <p>So how can I quickly determine memory usage from these variable? One simple way is to do this empirically and just measure it.</p>
281
+
282
  <h4>Memory profiling a training step</h4>
283
+
284
+ <p>Using this snippet [TODO: link to appendix A5], we can understand how memory is allocated throughout training. We can see that memory utilization is not a static thing but varies a lot during training and during a training step:</p>
285
+
286
+ <p><img alt="llama-1b-memory.png" src="assets/images/placeholder.png" /></p>
287
+
288
+ <p>Clearly the first step looks very different from the subsequent ones, but let’s first have a look at the general anatomy of a step: first the activations increase quickly as we do the forward pass, then during the backward pass the gradients build up and as the backward pass propagates, the stored activations used to compute the gradients are progressively cleared. Finally, we perform the optimization step during which we need all the gradients and then update the optimizer states before we start the next forward pass. </p>
289
+
290
+ <p>Why does the first step looks different: the activations increase quickly and then plateau for a while. In this first step the torch cache allocator does a lot of preparation preparing memory allocations to speed up the subsequent steps so that they don’t require searching for free memory blocks afterwards (see <a href="https://zdevito.github.io/2022/08/04/cuda-caching-allocator.html">Zach’s blog</a>). After the first step we also see the optimizer states appearing which generally offset the memory usage for further training steps.</p>
291
+
292
+ <aside>Ever noticed how sometimes the training succeeds in the first step but then OOMs during the following training steps? This can be explained by the build-up of the optimizer state after the first step.
293
+ </aside>
294
+
295
+ <p>Now that we’ve a first view of memory, let’s see how scaling up training is often a question of maximizing compute efficiency while keeping the memory requirements of these various items (activations, parameters, gradients, optimizer states) within the memory constraints of the GPUs.</p>
296
+
297
  <h4>Weights/grads/optimizer states memory</h4>
298
+
299
+ <p>We can actually pretty easily estimate the memory needed for the model’s weights, gradients and optimizer states.</p>
300
+
301
+ <p>For a simple transformer LLM the number of parameters is given by the <a href="https://michaelwornow.net/2024/01/18/counting-params-in-transformer">following formula</a>:</p>
302
+
303
+ <d-math block>
304
+ N = h * v + L * (12 * h^2 + 13 * h) + 2*h
305
+ </d-math>
306
+
307
+ <aside>We excluded the positional embedding count as rotary embeddings are not learned.</aside>
308
+
309
+ <p>In that equation, <d-math>h</d-math> is the hidden dimension, <d-math>v</d-math> the vocabulary size, and <d-math>L</d-math> the number of layers in the model. Note that looking at the equation we can see that the term that will dominate at large hidden dimensions is the <d-math>h^2</d-math> term since it’s the only one growing quadratically as we scale the parameters.</p>
310
+
311
+ <p>Memory requirements for the parameters and gradients are simply the number of parameters multiplied by the number of bytes per parameter. In good old-fashioned full precision (FP32) training both parameters and gradients require 4 bytes while the optimizer, if we use Adam, requires the momentum and variance to be stored, which adds another two 4 bytes per parameter. In summary:</p>
312
+
313
+ <d-math block>
314
+ \begin{aligned}
315
+ & m_{params} = 4 * N \\
316
+ & m_{grad} = 4 * N \\
317
+ & m_{opt} = (4+4) * N
318
+ \end{aligned}
319
+ </d-math>
320
+
321
+ <p>Now let’s have look how things change if we train with mixed precision<d-cite bibtex-key="micikevicius2018mixedprecisiontraining"></d-cite>. The default nowadays is for mixed precision training is BF16, requires 2 bytes per parameter and gradient as well as an additional copy of the model weights and gradients in FP32, thus 12 bytes per parameter in total. In addition to the parameters and gradient, we need to store the optimizer states: for the Adam optimizer, this requires the momentum and the variance usually stored in FP32 for numerical stability, each using 4 bytes. </p>
322
+
323
+ <aside>See some more details below when we cover the ZeRO methods.</aside>
324
+
325
+ <p>Here’s the summary:</p>
326
+
327
+ <d-math block>
328
+ \begin{aligned}
329
+ & m_{params} = 2 * N \\
330
+ & m_{grad} = 2 * N \\
331
+ & m_{params_fp32} = 4 * N \\
332
+ & m_{opt} = (4+4) * N
333
+ \end{aligned}
334
+ </d-math>
335
+
336
+ <aside>Some librarie store grads in fp32 which would require an additional $m_{params_fp32} = 4 * N$ memory. This is done for example in nanotron, because <code>bf16</code> is lossy for smaller values and we always prioritize stability. See <a href="https://github.com/microsoft/DeepSpeed/issues/1773">this DeepSpeed issue</a> for more information.
337
+ </aside>
338
+
339
+ <p>Interestingly, mixed precision itself doesn’t save overall memory as it just distributes the memory differently across the three components, and in fact adds another 4 bytes over full precision training if we accumulate gradients in FP32. It’s still advantageous as having the model which does the forward/backward in half precision it allows us to (1) use optimized lower precision operations on the GPU which are faster and (2) reduces the activation memory requirements during the forward pass.</p>
340
+
341
+ <p>Let’s get a sense of how much general memory we need for a model (full and mixed precision giving the same overall value):</p>
342
+
343
+ <table>
344
+ <thead>
345
+ <tr>
346
+ <th><strong>Model parameters</strong></th>
347
+ <th><strong>FP32 or BF16 w/o FP32 grad acc</strong></th>
348
+ <th><strong>BF16 w/ FP32 grad acc</strong></th>
349
+ </tr>
350
+ </thead>
351
+ <tbody>
352
+ <tr>
353
+ <td>1B</td>
354
+ <td>16 GB</td>
355
+ <td>20 GB</td>
356
+ </tr>
357
+ <tr>
358
+ <td>7B</td>
359
+ <td>112 GB</td>
360
+ <td>140 GB</td>
361
+ </tr>
362
+ <tr>
363
+ <td>70B</td>
364
+ <td>1120 GB</td>
365
+ <td>1400 GB</td>
366
+ </tr>
367
+ <tr>
368
+ <td>405B</td>
369
+ <td>6480 GB</td>
370
+ <td>8100 GB</td>
371
+ </tr>
372
+ </tbody>
373
+ </table>
374
+
375
+ <aside><p>Using FP8 training instead of BF16 would further decrease the memory usage but it is less stable and a very active research topic (see <a href="https://x.com/xariusrke/status/1826669126955278401">this tweet</a>) and we’ll cover it in more detail later.
376
+ </aside>
377
+
378
+ <p>As we can see, as soon as we reach <strong>7B</strong> (!), weights and optimizer requirements already starts to add up significantly and exceed the size of a typical GPU memory, e.g. 80GB for a H100 GPU.</p>
379
 
380
+ <p>But for now, let’s start with models which still fits in a single GPU, take a look at the other big contributor to our memory budget: the activation memory.</p>
381
+
382
  <h4>Activations memory</h4>
383
 
384
+ <p>Activation memory is a bit more complex to compute than the weights, gradients and optimizer states, in part because it depends on the inputs of the model. If you’re unsure why we even need to store activations for the backward pass, <a href="https://www.determined.ai/blog/act-mem-2">this reference</a> is a good quick refresh. After a careful inspection of how backward pass is computed we can estimate the total memory required for the activations in mixed precision and we arrive at the following equation:</p>
385
+
386
+ <d-math block>
387
+ m_{act} = L<em> seq * bs * h * (34 + \frac{5</em>n_{heads}*seq}{h})</p>
388
+ </d-math>
389
+
390
+ <p>Here <d-math>L</d-math> is the number of layers, <d-math>seq</d-math> the sequence length, <d-math>bs</d-math> the batch size in samples, <d-math>h</d-math> the hidden dimension of the model and <d-math>n_{heads}</d-math> the number of heads.</p>
391
+
392
+ <p>For the exact derivation of the numbers, you can follow this original NVIDIA paper on recomputation <d-cite bibtex-key="korthikanti2022recomputation"></d-cite>, it essentially requires you to do some accounting of all the sizes of intermediate activations between each operation in a transformer layer.</p>
393
+
394
+ <p>An interesting observation here is how the memory is not static for a given model but it scales linearly with both the sequence length and batch size. This means the activation memory is the part which will blow up when we increase our batch size or train with longer sequences. We can use this equation to look at how memory usage changes for various sequence lengths for example for Llama models (<code>bs=1</code>):</p>
395
+
396
+ <p><img alt="llama-memory-bars-no-recomp.png" src="/assets/images/placeholder.png" /></p>
397
+
398
+ <p>This graph tells a striking story: for short sequences (or similar for small batch-sizes), activations are almost negligible, but starting at around 2-4k tokens they come to take a significant amount of memory while parameter, gradient and optimizer states usage (that we’ll discuss later) stays roughly independent of the sequence length and batch size.</p>
399
+
400
+ <p><strong>For large input tokens (a.k.a large batch-sizes/sequences), activations become by far the largest memory burden.</strong> </p>
401
+
402
+ <p>Is there a way to tame this “activation explosion”? Good question, reader!</p>
403
+
404
+ <p>It’s time to explain our first technique – called <strong><em>activation recomputation</em><em>–</em> </strong>**which will help us cap activation memory footprint. An essential tool in today’s large model training toolbox.</p>
405
+
406
  <h3>Activation recomputation</h3>
407
 
408
+ <p>The general idea behind <strong><em>activation recomputation</em></strong> – also called <em>gradient checkpointing</em> or <em>rematerialization</em> – is to discard some activations during the forward pass to save memory and spend some extra compute to recompute these on the fly during the backward pass. Without recomputation, we store every hidden state between two learnable operations (e.g. FF, LayerNorm etc.), such that we can use them during the backward pass to compute gradients. When we use recomputation we typically will only store activations at a few key points along the model architecture, discard the rest of activations and recompute them on the fly during the backward pass from the nearest saved activations, basically performing again a sub-part of the forward pass to trade of memory for compute. It generally looks like this:</p>
409
+
410
+ <p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
411
+
412
+ <p>There are several strategies to select key activations to store:</p>
413
+
414
+ <ul>
415
+ <li><strong>Full</strong>: We checkpoint activations at the transition point between each layer of the Transformer model. This is usually called the <code>full</code> strategy since it requires a forward pass through each layer essentially adding a full forward pass during the backward pass. This strategy saves the most memory but is the most expensive one in terms of compute. It generally increases the compute cost and time by up to 30-40% which is very noticeable.</li>
416
+ <li><strong>Selective</strong>: In general we can do better than full. The authors of the recomputation paper<d-cite bibtex-key="korthikanti2022recomputation"></d-cite> did a detailed analysis studying which activations grow the largest and have the cheapest recomputation cost in terms of FLOPs. Turns out that the attention computations fall in that category, and thus we can usually discard them and focus on checkpointing expensive the feedforward computations. For a GPT-3 (175B) model this means <strong>70% activation memory reduction at a 2.7% compute cost</strong>.</li>
417
+ </ul>
418
+
419
+ <aside>In recent models like DeepSeek V3, selective checkpointing is performed, storing even a smaller size of attention activation —using so-called “Multi-Head Latent Attention” (MLA)– to optimize activation memory usage.</aside>
420
+
421
+ <p>Let’s see how drastically recomputation strategies can in practice reduce the memory footprint and how selective recomputation strikes a nice balance between memory saving and recomputation cost:</p>
422
+
423
+ <p><img alt="llama-8b-memory-bars--recomp.png" src="/assets/images/placeholder.png" /></p>
424
+
425
+ <aside>When you’re measuring how efficient your training setup is at using the accelerator’s available compute, you may want to take recomputation into account when measuring the total FLOPS (Floating point operations per second) of your training setup and comparing it to theoretical maximum FLOPS of your GPU/TPU/accelerator to estimate GPU utilization. Taking recomputation into account when calculating FLOPS for a training step gives a value called “hardware FLOPS” which is the real number of operations performed on the accelerator. Dividing this number by the duration of one training step and the maximum accelerator FLOPS yields the <em>Hardware FLOPS Utilization (HFU).</em> </aside>
426
 
427
+ <aside>However, when comparing various accelerators together, what really matters at the end of the day is the start-to-end time needed to train the same models on the same dataset, ie. if an accelerator allows to skip recomputation and thus perform less operation per second for a faster training it should be rewarded. Thus, alternative is to compute what is called <em>Model FLOPS Utilization (MFU)</em>, which in contrast to HFU only accounts for the required operations to compute the forward+backward passes, and not recomputation, ie. is specific to the model, not the training implementation.</aside>
428
+
429
+ <p>Most training frameworks these days use FlashAttention (which we’ll cover a bit later) which integrate natively activation recomputation in its optimization strategy by recomputing attention scores and matrices in the backward pass instead of storing them. Thus most people using Flash Attention are already making use of selective recomputation.</p>
430
+
431
+ <p><strong>As you’ve now understood, activation recomputation increases the number of FLOPs slightly due to recomputation, while it significantly reduces memory access overhead.</strong> </p>
432
+
433
+ <p>This trade-off is particularly advantageous on hardware with small high-speed memory, like GPUs, as accessing memory is typically slower than performing computations. Despite the additional operations involves, the overall effect is thus often faster computation as well, in addition to the much lower memory footprint.</p>
434
+
435
+ <p>Now that we’ve learned about recomputation, we can tame the activations memory usage as we saw in the above graphs!</p>
436
+
437
+ <p>However, activations still bears a linear dependance on the batch size and all our profiles in the barplots above were using <code>bs=1</code> so as we move to larger batch sizes it might become an issue again. Do not despair as we have a second tool in our box - <strong><em>gradient accumulation</em></strong> to the rescue!</p>
438
+
439
+ <h3>Gradient accumulation</h3>
440
+
441
+ <p>Now that we’ve used activation recomputation to fit our model with a small batch size on a single GPU, we still need to reach our target batch size, let’s say 1M tokens (see our earlier discussion on optimal batch size). Gradient accumulation is a very straightforward method to avoid memory explosion when doing this.</p>
442
+
443
+ <p>With <em>gradient accumulation</em> we split our batch into micro-batches, do forward and backward passes repeatedly on each micro-batch, compute the gradients, and, as the name suggests, sum the gradients for each micro-batch before doing a final optimizer step. In practice, we perform the optimization step not on the sum but on the average of the gradients, so the result is independent of the number of gradient accumulation steps.</p>
444
+
445
+ <p>Let’s call the batch size for each forward pass the <code>micro batch size</code> (mbs). We’ll refer to the overall batch size between each optimizer step as the <code>global batch size</code> (gbs). If we do one optimizer step for each 8 forward/backward passes, the <code>global batch size</code> will be 8 times the <code>micro batch size</code>.</p>
446
+
447
+ <p>What we now call <code>global batch size</code> thus corresponds to what we’ve called up to now just <code>batch size</code> for simplicity (we now make our terms more precise to avoid ambiguity).</p>
448
+
449
+ <p>With gradient accumulation the global batch size can be simply computed as follows:</p>
450
+
451
+ <d-math block>
452
+ bs = gbs = mbs \times grad\_acc
453
+ </d-math>
454
+
455
+ <p>Gradient accumulation allows us to effectively increase our batch size up to infinity (and beyond!) while the memory footprint stays constant. Gradient accumulation is also compatible with activation recomputation for further memory reduction. One drawback however, is that gradient accumulation requires multiple consecutive forward/backward passes per optimization step thereby increasing the compute overhead and slowing down training. No free lunch! </p>
456
+
457
+ <p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
458
+
459
+ <p><strong>Gradient accumulation allows us to reduce memory of activations which grow linearly with batch size by computing only only partial, micro-batches. There is a small overhead caused by the additional forward and backward passes.</strong></p>
460
+
461
+ <aside>Using gradient accumulation means we need to keep buffers where we accumulate gradients which persist throughout a training step. Whereas without gradient accumulation, in the backward gradients are computed while freeing the activations memory, which means a lower peak memory.</aside>
462
+
463
+ <p>But if you’ve carefully followed, you probably noticed that the forward/backward passes for each micro-batch can actually be run in parallel. Forward/backward passes are independent from each other, with independent input samples being the only difference. Seems like it’s time to start extending our training to more than one GPU! </p>
464
+
465
+ <p>Let’s get a larger workstation 🖥️ with a couple of GPUs and start investigating our first scaling technique called <em><strong>data parallelism</strong> which is just a parallel version of gradient accumulation</em>.</p>
466
+
467
+ <p><strong>TODO: add profiling here or not?</strong></p>
468
+
469
  <h2>Data Parallelism</h2>
470
 
471
  <h4><strong>First optimization:</strong> Overlap gradient synchronization with backward pass</h4>