Files changed (2) hide show
  1. dist/index.html +90 -18
  2. src/index.html +40 -18
dist/index.html CHANGED
@@ -327,7 +327,10 @@
327
 
328
  <h4>Profiling the memory usage</h4>
329
 
330
- <p>Using this snippet [TODO: link to appendix A5], we can understand how memory is allocated throughout training. We can see that memory utilization is not a static thing but varies a lot during training and during a training step:</p>
 
 
 
331
 
332
  <!-- <div class="svg-container l-body-outset" id="svg-first_steps_memory_profile"> </div>
333
  <div class="info" id="svg-first_steps_memory_profile-info">Hover over the elements to see their details</div>
@@ -596,7 +599,7 @@
596
 
597
  <p><img alt="image.png" src="/assets/images/dp_diagram.png" /></p>
598
 
599
- <aside>If you are not familiar with distributed communications patterns like broadcast, gather or all-reduce we put together a small crash course in the Appendix [TODO Link].</aside>
600
 
601
  <p>This involves our first “distributed communication” primitive: <em><strong>all-reduce</em></strong> which handles the synchronization and communication between GPU instances and nodes.</p>
602
 
@@ -809,7 +812,7 @@
809
  <ul>
810
  <li>Forward pass with all bf16 parameters, but different microbatches across DP ranks</li>
811
  <li>Backward pass with all gradients, but different microbatches across DP ranks</li>
812
- <li>Perform an reduce-scatter <strong>[TODO ADD link!]</strong> on the gradients (reduce-scatter is 2 times faster than all reduce! <em>Yay, a third communication primitive!</em>)</li>
813
  <li>- Each replica perform an optimizer step (has only <d-math>\frac{1}{N_d}</d-math> optimizer states) updates only on <d-math>\frac{1}{N_d}</d-math> of fp32 parameters, and then <d-math>\frac{1}{N_d}</d-math> of bf16 parameters.</li>
814
  <li>Perform an all-gather of bf16 parameters to send missing slices back to each replica. This is a new operation in ZeRO, and not used in vanilla DP.</li>
815
  </ul>
@@ -1179,7 +1182,7 @@
1179
  </script>
1180
  <!-- <p><img alt="tp_sp_memoryusage.svg" src="/assets/images/tp_sp_memoryusage.svg" /></p> -->
1181
 
1182
- <p>Does that mean that SP incurs more communication than TP? Well, yes and no. In the forward of a vanilla TP we had two all-reduce per transformer block, and in SP we have two all-gather and two reduce-scatter per transformer block. So SP does twice the number of communication operations as TP. But since an all-reduce operation can be broken down into to an all-gather + reduce-scatter (see in [TODO: Appendix link]) they’re actually equivalent in terms of communication. Same reasoning for backward as we just use the conjugate of each operation (no-op ↔ allreduce and allgather ↔ reducescatter).</p>
1183
 
1184
  <p>If you’ve been paying close attention, you’ll notice that we’re talking about 4 comms ops in each layer (2 for Attention and 2 for MLP). This is how the MLP profiling looks like when using Tensor + Sequence Parallelism:</p>
1185
 
@@ -1214,16 +1217,16 @@
1214
  </ul>
1215
 
1216
  <p><strong>We have seen how TP helps us shard activations across several GPUs by splitting the attention and feedforward operations along the hidden dimension and how SP is a natural complement for the remaining operations by splitting along the sequence dimension.</strong></p>
1217
-
1218
- <p>However, there are two limits to TP and SP: 1) if we scale the sequence length the activation memory will still blow up in the TP region and 2) if the model is too big to fit with TP=8 then we will see a massive slow-down due to the inter-node connectivity.</p>
1219
-
1220
  <div class="note-box">
1221
  <p class="note-box-title">📝 Note</p>
1222
  <p class="note-box-content">
1223
- <p>Since LayerNorms in the SP region operate on different portions of the sequence, their gradients will differ across TP ranks. To ensure the weights stay synchronized, we need to allreduce their gradients during the backward pass, similar to how DP ensures weights stay in sync. This is a small communication overhead since LayerNorm has relatively few parameters.
1224
  </p>
1225
  </div>
1226
 
 
 
1227
  <p>We can tackle problem 1) with Context parallelism and problem 2) with Pipeline parallelism. Let’s first have a look at Context parallelism!</p>
1228
 
1229
  <h2>Context Parallelism</h2>
@@ -1828,12 +1831,12 @@
1828
  <p>On the compute side, GPUs consist of an array of compute units called <strong>Streaming Multiprocessors</strong> (SM). Each SM contains and controls a set of streaming processors, also known as cores. For example, an Nvidia H100 GPU has 132 SMs with 128 cores per SM, resulting in a total of 16,896 cores (see <a href="https://resources.nvidia.com/en-us-tensor-core">docs for tensor cores</a> for details), each capable of handling multiple threads simultaneously.</p>
1829
 
1830
  <p><img alt="image.png" src="/assets/images/diving_primergpu.svg" /></p>
1831
- <p>TODO: Original figure from https://blog.codingconfessions.com/p/gpu-computing.</p>
1832
 
1833
  <p>The memory side is also highly hierarchical with several layers of cache and memory: <strong>Registers</strong> are the smallest units and are private to the threads during executions, <strong>Shared Memory</strong> and <strong>L1 cache are</strong> shared between the threads running on a single SM, higher up is the <strong>L2 cache</strong> shared by all SMs, finally there is the <strong>Global Memory</strong> which is the largest memory on the GPU (the advertised 80 GB for a H100 for instance) but also the slowest to access and query.</p>
1834
 
1835
  <p><img alt="image.png" src="/assets/images/diving_primergpu2.svg" /></p>
1836
- <p>TODO: Original figure from https://www.youtube.com/watch?v=ZQKMZIP3Fzg</p>
1837
 
1838
  <p>The goal of GPU will be to run as many workloads as possible, in parallel, on the GPU cores, by taking advantage of this hierarchical organization of compute/memory.</p>
1839
 
@@ -1870,7 +1873,6 @@
1870
  x & \text{if } x \geq 0
1871
  \end{cases}
1872
  </d-math>
1873
- <p>TODO: something off with spacing but seems the rendering engine</p>
1874
 
1875
  <p>You can start by a simple pytorch implementation and then just add the <code>@torch.compile</code> decorator on top:</p>
1876
 
@@ -2297,7 +2299,7 @@
2297
  <td>Above without FP32 grad accumulation</td>
2298
  <td>bf16</td>
2299
  <td>fp32</td>
2300
- <td></td>
2301
  <td>bf16</td>
2302
  <td>bf16</td>
2303
  <td>fp32 + fp32</td>
@@ -2306,8 +2308,8 @@
2306
  <tr>
2307
  <td>Transformer Engine</td>
2308
  <td>fp8</td>
2309
- <td></td>
2310
- <td></td>
2311
  <td>fp32</td>
2312
  <td>fp32</td>
2313
  <td>fp32 + fp32</td>
@@ -2346,7 +2348,7 @@
2346
  </tbody>
2347
  </table>
2348
 
2349
- <p>Overall, FP8 is still an experimental technique and methods are evolving, but will likely become the standard soon replacing bf16 mixed-precision. To follow public implementations of this, please head to the nanotron’s implementation in [TODO: link to appendix]. </p>
2350
 
2351
  <p>In the future, Blackwell, the next generation of NVIDIA chips, <a href="https://www.nvidia.com/en-us/data-center/technologies/blackwell-architecture/">have been announced </a> to support FP4 training, further speeding up training but without a doubt also introducing a new training stability challenge.</p>
2352
 
@@ -2432,9 +2434,8 @@
2432
  <h3>What’s next?</h3>
2433
 
2434
  <p>You should have a good overview of all the distributed training concepts but there are still things to learn and details we couldn’t cover. To get deeper in the field we recommend doing some of the following steps:</p>
2435
-
2436
  <ul>
2437
- <li>Carefully read some of the landmark or very recent papers. You can find a list of some of the most impactful papers in [TODO References]</li>
2438
  <li>Start from scratch and implement an algorithm yourself. Often a method only fully “clicks” if you implemented it yourself.</li>
2439
  <li>Dive into one of the widely used frameworks and start contributing: fix bugs, answer issues, or implement a new feature. That’s the best way to get in any ML field!</li>
2440
  </ul>
@@ -2464,6 +2465,11 @@
2464
  <a href="https://arxiv.org/abs/2312.11805"><strong>Gemini</strong></a>
2465
  <p>Presents Google's multimodal model architecture capable of processing text, images, audio, and video inputs.</p>
2466
  </div>
 
 
 
 
 
2467
 
2468
  <div>
2469
  <a href="https://arxiv.org/abs/2412.19437v1"><strong>DeepSeek-V3</strong></a>
@@ -2472,7 +2478,6 @@
2472
 
2473
 
2474
  <h3>Training Frameworks</h3>
2475
-
2476
  <div>
2477
  <a href="https://github.com/facebookresearch/fairscale/tree/main"><strong>FairScale</strong></a>
2478
  <p>PyTorch extension library for large-scale training, offering various parallelism and optimization techniques.</p>
@@ -2525,6 +2530,11 @@
2525
  <p>Comprehensive guide to understanding and optimizing GPU memory usage in PyTorch.</p>
2526
  </div>
2527
 
 
 
 
 
 
2528
  <div>
2529
  <a href="https://pytorch.org/tutorials/intermediate/tensorboard_profiler_tutorial.html"><strong>TensorBoard Profiler Tutorial</strong></a>
2530
  <p>Guide to using TensorBoard's profiling tools for PyTorch models.</p>
@@ -2586,6 +2596,11 @@
2586
  <a href="https://arxiv.org/abs/1710.03740"><strong>Mixed precision training</strong></a>
2587
  <p>Introduces mixed precision training techniques for deep learning models.</p>
2588
  </div>
 
 
 
 
 
2589
 
2590
  <h3>Hardware</h3>
2591
 
@@ -2603,6 +2618,11 @@
2603
  <a href="https://www.semianalysis.com/p/100000-h100-clusters-power-network"><strong>Semianalysis - 100k H100 cluster</strong></a>
2604
  <p>Analysis of large-scale H100 GPU clusters and their implications for AI infrastructure.</p>
2605
  </div>
 
 
 
 
 
2606
 
2607
  <h3>Others</h3>
2608
 
@@ -2630,9 +2650,61 @@
2630
  <a href="https://www.harmdevries.com/post/context-length/"><strong>Harm's blog for long context</strong></a>
2631
  <p>Investigation into long context training in terms of data and training cost.</p>
2632
  </div>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2633
 
2634
  <h2>Appendix</h2>
2635
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2636
  </d-article>
2637
 
2638
  <d-appendix>
 
327
 
328
  <h4>Profiling the memory usage</h4>
329
 
330
+ <p>Using this snippet, we can understand how memory is allocated throughout training. We can see that memory utilization is not a static thing but varies a lot during training and during a training step:</p>
331
+
332
+ <aside>Check out <a target="_self" href="#a1%3A_distributed_training_profiling" class="">A1: Distributed Training Profiling</a> for a walkthrough how to profile your model.</aside>
333
+
334
 
335
  <!-- <div class="svg-container l-body-outset" id="svg-first_steps_memory_profile"> </div>
336
  <div class="info" id="svg-first_steps_memory_profile-info">Hover over the elements to see their details</div>
 
599
 
600
  <p><img alt="image.png" src="/assets/images/dp_diagram.png" /></p>
601
 
602
+ <aside>If you are not familiar with distributed communications patterns like broadcast, gather or all-reduce we put together a small crash course in <a target="_self" href="#a0%3A_parallel_programming_crash_course" class="">A0: Parallel Programming Crash Course</a>.</aside>
603
 
604
  <p>This involves our first “distributed communication” primitive: <em><strong>all-reduce</em></strong> which handles the synchronization and communication between GPU instances and nodes.</p>
605
 
 
812
  <ul>
813
  <li>Forward pass with all bf16 parameters, but different microbatches across DP ranks</li>
814
  <li>Backward pass with all gradients, but different microbatches across DP ranks</li>
815
+ <li>Perform an reduce-scatter on the gradients (reduce-scatter is 2 times faster than all reduce! <em>Yay, a third communication primitive!</em>)</li>
816
  <li>- Each replica perform an optimizer step (has only <d-math>\frac{1}{N_d}</d-math> optimizer states) updates only on <d-math>\frac{1}{N_d}</d-math> of fp32 parameters, and then <d-math>\frac{1}{N_d}</d-math> of bf16 parameters.</li>
817
  <li>Perform an all-gather of bf16 parameters to send missing slices back to each replica. This is a new operation in ZeRO, and not used in vanilla DP.</li>
818
  </ul>
 
1182
  </script>
1183
  <!-- <p><img alt="tp_sp_memoryusage.svg" src="/assets/images/tp_sp_memoryusage.svg" /></p> -->
1184
 
1185
+ <p>Does that mean that SP incurs more communication than TP? Well, yes and no. In the forward of a vanilla TP we had two all-reduce per transformer block, and in SP we have two all-gather and two reduce-scatter per transformer block. So SP does twice the number of communication operations as TP. But since an all-reduce operation can be broken down into to an all-gather + reduce-scatter (see the <a target="_self" href="#a_quick_focus_on_ring_allreduce" class="">A quick focus on Ring AllReduce</a> section in the appendix) they’re actually equivalent in terms of communication. Same reasoning for backward as we just use the conjugate of each operation (no-op ↔ allreduce and allgather ↔ reducescatter).</p>
1186
 
1187
  <p>If you’ve been paying close attention, you’ll notice that we’re talking about 4 comms ops in each layer (2 for Attention and 2 for MLP). This is how the MLP profiling looks like when using Tensor + Sequence Parallelism:</p>
1188
 
 
1217
  </ul>
1218
 
1219
  <p><strong>We have seen how TP helps us shard activations across several GPUs by splitting the attention and feedforward operations along the hidden dimension and how SP is a natural complement for the remaining operations by splitting along the sequence dimension.</strong></p>
1220
+
 
 
1221
  <div class="note-box">
1222
  <p class="note-box-title">📝 Note</p>
1223
  <p class="note-box-content">
1224
+ <p>Since LayerNorms in the SP region operate on different portions of the sequence, their gradients will differ across TP ranks. To ensure the weights stay synchronized, we need to all-reduce their gradients during the backward pass, similar to how DP ensures weights stay in sync. This is a small communication overhead since LayerNorm has relatively few parameters.
1225
  </p>
1226
  </div>
1227
 
1228
+ <p>However, there are two limits to TP and SP: 1) if we scale the sequence length the activation memory will still blow up in the TP region and 2) if the model is too big to fit with TP=8 then we will see a massive slow-down due to the inter-node connectivity.</p>
1229
+
1230
  <p>We can tackle problem 1) with Context parallelism and problem 2) with Pipeline parallelism. Let’s first have a look at Context parallelism!</p>
1231
 
1232
  <h2>Context Parallelism</h2>
 
1831
  <p>On the compute side, GPUs consist of an array of compute units called <strong>Streaming Multiprocessors</strong> (SM). Each SM contains and controls a set of streaming processors, also known as cores. For example, an Nvidia H100 GPU has 132 SMs with 128 cores per SM, resulting in a total of 16,896 cores (see <a href="https://resources.nvidia.com/en-us-tensor-core">docs for tensor cores</a> for details), each capable of handling multiple threads simultaneously.</p>
1832
 
1833
  <p><img alt="image.png" src="/assets/images/diving_primergpu.svg" /></p>
1834
+ <p><em>Source: https://blog.codingconfessions.com/p/gpu-computing.</em></p>
1835
 
1836
  <p>The memory side is also highly hierarchical with several layers of cache and memory: <strong>Registers</strong> are the smallest units and are private to the threads during executions, <strong>Shared Memory</strong> and <strong>L1 cache are</strong> shared between the threads running on a single SM, higher up is the <strong>L2 cache</strong> shared by all SMs, finally there is the <strong>Global Memory</strong> which is the largest memory on the GPU (the advertised 80 GB for a H100 for instance) but also the slowest to access and query.</p>
1837
 
1838
  <p><img alt="image.png" src="/assets/images/diving_primergpu2.svg" /></p>
1839
+ <p><em>Source: https://www.youtube.com/watch?v=ZQKMZIP3Fzg</em></p>
1840
 
1841
  <p>The goal of GPU will be to run as many workloads as possible, in parallel, on the GPU cores, by taking advantage of this hierarchical organization of compute/memory.</p>
1842
 
 
1873
  x & \text{if } x \geq 0
1874
  \end{cases}
1875
  </d-math>
 
1876
 
1877
  <p>You can start by a simple pytorch implementation and then just add the <code>@torch.compile</code> decorator on top:</p>
1878
 
 
2299
  <td>Above without FP32 grad accumulation</td>
2300
  <td>bf16</td>
2301
  <td>fp32</td>
2302
+ <td>n/a</td>
2303
  <td>bf16</td>
2304
  <td>bf16</td>
2305
  <td>fp32 + fp32</td>
 
2308
  <tr>
2309
  <td>Transformer Engine</td>
2310
  <td>fp8</td>
2311
+ <td>n/a</td>
2312
+ <td>n/a</td>
2313
  <td>fp32</td>
2314
  <td>fp32</td>
2315
  <td>fp32 + fp32</td>
 
2348
  </tbody>
2349
  </table>
2350
 
2351
+ <p>Overall, FP8 is still an experimental technique and methods are evolving, but will likely become the standard soon replacing bf16 mixed-precision. To follow a public implementations of this, please head to the nanotron’s implementation in <a href="https://github.com/huggingface/nanotron/pull/70">this PR</a>. </p>
2352
 
2353
  <p>In the future, Blackwell, the next generation of NVIDIA chips, <a href="https://www.nvidia.com/en-us/data-center/technologies/blackwell-architecture/">have been announced </a> to support FP4 training, further speeding up training but without a doubt also introducing a new training stability challenge.</p>
2354
 
 
2434
  <h3>What’s next?</h3>
2435
 
2436
  <p>You should have a good overview of all the distributed training concepts but there are still things to learn and details we couldn’t cover. To get deeper in the field we recommend doing some of the following steps:</p>
 
2437
  <ul>
2438
+ <li>Carefully read some of the landmark or very recent papers. You can find a list of some of the most impactful papers in <a target="_self" href="#references" class="">References</a>.</li>
2439
  <li>Start from scratch and implement an algorithm yourself. Often a method only fully “clicks” if you implemented it yourself.</li>
2440
  <li>Dive into one of the widely used frameworks and start contributing: fix bugs, answer issues, or implement a new feature. That’s the best way to get in any ML field!</li>
2441
  </ul>
 
2465
  <a href="https://arxiv.org/abs/2312.11805"><strong>Gemini</strong></a>
2466
  <p>Presents Google's multimodal model architecture capable of processing text, images, audio, and video inputs.</p>
2467
  </div>
2468
+
2469
+ <div>
2470
+ <a href="https://arxiv.org/abs/2407.21783"><strong>Llama 3</strong></a>
2471
+ <p>The Llama 3 Herd of Models</p>
2472
+ </div>
2473
 
2474
  <div>
2475
  <a href="https://arxiv.org/abs/2412.19437v1"><strong>DeepSeek-V3</strong></a>
 
2478
 
2479
 
2480
  <h3>Training Frameworks</h3>
 
2481
  <div>
2482
  <a href="https://github.com/facebookresearch/fairscale/tree/main"><strong>FairScale</strong></a>
2483
  <p>PyTorch extension library for large-scale training, offering various parallelism and optimization techniques.</p>
 
2530
  <p>Comprehensive guide to understanding and optimizing GPU memory usage in PyTorch.</p>
2531
  </div>
2532
 
2533
+ <div>
2534
+ <a href="https://huggingface.co/blog/train_memory"><strong>Memory profiling walkthrough on a simple example</strong></a>
2535
+ <p>Visualize and understand GPU memory in PyTorch.</p>
2536
+ </div>
2537
+
2538
  <div>
2539
  <a href="https://pytorch.org/tutorials/intermediate/tensorboard_profiler_tutorial.html"><strong>TensorBoard Profiler Tutorial</strong></a>
2540
  <p>Guide to using TensorBoard's profiling tools for PyTorch models.</p>
 
2596
  <a href="https://arxiv.org/abs/1710.03740"><strong>Mixed precision training</strong></a>
2597
  <p>Introduces mixed precision training techniques for deep learning models.</p>
2598
  </div>
2599
+
2600
+ <div>
2601
+ <a href="https://main-horse.github.io/posts/visualizing-6d/"><strong>@main_horse blog</strong></a>
2602
+ <p>Visualizing 6D Mesh Parallelism</p>
2603
+ </div>
2604
 
2605
  <h3>Hardware</h3>
2606
 
 
2618
  <a href="https://www.semianalysis.com/p/100000-h100-clusters-power-network"><strong>Semianalysis - 100k H100 cluster</strong></a>
2619
  <p>Analysis of large-scale H100 GPU clusters and their implications for AI infrastructure.</p>
2620
  </div>
2621
+
2622
+ <div>
2623
+ <a href="https://modal.com/gpu-glossary/readme"><strong>Modal GPU Glossary </strong></a>
2624
+ <p>CUDA docs for human</p>
2625
+ </div>
2626
 
2627
  <h3>Others</h3>
2628
 
 
2650
  <a href="https://www.harmdevries.com/post/context-length/"><strong>Harm's blog for long context</strong></a>
2651
  <p>Investigation into long context training in terms of data and training cost.</p>
2652
  </div>
2653
+
2654
+ <div>
2655
+ <a href="https://www.youtube.com/@GPUMODE/videos"><strong>GPU Mode</strong></a>
2656
+ <p>A GPU reading group and community.</p>
2657
+ </div>
2658
+
2659
+ <div>
2660
+ <a href="https://youtube.com/playlist?list=PLvtrkEledFjqOLuDB_9FWL3dgivYqc6-3&si=fKWPotx8BflLAUkf"><strong>EleutherAI Youtube channel</strong></a>
2661
+ <p>ML Scalability & Performance Reading Group</p>
2662
+ </div>
2663
+
2664
+ <div>
2665
+ <a href="https://jax-ml.github.io/scaling-book/"><strong>Google Jax Scaling book</strong></a>
2666
+ <p>How to Scale Your Model</p>
2667
+ </div>
2668
+
2669
+ <div>
2670
+ <a href="https://github.com/facebookresearch/capi/blob/main/fsdp.py"><strong>@fvsmassa & @TimDarcet FSDP</strong></a>
2671
+ <p>Standalone ~500 LoC FSDP implementation</p>
2672
+ </div>
2673
+
2674
+ <div>
2675
+ <a href="https://www.thonking.ai/"><strong>thonking.ai</strong></a>
2676
+ <p>Some of Horace He's blogposts</p>
2677
+ </div>
2678
+
2679
+ <div>
2680
+ <a href="https://gordicaleksa.medium.com/eli5-flash-attention-5c44017022ad"><strong>Aleksa's ELI5 Flash Attention</strong></a>
2681
+ <p>Easy explanation of Flash Attention</p>
2682
+ </div>
2683
+
2684
 
2685
  <h2>Appendix</h2>
2686
 
2687
+ <h3>A0: Parallel Programming Crash Course</h3>
2688
+
2689
+ <h4>Broadcast</h4>
2690
+
2691
+ <h4>Reduce & AllReduce</h4>
2692
+
2693
+ <h4>A quick focus on Ring AllReduce</h4>
2694
+
2695
+ <h4>Gather & AllGather </h4>
2696
+
2697
+ <h4>Scatter & ReduceScatter</h4>
2698
+
2699
+ <h4>Barrier</h4>
2700
+
2701
+ <h4>NCCL: NVIDIA Collective Communications Library</h4>
2702
+
2703
+ <h3>A1: Distributed Training Profiling</h3>
2704
+
2705
+ <h3>A2: Math for Compute/Comms Overlap</h3>
2706
+
2707
+
2708
  </d-article>
2709
 
2710
  <d-appendix>
src/index.html CHANGED
@@ -327,7 +327,10 @@
327
 
328
  <h4>Profiling the memory usage</h4>
329
 
330
- <p>Using this snippet [TODO: link to appendix A5], we can understand how memory is allocated throughout training. We can see that memory utilization is not a static thing but varies a lot during training and during a training step:</p>
 
 
 
331
 
332
  <!-- <div class="svg-container l-body-outset" id="svg-first_steps_memory_profile"> </div>
333
  <div class="info" id="svg-first_steps_memory_profile-info">Hover over the elements to see their details</div>
@@ -596,7 +599,7 @@
596
 
597
  <p><img alt="image.png" src="/assets/images/dp_diagram.png" /></p>
598
 
599
- <aside>If you are not familiar with distributed communications patterns like broadcast, gather or all-reduce we put together a small crash course in the Appendix [TODO Link].</aside>
600
 
601
  <p>This involves our first “distributed communication” primitive: <em><strong>all-reduce</em></strong> which handles the synchronization and communication between GPU instances and nodes.</p>
602
 
@@ -809,7 +812,7 @@
809
  <ul>
810
  <li>Forward pass with all bf16 parameters, but different microbatches across DP ranks</li>
811
  <li>Backward pass with all gradients, but different microbatches across DP ranks</li>
812
- <li>Perform an reduce-scatter <strong>[TODO ADD link!]</strong> on the gradients (reduce-scatter is 2 times faster than all reduce! <em>Yay, a third communication primitive!</em>)</li>
813
  <li>- Each replica perform an optimizer step (has only <d-math>\frac{1}{N_d}</d-math> optimizer states) updates only on <d-math>\frac{1}{N_d}</d-math> of fp32 parameters, and then <d-math>\frac{1}{N_d}</d-math> of bf16 parameters.</li>
814
  <li>Perform an all-gather of bf16 parameters to send missing slices back to each replica. This is a new operation in ZeRO, and not used in vanilla DP.</li>
815
  </ul>
@@ -1179,7 +1182,7 @@
1179
  </script>
1180
  <!-- <p><img alt="tp_sp_memoryusage.svg" src="/assets/images/tp_sp_memoryusage.svg" /></p> -->
1181
 
1182
- <p>Does that mean that SP incurs more communication than TP? Well, yes and no. In the forward of a vanilla TP we had two all-reduce per transformer block, and in SP we have two all-gather and two reduce-scatter per transformer block. So SP does twice the number of communication operations as TP. But since an all-reduce operation can be broken down into to an all-gather + reduce-scatter (see in [TODO: Appendix link]) they’re actually equivalent in terms of communication. Same reasoning for backward as we just use the conjugate of each operation (no-op ↔ allreduce and allgather ↔ reducescatter).</p>
1183
 
1184
  <p>If you’ve been paying close attention, you’ll notice that we’re talking about 4 comms ops in each layer (2 for Attention and 2 for MLP). This is how the MLP profiling looks like when using Tensor + Sequence Parallelism:</p>
1185
 
@@ -1214,16 +1217,16 @@
1214
  </ul>
1215
 
1216
  <p><strong>We have seen how TP helps us shard activations across several GPUs by splitting the attention and feedforward operations along the hidden dimension and how SP is a natural complement for the remaining operations by splitting along the sequence dimension.</strong></p>
1217
-
1218
- <p>However, there are two limits to TP and SP: 1) if we scale the sequence length the activation memory will still blow up in the TP region and 2) if the model is too big to fit with TP=8 then we will see a massive slow-down due to the inter-node connectivity.</p>
1219
-
1220
  <div class="note-box">
1221
  <p class="note-box-title">📝 Note</p>
1222
  <p class="note-box-content">
1223
- <p>Since LayerNorms in the SP region operate on different portions of the sequence, their gradients will differ across TP ranks. To ensure the weights stay synchronized, we need to allreduce their gradients during the backward pass, similar to how DP ensures weights stay in sync. This is a small communication overhead since LayerNorm has relatively few parameters.
1224
  </p>
1225
  </div>
1226
 
 
 
1227
  <p>We can tackle problem 1) with Context parallelism and problem 2) with Pipeline parallelism. Let’s first have a look at Context parallelism!</p>
1228
 
1229
  <h2>Context Parallelism</h2>
@@ -1828,12 +1831,12 @@
1828
  <p>On the compute side, GPUs consist of an array of compute units called <strong>Streaming Multiprocessors</strong> (SM). Each SM contains and controls a set of streaming processors, also known as cores. For example, an Nvidia H100 GPU has 132 SMs with 128 cores per SM, resulting in a total of 16,896 cores (see <a href="https://resources.nvidia.com/en-us-tensor-core">docs for tensor cores</a> for details), each capable of handling multiple threads simultaneously.</p>
1829
 
1830
  <p><img alt="image.png" src="/assets/images/diving_primergpu.svg" /></p>
1831
- <p>TODO: Original figure from https://blog.codingconfessions.com/p/gpu-computing.</p>
1832
 
1833
  <p>The memory side is also highly hierarchical with several layers of cache and memory: <strong>Registers</strong> are the smallest units and are private to the threads during executions, <strong>Shared Memory</strong> and <strong>L1 cache are</strong> shared between the threads running on a single SM, higher up is the <strong>L2 cache</strong> shared by all SMs, finally there is the <strong>Global Memory</strong> which is the largest memory on the GPU (the advertised 80 GB for a H100 for instance) but also the slowest to access and query.</p>
1834
 
1835
  <p><img alt="image.png" src="/assets/images/diving_primergpu2.svg" /></p>
1836
- <p>TODO: Original figure from https://www.youtube.com/watch?v=ZQKMZIP3Fzg</p>
1837
 
1838
  <p>The goal of GPU will be to run as many workloads as possible, in parallel, on the GPU cores, by taking advantage of this hierarchical organization of compute/memory.</p>
1839
 
@@ -1870,7 +1873,6 @@
1870
  x & \text{if } x \geq 0
1871
  \end{cases}
1872
  </d-math>
1873
- <p>TODO: something off with spacing but seems the rendering engine</p>
1874
 
1875
  <p>You can start by a simple pytorch implementation and then just add the <code>@torch.compile</code> decorator on top:</p>
1876
 
@@ -2297,7 +2299,7 @@
2297
  <td>Above without FP32 grad accumulation</td>
2298
  <td>bf16</td>
2299
  <td>fp32</td>
2300
- <td></td>
2301
  <td>bf16</td>
2302
  <td>bf16</td>
2303
  <td>fp32 + fp32</td>
@@ -2306,8 +2308,8 @@
2306
  <tr>
2307
  <td>Transformer Engine</td>
2308
  <td>fp8</td>
2309
- <td></td>
2310
- <td></td>
2311
  <td>fp32</td>
2312
  <td>fp32</td>
2313
  <td>fp32 + fp32</td>
@@ -2346,7 +2348,7 @@
2346
  </tbody>
2347
  </table>
2348
 
2349
- <p>Overall, FP8 is still an experimental technique and methods are evolving, but will likely become the standard soon replacing bf16 mixed-precision. To follow public implementations of this, please head to the nanotron’s implementation in [TODO: link to appendix]. </p>
2350
 
2351
  <p>In the future, Blackwell, the next generation of NVIDIA chips, <a href="https://www.nvidia.com/en-us/data-center/technologies/blackwell-architecture/">have been announced </a> to support FP4 training, further speeding up training but without a doubt also introducing a new training stability challenge.</p>
2352
 
@@ -2432,9 +2434,8 @@
2432
  <h3>What’s next?</h3>
2433
 
2434
  <p>You should have a good overview of all the distributed training concepts but there are still things to learn and details we couldn’t cover. To get deeper in the field we recommend doing some of the following steps:</p>
2435
-
2436
  <ul>
2437
- <li>Carefully read some of the landmark or very recent papers. You can find a list of some of the most impactful papers in [TODO References]</li>
2438
  <li>Start from scratch and implement an algorithm yourself. Often a method only fully “clicks” if you implemented it yourself.</li>
2439
  <li>Dive into one of the widely used frameworks and start contributing: fix bugs, answer issues, or implement a new feature. That’s the best way to get in any ML field!</li>
2440
  </ul>
@@ -2672,7 +2673,7 @@
2672
 
2673
  <div>
2674
  <a href="https://www.thonking.ai/"><strong>thonking.ai</strong></a>
2675
- <p>Some of Horace He blogpost</p>
2676
  </div>
2677
 
2678
  <div>
@@ -2683,6 +2684,27 @@
2683
 
2684
  <h2>Appendix</h2>
2685
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2686
  </d-article>
2687
 
2688
  <d-appendix>
 
327
 
328
  <h4>Profiling the memory usage</h4>
329
 
330
+ <p>Using this snippet, we can understand how memory is allocated throughout training. We can see that memory utilization is not a static thing but varies a lot during training and during a training step:</p>
331
+
332
+ <aside>Check out <a target="_self" href="#a1%3A_distributed_training_profiling" class="">A1: Distributed Training Profiling</a> for a walkthrough how to profile your model.</aside>
333
+
334
 
335
  <!-- <div class="svg-container l-body-outset" id="svg-first_steps_memory_profile"> </div>
336
  <div class="info" id="svg-first_steps_memory_profile-info">Hover over the elements to see their details</div>
 
599
 
600
  <p><img alt="image.png" src="/assets/images/dp_diagram.png" /></p>
601
 
602
+ <aside>If you are not familiar with distributed communications patterns like broadcast, gather or all-reduce we put together a small crash course in <a target="_self" href="#a0%3A_parallel_programming_crash_course" class="">A0: Parallel Programming Crash Course</a>.</aside>
603
 
604
  <p>This involves our first “distributed communication” primitive: <em><strong>all-reduce</em></strong> which handles the synchronization and communication between GPU instances and nodes.</p>
605
 
 
812
  <ul>
813
  <li>Forward pass with all bf16 parameters, but different microbatches across DP ranks</li>
814
  <li>Backward pass with all gradients, but different microbatches across DP ranks</li>
815
+ <li>Perform an reduce-scatter on the gradients (reduce-scatter is 2 times faster than all reduce! <em>Yay, a third communication primitive!</em>)</li>
816
  <li>- Each replica perform an optimizer step (has only <d-math>\frac{1}{N_d}</d-math> optimizer states) updates only on <d-math>\frac{1}{N_d}</d-math> of fp32 parameters, and then <d-math>\frac{1}{N_d}</d-math> of bf16 parameters.</li>
817
  <li>Perform an all-gather of bf16 parameters to send missing slices back to each replica. This is a new operation in ZeRO, and not used in vanilla DP.</li>
818
  </ul>
 
1182
  </script>
1183
  <!-- <p><img alt="tp_sp_memoryusage.svg" src="/assets/images/tp_sp_memoryusage.svg" /></p> -->
1184
 
1185
+ <p>Does that mean that SP incurs more communication than TP? Well, yes and no. In the forward of a vanilla TP we had two all-reduce per transformer block, and in SP we have two all-gather and two reduce-scatter per transformer block. So SP does twice the number of communication operations as TP. But since an all-reduce operation can be broken down into to an all-gather + reduce-scatter (see the <a target="_self" href="#a_quick_focus_on_ring_allreduce" class="">A quick focus on Ring AllReduce</a> section in the appendix) they’re actually equivalent in terms of communication. Same reasoning for backward as we just use the conjugate of each operation (no-op ↔ allreduce and allgather ↔ reducescatter).</p>
1186
 
1187
  <p>If you’ve been paying close attention, you’ll notice that we’re talking about 4 comms ops in each layer (2 for Attention and 2 for MLP). This is how the MLP profiling looks like when using Tensor + Sequence Parallelism:</p>
1188
 
 
1217
  </ul>
1218
 
1219
  <p><strong>We have seen how TP helps us shard activations across several GPUs by splitting the attention and feedforward operations along the hidden dimension and how SP is a natural complement for the remaining operations by splitting along the sequence dimension.</strong></p>
1220
+
 
 
1221
  <div class="note-box">
1222
  <p class="note-box-title">📝 Note</p>
1223
  <p class="note-box-content">
1224
+ <p>Since LayerNorms in the SP region operate on different portions of the sequence, their gradients will differ across TP ranks. To ensure the weights stay synchronized, we need to all-reduce their gradients during the backward pass, similar to how DP ensures weights stay in sync. This is a small communication overhead since LayerNorm has relatively few parameters.
1225
  </p>
1226
  </div>
1227
 
1228
+ <p>However, there are two limits to TP and SP: 1) if we scale the sequence length the activation memory will still blow up in the TP region and 2) if the model is too big to fit with TP=8 then we will see a massive slow-down due to the inter-node connectivity.</p>
1229
+
1230
  <p>We can tackle problem 1) with Context parallelism and problem 2) with Pipeline parallelism. Let’s first have a look at Context parallelism!</p>
1231
 
1232
  <h2>Context Parallelism</h2>
 
1831
  <p>On the compute side, GPUs consist of an array of compute units called <strong>Streaming Multiprocessors</strong> (SM). Each SM contains and controls a set of streaming processors, also known as cores. For example, an Nvidia H100 GPU has 132 SMs with 128 cores per SM, resulting in a total of 16,896 cores (see <a href="https://resources.nvidia.com/en-us-tensor-core">docs for tensor cores</a> for details), each capable of handling multiple threads simultaneously.</p>
1832
 
1833
  <p><img alt="image.png" src="/assets/images/diving_primergpu.svg" /></p>
1834
+ <p><em>Source: https://blog.codingconfessions.com/p/gpu-computing.</em></p>
1835
 
1836
  <p>The memory side is also highly hierarchical with several layers of cache and memory: <strong>Registers</strong> are the smallest units and are private to the threads during executions, <strong>Shared Memory</strong> and <strong>L1 cache are</strong> shared between the threads running on a single SM, higher up is the <strong>L2 cache</strong> shared by all SMs, finally there is the <strong>Global Memory</strong> which is the largest memory on the GPU (the advertised 80 GB for a H100 for instance) but also the slowest to access and query.</p>
1837
 
1838
  <p><img alt="image.png" src="/assets/images/diving_primergpu2.svg" /></p>
1839
+ <p><em>Source: https://www.youtube.com/watch?v=ZQKMZIP3Fzg</em></p>
1840
 
1841
  <p>The goal of GPU will be to run as many workloads as possible, in parallel, on the GPU cores, by taking advantage of this hierarchical organization of compute/memory.</p>
1842
 
 
1873
  x & \text{if } x \geq 0
1874
  \end{cases}
1875
  </d-math>
 
1876
 
1877
  <p>You can start by a simple pytorch implementation and then just add the <code>@torch.compile</code> decorator on top:</p>
1878
 
 
2299
  <td>Above without FP32 grad accumulation</td>
2300
  <td>bf16</td>
2301
  <td>fp32</td>
2302
+ <td>n/a</td>
2303
  <td>bf16</td>
2304
  <td>bf16</td>
2305
  <td>fp32 + fp32</td>
 
2308
  <tr>
2309
  <td>Transformer Engine</td>
2310
  <td>fp8</td>
2311
+ <td>n/a</td>
2312
+ <td>n/a</td>
2313
  <td>fp32</td>
2314
  <td>fp32</td>
2315
  <td>fp32 + fp32</td>
 
2348
  </tbody>
2349
  </table>
2350
 
2351
+ <p>Overall, FP8 is still an experimental technique and methods are evolving, but will likely become the standard soon replacing bf16 mixed-precision. To follow a public implementations of this, please head to the nanotron’s implementation in <a href="https://github.com/huggingface/nanotron/pull/70">this PR</a>. </p>
2352
 
2353
  <p>In the future, Blackwell, the next generation of NVIDIA chips, <a href="https://www.nvidia.com/en-us/data-center/technologies/blackwell-architecture/">have been announced </a> to support FP4 training, further speeding up training but without a doubt also introducing a new training stability challenge.</p>
2354
 
 
2434
  <h3>What’s next?</h3>
2435
 
2436
  <p>You should have a good overview of all the distributed training concepts but there are still things to learn and details we couldn’t cover. To get deeper in the field we recommend doing some of the following steps:</p>
 
2437
  <ul>
2438
+ <li>Carefully read some of the landmark or very recent papers. You can find a list of some of the most impactful papers in <a target="_self" href="#references" class="">References</a>.</li>
2439
  <li>Start from scratch and implement an algorithm yourself. Often a method only fully “clicks” if you implemented it yourself.</li>
2440
  <li>Dive into one of the widely used frameworks and start contributing: fix bugs, answer issues, or implement a new feature. That’s the best way to get in any ML field!</li>
2441
  </ul>
 
2673
 
2674
  <div>
2675
  <a href="https://www.thonking.ai/"><strong>thonking.ai</strong></a>
2676
+ <p>Some of Horace He's blogposts</p>
2677
  </div>
2678
 
2679
  <div>
 
2684
 
2685
  <h2>Appendix</h2>
2686
 
2687
+ <h3>A0: Parallel Programming Crash Course</h3>
2688
+
2689
+ <h4>Broadcast</h4>
2690
+
2691
+ <h4>Reduce & AllReduce</h4>
2692
+
2693
+ <h4>A quick focus on Ring AllReduce</h4>
2694
+
2695
+ <h4>Gather & AllGather </h4>
2696
+
2697
+ <h4>Scatter & ReduceScatter</h4>
2698
+
2699
+ <h4>Barrier</h4>
2700
+
2701
+ <h4>NCCL: NVIDIA Collective Communications Library</h4>
2702
+
2703
+ <h3>A1: Distributed Training Profiling</h3>
2704
+
2705
+ <h3>A2: Math for Compute/Comms Overlap</h3>
2706
+
2707
+
2708
  </d-article>
2709
 
2710
  <d-appendix>