Spaces:
Running
Running
todos and references
Browse files- dist/index.html +90 -18
- src/index.html +40 -18
dist/index.html
CHANGED
@@ -327,7 +327,10 @@
|
|
327 |
|
328 |
<h4>Profiling the memory usage</h4>
|
329 |
|
330 |
-
<p>Using this snippet
|
|
|
|
|
|
|
331 |
|
332 |
<!-- <div class="svg-container l-body-outset" id="svg-first_steps_memory_profile"> </div>
|
333 |
<div class="info" id="svg-first_steps_memory_profile-info">Hover over the elements to see their details</div>
|
@@ -596,7 +599,7 @@
|
|
596 |
|
597 |
<p><img alt="image.png" src="/assets/images/dp_diagram.png" /></p>
|
598 |
|
599 |
-
<aside>If you are not familiar with distributed communications patterns like broadcast, gather or all-reduce we put together a small crash course in
|
600 |
|
601 |
<p>This involves our first “distributed communication” primitive: <em><strong>all-reduce</em></strong> which handles the synchronization and communication between GPU instances and nodes.</p>
|
602 |
|
@@ -809,7 +812,7 @@
|
|
809 |
<ul>
|
810 |
<li>Forward pass with all bf16 parameters, but different microbatches across DP ranks</li>
|
811 |
<li>Backward pass with all gradients, but different microbatches across DP ranks</li>
|
812 |
-
<li>Perform an reduce-scatter
|
813 |
<li>- Each replica perform an optimizer step (has only <d-math>\frac{1}{N_d}</d-math> optimizer states) updates only on <d-math>\frac{1}{N_d}</d-math> of fp32 parameters, and then <d-math>\frac{1}{N_d}</d-math> of bf16 parameters.</li>
|
814 |
<li>Perform an all-gather of bf16 parameters to send missing slices back to each replica. This is a new operation in ZeRO, and not used in vanilla DP.</li>
|
815 |
</ul>
|
@@ -1179,7 +1182,7 @@
|
|
1179 |
</script>
|
1180 |
<!-- <p><img alt="tp_sp_memoryusage.svg" src="/assets/images/tp_sp_memoryusage.svg" /></p> -->
|
1181 |
|
1182 |
-
<p>Does that mean that SP incurs more communication than TP? Well, yes and no. In the forward of a vanilla TP we had two all-reduce per transformer block, and in SP we have two all-gather and two reduce-scatter per transformer block. So SP does twice the number of communication operations as TP. But since an all-reduce operation can be broken down into to an all-gather + reduce-scatter (see in
|
1183 |
|
1184 |
<p>If you’ve been paying close attention, you’ll notice that we’re talking about 4 comms ops in each layer (2 for Attention and 2 for MLP). This is how the MLP profiling looks like when using Tensor + Sequence Parallelism:</p>
|
1185 |
|
@@ -1214,16 +1217,16 @@
|
|
1214 |
</ul>
|
1215 |
|
1216 |
<p><strong>We have seen how TP helps us shard activations across several GPUs by splitting the attention and feedforward operations along the hidden dimension and how SP is a natural complement for the remaining operations by splitting along the sequence dimension.</strong></p>
|
1217 |
-
|
1218 |
-
<p>However, there are two limits to TP and SP: 1) if we scale the sequence length the activation memory will still blow up in the TP region and 2) if the model is too big to fit with TP=8 then we will see a massive slow-down due to the inter-node connectivity.</p>
|
1219 |
-
|
1220 |
<div class="note-box">
|
1221 |
<p class="note-box-title">📝 Note</p>
|
1222 |
<p class="note-box-content">
|
1223 |
-
<p>Since LayerNorms in the SP region operate on different portions of the sequence, their gradients will differ across TP ranks. To ensure the weights stay synchronized, we need to
|
1224 |
</p>
|
1225 |
</div>
|
1226 |
|
|
|
|
|
1227 |
<p>We can tackle problem 1) with Context parallelism and problem 2) with Pipeline parallelism. Let’s first have a look at Context parallelism!</p>
|
1228 |
|
1229 |
<h2>Context Parallelism</h2>
|
@@ -1828,12 +1831,12 @@
|
|
1828 |
<p>On the compute side, GPUs consist of an array of compute units called <strong>Streaming Multiprocessors</strong> (SM). Each SM contains and controls a set of streaming processors, also known as cores. For example, an Nvidia H100 GPU has 132 SMs with 128 cores per SM, resulting in a total of 16,896 cores (see <a href="https://resources.nvidia.com/en-us-tensor-core">docs for tensor cores</a> for details), each capable of handling multiple threads simultaneously.</p>
|
1829 |
|
1830 |
<p><img alt="image.png" src="/assets/images/diving_primergpu.svg" /></p>
|
1831 |
-
<p>
|
1832 |
|
1833 |
<p>The memory side is also highly hierarchical with several layers of cache and memory: <strong>Registers</strong> are the smallest units and are private to the threads during executions, <strong>Shared Memory</strong> and <strong>L1 cache are</strong> shared between the threads running on a single SM, higher up is the <strong>L2 cache</strong> shared by all SMs, finally there is the <strong>Global Memory</strong> which is the largest memory on the GPU (the advertised 80 GB for a H100 for instance) but also the slowest to access and query.</p>
|
1834 |
|
1835 |
<p><img alt="image.png" src="/assets/images/diving_primergpu2.svg" /></p>
|
1836 |
-
<p>
|
1837 |
|
1838 |
<p>The goal of GPU will be to run as many workloads as possible, in parallel, on the GPU cores, by taking advantage of this hierarchical organization of compute/memory.</p>
|
1839 |
|
@@ -1870,7 +1873,6 @@
|
|
1870 |
x & \text{if } x \geq 0
|
1871 |
\end{cases}
|
1872 |
</d-math>
|
1873 |
-
<p>TODO: something off with spacing but seems the rendering engine</p>
|
1874 |
|
1875 |
<p>You can start by a simple pytorch implementation and then just add the <code>@torch.compile</code> decorator on top:</p>
|
1876 |
|
@@ -2297,7 +2299,7 @@
|
|
2297 |
<td>Above without FP32 grad accumulation</td>
|
2298 |
<td>bf16</td>
|
2299 |
<td>fp32</td>
|
2300 |
-
<td
|
2301 |
<td>bf16</td>
|
2302 |
<td>bf16</td>
|
2303 |
<td>fp32 + fp32</td>
|
@@ -2306,8 +2308,8 @@
|
|
2306 |
<tr>
|
2307 |
<td>Transformer Engine</td>
|
2308 |
<td>fp8</td>
|
2309 |
-
<td
|
2310 |
-
<td
|
2311 |
<td>fp32</td>
|
2312 |
<td>fp32</td>
|
2313 |
<td>fp32 + fp32</td>
|
@@ -2346,7 +2348,7 @@
|
|
2346 |
</tbody>
|
2347 |
</table>
|
2348 |
|
2349 |
-
<p>Overall, FP8 is still an experimental technique and methods are evolving, but will likely become the standard soon replacing bf16 mixed-precision. To follow public implementations of this, please head to the nanotron’s implementation in
|
2350 |
|
2351 |
<p>In the future, Blackwell, the next generation of NVIDIA chips, <a href="https://www.nvidia.com/en-us/data-center/technologies/blackwell-architecture/">have been announced </a> to support FP4 training, further speeding up training but without a doubt also introducing a new training stability challenge.</p>
|
2352 |
|
@@ -2432,9 +2434,8 @@
|
|
2432 |
<h3>What’s next?</h3>
|
2433 |
|
2434 |
<p>You should have a good overview of all the distributed training concepts but there are still things to learn and details we couldn’t cover. To get deeper in the field we recommend doing some of the following steps:</p>
|
2435 |
-
|
2436 |
<ul>
|
2437 |
-
<li>Carefully read some of the landmark or very recent papers. You can find a list of some of the most impactful papers in
|
2438 |
<li>Start from scratch and implement an algorithm yourself. Often a method only fully “clicks” if you implemented it yourself.</li>
|
2439 |
<li>Dive into one of the widely used frameworks and start contributing: fix bugs, answer issues, or implement a new feature. That’s the best way to get in any ML field!</li>
|
2440 |
</ul>
|
@@ -2464,6 +2465,11 @@
|
|
2464 |
<a href="https://arxiv.org/abs/2312.11805"><strong>Gemini</strong></a>
|
2465 |
<p>Presents Google's multimodal model architecture capable of processing text, images, audio, and video inputs.</p>
|
2466 |
</div>
|
|
|
|
|
|
|
|
|
|
|
2467 |
|
2468 |
<div>
|
2469 |
<a href="https://arxiv.org/abs/2412.19437v1"><strong>DeepSeek-V3</strong></a>
|
@@ -2472,7 +2478,6 @@
|
|
2472 |
|
2473 |
|
2474 |
<h3>Training Frameworks</h3>
|
2475 |
-
|
2476 |
<div>
|
2477 |
<a href="https://github.com/facebookresearch/fairscale/tree/main"><strong>FairScale</strong></a>
|
2478 |
<p>PyTorch extension library for large-scale training, offering various parallelism and optimization techniques.</p>
|
@@ -2525,6 +2530,11 @@
|
|
2525 |
<p>Comprehensive guide to understanding and optimizing GPU memory usage in PyTorch.</p>
|
2526 |
</div>
|
2527 |
|
|
|
|
|
|
|
|
|
|
|
2528 |
<div>
|
2529 |
<a href="https://pytorch.org/tutorials/intermediate/tensorboard_profiler_tutorial.html"><strong>TensorBoard Profiler Tutorial</strong></a>
|
2530 |
<p>Guide to using TensorBoard's profiling tools for PyTorch models.</p>
|
@@ -2586,6 +2596,11 @@
|
|
2586 |
<a href="https://arxiv.org/abs/1710.03740"><strong>Mixed precision training</strong></a>
|
2587 |
<p>Introduces mixed precision training techniques for deep learning models.</p>
|
2588 |
</div>
|
|
|
|
|
|
|
|
|
|
|
2589 |
|
2590 |
<h3>Hardware</h3>
|
2591 |
|
@@ -2603,6 +2618,11 @@
|
|
2603 |
<a href="https://www.semianalysis.com/p/100000-h100-clusters-power-network"><strong>Semianalysis - 100k H100 cluster</strong></a>
|
2604 |
<p>Analysis of large-scale H100 GPU clusters and their implications for AI infrastructure.</p>
|
2605 |
</div>
|
|
|
|
|
|
|
|
|
|
|
2606 |
|
2607 |
<h3>Others</h3>
|
2608 |
|
@@ -2630,9 +2650,61 @@
|
|
2630 |
<a href="https://www.harmdevries.com/post/context-length/"><strong>Harm's blog for long context</strong></a>
|
2631 |
<p>Investigation into long context training in terms of data and training cost.</p>
|
2632 |
</div>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2633 |
|
2634 |
<h2>Appendix</h2>
|
2635 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2636 |
</d-article>
|
2637 |
|
2638 |
<d-appendix>
|
|
|
327 |
|
328 |
<h4>Profiling the memory usage</h4>
|
329 |
|
330 |
+
<p>Using this snippet, we can understand how memory is allocated throughout training. We can see that memory utilization is not a static thing but varies a lot during training and during a training step:</p>
|
331 |
+
|
332 |
+
<aside>Check out <a target="_self" href="#a1%3A_distributed_training_profiling" class="">A1: Distributed Training Profiling</a> for a walkthrough how to profile your model.</aside>
|
333 |
+
|
334 |
|
335 |
<!-- <div class="svg-container l-body-outset" id="svg-first_steps_memory_profile"> </div>
|
336 |
<div class="info" id="svg-first_steps_memory_profile-info">Hover over the elements to see their details</div>
|
|
|
599 |
|
600 |
<p><img alt="image.png" src="/assets/images/dp_diagram.png" /></p>
|
601 |
|
602 |
+
<aside>If you are not familiar with distributed communications patterns like broadcast, gather or all-reduce we put together a small crash course in <a target="_self" href="#a0%3A_parallel_programming_crash_course" class="">A0: Parallel Programming Crash Course</a>.</aside>
|
603 |
|
604 |
<p>This involves our first “distributed communication” primitive: <em><strong>all-reduce</em></strong> which handles the synchronization and communication between GPU instances and nodes.</p>
|
605 |
|
|
|
812 |
<ul>
|
813 |
<li>Forward pass with all bf16 parameters, but different microbatches across DP ranks</li>
|
814 |
<li>Backward pass with all gradients, but different microbatches across DP ranks</li>
|
815 |
+
<li>Perform an reduce-scatter on the gradients (reduce-scatter is 2 times faster than all reduce! <em>Yay, a third communication primitive!</em>)</li>
|
816 |
<li>- Each replica perform an optimizer step (has only <d-math>\frac{1}{N_d}</d-math> optimizer states) updates only on <d-math>\frac{1}{N_d}</d-math> of fp32 parameters, and then <d-math>\frac{1}{N_d}</d-math> of bf16 parameters.</li>
|
817 |
<li>Perform an all-gather of bf16 parameters to send missing slices back to each replica. This is a new operation in ZeRO, and not used in vanilla DP.</li>
|
818 |
</ul>
|
|
|
1182 |
</script>
|
1183 |
<!-- <p><img alt="tp_sp_memoryusage.svg" src="/assets/images/tp_sp_memoryusage.svg" /></p> -->
|
1184 |
|
1185 |
+
<p>Does that mean that SP incurs more communication than TP? Well, yes and no. In the forward of a vanilla TP we had two all-reduce per transformer block, and in SP we have two all-gather and two reduce-scatter per transformer block. So SP does twice the number of communication operations as TP. But since an all-reduce operation can be broken down into to an all-gather + reduce-scatter (see the <a target="_self" href="#a_quick_focus_on_ring_allreduce" class="">A quick focus on Ring AllReduce</a> section in the appendix) they’re actually equivalent in terms of communication. Same reasoning for backward as we just use the conjugate of each operation (no-op ↔ allreduce and allgather ↔ reducescatter).</p>
|
1186 |
|
1187 |
<p>If you’ve been paying close attention, you’ll notice that we’re talking about 4 comms ops in each layer (2 for Attention and 2 for MLP). This is how the MLP profiling looks like when using Tensor + Sequence Parallelism:</p>
|
1188 |
|
|
|
1217 |
</ul>
|
1218 |
|
1219 |
<p><strong>We have seen how TP helps us shard activations across several GPUs by splitting the attention and feedforward operations along the hidden dimension and how SP is a natural complement for the remaining operations by splitting along the sequence dimension.</strong></p>
|
1220 |
+
|
|
|
|
|
1221 |
<div class="note-box">
|
1222 |
<p class="note-box-title">📝 Note</p>
|
1223 |
<p class="note-box-content">
|
1224 |
+
<p>Since LayerNorms in the SP region operate on different portions of the sequence, their gradients will differ across TP ranks. To ensure the weights stay synchronized, we need to all-reduce their gradients during the backward pass, similar to how DP ensures weights stay in sync. This is a small communication overhead since LayerNorm has relatively few parameters.
|
1225 |
</p>
|
1226 |
</div>
|
1227 |
|
1228 |
+
<p>However, there are two limits to TP and SP: 1) if we scale the sequence length the activation memory will still blow up in the TP region and 2) if the model is too big to fit with TP=8 then we will see a massive slow-down due to the inter-node connectivity.</p>
|
1229 |
+
|
1230 |
<p>We can tackle problem 1) with Context parallelism and problem 2) with Pipeline parallelism. Let’s first have a look at Context parallelism!</p>
|
1231 |
|
1232 |
<h2>Context Parallelism</h2>
|
|
|
1831 |
<p>On the compute side, GPUs consist of an array of compute units called <strong>Streaming Multiprocessors</strong> (SM). Each SM contains and controls a set of streaming processors, also known as cores. For example, an Nvidia H100 GPU has 132 SMs with 128 cores per SM, resulting in a total of 16,896 cores (see <a href="https://resources.nvidia.com/en-us-tensor-core">docs for tensor cores</a> for details), each capable of handling multiple threads simultaneously.</p>
|
1832 |
|
1833 |
<p><img alt="image.png" src="/assets/images/diving_primergpu.svg" /></p>
|
1834 |
+
<p><em>Source: https://blog.codingconfessions.com/p/gpu-computing.</em></p>
|
1835 |
|
1836 |
<p>The memory side is also highly hierarchical with several layers of cache and memory: <strong>Registers</strong> are the smallest units and are private to the threads during executions, <strong>Shared Memory</strong> and <strong>L1 cache are</strong> shared between the threads running on a single SM, higher up is the <strong>L2 cache</strong> shared by all SMs, finally there is the <strong>Global Memory</strong> which is the largest memory on the GPU (the advertised 80 GB for a H100 for instance) but also the slowest to access and query.</p>
|
1837 |
|
1838 |
<p><img alt="image.png" src="/assets/images/diving_primergpu2.svg" /></p>
|
1839 |
+
<p><em>Source: https://www.youtube.com/watch?v=ZQKMZIP3Fzg</em></p>
|
1840 |
|
1841 |
<p>The goal of GPU will be to run as many workloads as possible, in parallel, on the GPU cores, by taking advantage of this hierarchical organization of compute/memory.</p>
|
1842 |
|
|
|
1873 |
x & \text{if } x \geq 0
|
1874 |
\end{cases}
|
1875 |
</d-math>
|
|
|
1876 |
|
1877 |
<p>You can start by a simple pytorch implementation and then just add the <code>@torch.compile</code> decorator on top:</p>
|
1878 |
|
|
|
2299 |
<td>Above without FP32 grad accumulation</td>
|
2300 |
<td>bf16</td>
|
2301 |
<td>fp32</td>
|
2302 |
+
<td>n/a</td>
|
2303 |
<td>bf16</td>
|
2304 |
<td>bf16</td>
|
2305 |
<td>fp32 + fp32</td>
|
|
|
2308 |
<tr>
|
2309 |
<td>Transformer Engine</td>
|
2310 |
<td>fp8</td>
|
2311 |
+
<td>n/a</td>
|
2312 |
+
<td>n/a</td>
|
2313 |
<td>fp32</td>
|
2314 |
<td>fp32</td>
|
2315 |
<td>fp32 + fp32</td>
|
|
|
2348 |
</tbody>
|
2349 |
</table>
|
2350 |
|
2351 |
+
<p>Overall, FP8 is still an experimental technique and methods are evolving, but will likely become the standard soon replacing bf16 mixed-precision. To follow a public implementations of this, please head to the nanotron’s implementation in <a href="https://github.com/huggingface/nanotron/pull/70">this PR</a>. </p>
|
2352 |
|
2353 |
<p>In the future, Blackwell, the next generation of NVIDIA chips, <a href="https://www.nvidia.com/en-us/data-center/technologies/blackwell-architecture/">have been announced </a> to support FP4 training, further speeding up training but without a doubt also introducing a new training stability challenge.</p>
|
2354 |
|
|
|
2434 |
<h3>What’s next?</h3>
|
2435 |
|
2436 |
<p>You should have a good overview of all the distributed training concepts but there are still things to learn and details we couldn’t cover. To get deeper in the field we recommend doing some of the following steps:</p>
|
|
|
2437 |
<ul>
|
2438 |
+
<li>Carefully read some of the landmark or very recent papers. You can find a list of some of the most impactful papers in <a target="_self" href="#references" class="">References</a>.</li>
|
2439 |
<li>Start from scratch and implement an algorithm yourself. Often a method only fully “clicks” if you implemented it yourself.</li>
|
2440 |
<li>Dive into one of the widely used frameworks and start contributing: fix bugs, answer issues, or implement a new feature. That’s the best way to get in any ML field!</li>
|
2441 |
</ul>
|
|
|
2465 |
<a href="https://arxiv.org/abs/2312.11805"><strong>Gemini</strong></a>
|
2466 |
<p>Presents Google's multimodal model architecture capable of processing text, images, audio, and video inputs.</p>
|
2467 |
</div>
|
2468 |
+
|
2469 |
+
<div>
|
2470 |
+
<a href="https://arxiv.org/abs/2407.21783"><strong>Llama 3</strong></a>
|
2471 |
+
<p>The Llama 3 Herd of Models</p>
|
2472 |
+
</div>
|
2473 |
|
2474 |
<div>
|
2475 |
<a href="https://arxiv.org/abs/2412.19437v1"><strong>DeepSeek-V3</strong></a>
|
|
|
2478 |
|
2479 |
|
2480 |
<h3>Training Frameworks</h3>
|
|
|
2481 |
<div>
|
2482 |
<a href="https://github.com/facebookresearch/fairscale/tree/main"><strong>FairScale</strong></a>
|
2483 |
<p>PyTorch extension library for large-scale training, offering various parallelism and optimization techniques.</p>
|
|
|
2530 |
<p>Comprehensive guide to understanding and optimizing GPU memory usage in PyTorch.</p>
|
2531 |
</div>
|
2532 |
|
2533 |
+
<div>
|
2534 |
+
<a href="https://huggingface.co/blog/train_memory"><strong>Memory profiling walkthrough on a simple example</strong></a>
|
2535 |
+
<p>Visualize and understand GPU memory in PyTorch.</p>
|
2536 |
+
</div>
|
2537 |
+
|
2538 |
<div>
|
2539 |
<a href="https://pytorch.org/tutorials/intermediate/tensorboard_profiler_tutorial.html"><strong>TensorBoard Profiler Tutorial</strong></a>
|
2540 |
<p>Guide to using TensorBoard's profiling tools for PyTorch models.</p>
|
|
|
2596 |
<a href="https://arxiv.org/abs/1710.03740"><strong>Mixed precision training</strong></a>
|
2597 |
<p>Introduces mixed precision training techniques for deep learning models.</p>
|
2598 |
</div>
|
2599 |
+
|
2600 |
+
<div>
|
2601 |
+
<a href="https://main-horse.github.io/posts/visualizing-6d/"><strong>@main_horse blog</strong></a>
|
2602 |
+
<p>Visualizing 6D Mesh Parallelism</p>
|
2603 |
+
</div>
|
2604 |
|
2605 |
<h3>Hardware</h3>
|
2606 |
|
|
|
2618 |
<a href="https://www.semianalysis.com/p/100000-h100-clusters-power-network"><strong>Semianalysis - 100k H100 cluster</strong></a>
|
2619 |
<p>Analysis of large-scale H100 GPU clusters and their implications for AI infrastructure.</p>
|
2620 |
</div>
|
2621 |
+
|
2622 |
+
<div>
|
2623 |
+
<a href="https://modal.com/gpu-glossary/readme"><strong>Modal GPU Glossary </strong></a>
|
2624 |
+
<p>CUDA docs for human</p>
|
2625 |
+
</div>
|
2626 |
|
2627 |
<h3>Others</h3>
|
2628 |
|
|
|
2650 |
<a href="https://www.harmdevries.com/post/context-length/"><strong>Harm's blog for long context</strong></a>
|
2651 |
<p>Investigation into long context training in terms of data and training cost.</p>
|
2652 |
</div>
|
2653 |
+
|
2654 |
+
<div>
|
2655 |
+
<a href="https://www.youtube.com/@GPUMODE/videos"><strong>GPU Mode</strong></a>
|
2656 |
+
<p>A GPU reading group and community.</p>
|
2657 |
+
</div>
|
2658 |
+
|
2659 |
+
<div>
|
2660 |
+
<a href="https://youtube.com/playlist?list=PLvtrkEledFjqOLuDB_9FWL3dgivYqc6-3&si=fKWPotx8BflLAUkf"><strong>EleutherAI Youtube channel</strong></a>
|
2661 |
+
<p>ML Scalability & Performance Reading Group</p>
|
2662 |
+
</div>
|
2663 |
+
|
2664 |
+
<div>
|
2665 |
+
<a href="https://jax-ml.github.io/scaling-book/"><strong>Google Jax Scaling book</strong></a>
|
2666 |
+
<p>How to Scale Your Model</p>
|
2667 |
+
</div>
|
2668 |
+
|
2669 |
+
<div>
|
2670 |
+
<a href="https://github.com/facebookresearch/capi/blob/main/fsdp.py"><strong>@fvsmassa & @TimDarcet FSDP</strong></a>
|
2671 |
+
<p>Standalone ~500 LoC FSDP implementation</p>
|
2672 |
+
</div>
|
2673 |
+
|
2674 |
+
<div>
|
2675 |
+
<a href="https://www.thonking.ai/"><strong>thonking.ai</strong></a>
|
2676 |
+
<p>Some of Horace He's blogposts</p>
|
2677 |
+
</div>
|
2678 |
+
|
2679 |
+
<div>
|
2680 |
+
<a href="https://gordicaleksa.medium.com/eli5-flash-attention-5c44017022ad"><strong>Aleksa's ELI5 Flash Attention</strong></a>
|
2681 |
+
<p>Easy explanation of Flash Attention</p>
|
2682 |
+
</div>
|
2683 |
+
|
2684 |
|
2685 |
<h2>Appendix</h2>
|
2686 |
|
2687 |
+
<h3>A0: Parallel Programming Crash Course</h3>
|
2688 |
+
|
2689 |
+
<h4>Broadcast</h4>
|
2690 |
+
|
2691 |
+
<h4>Reduce & AllReduce</h4>
|
2692 |
+
|
2693 |
+
<h4>A quick focus on Ring AllReduce</h4>
|
2694 |
+
|
2695 |
+
<h4>Gather & AllGather </h4>
|
2696 |
+
|
2697 |
+
<h4>Scatter & ReduceScatter</h4>
|
2698 |
+
|
2699 |
+
<h4>Barrier</h4>
|
2700 |
+
|
2701 |
+
<h4>NCCL: NVIDIA Collective Communications Library</h4>
|
2702 |
+
|
2703 |
+
<h3>A1: Distributed Training Profiling</h3>
|
2704 |
+
|
2705 |
+
<h3>A2: Math for Compute/Comms Overlap</h3>
|
2706 |
+
|
2707 |
+
|
2708 |
</d-article>
|
2709 |
|
2710 |
<d-appendix>
|
src/index.html
CHANGED
@@ -327,7 +327,10 @@
|
|
327 |
|
328 |
<h4>Profiling the memory usage</h4>
|
329 |
|
330 |
-
<p>Using this snippet
|
|
|
|
|
|
|
331 |
|
332 |
<!-- <div class="svg-container l-body-outset" id="svg-first_steps_memory_profile"> </div>
|
333 |
<div class="info" id="svg-first_steps_memory_profile-info">Hover over the elements to see their details</div>
|
@@ -596,7 +599,7 @@
|
|
596 |
|
597 |
<p><img alt="image.png" src="/assets/images/dp_diagram.png" /></p>
|
598 |
|
599 |
-
<aside>If you are not familiar with distributed communications patterns like broadcast, gather or all-reduce we put together a small crash course in
|
600 |
|
601 |
<p>This involves our first “distributed communication” primitive: <em><strong>all-reduce</em></strong> which handles the synchronization and communication between GPU instances and nodes.</p>
|
602 |
|
@@ -809,7 +812,7 @@
|
|
809 |
<ul>
|
810 |
<li>Forward pass with all bf16 parameters, but different microbatches across DP ranks</li>
|
811 |
<li>Backward pass with all gradients, but different microbatches across DP ranks</li>
|
812 |
-
<li>Perform an reduce-scatter
|
813 |
<li>- Each replica perform an optimizer step (has only <d-math>\frac{1}{N_d}</d-math> optimizer states) updates only on <d-math>\frac{1}{N_d}</d-math> of fp32 parameters, and then <d-math>\frac{1}{N_d}</d-math> of bf16 parameters.</li>
|
814 |
<li>Perform an all-gather of bf16 parameters to send missing slices back to each replica. This is a new operation in ZeRO, and not used in vanilla DP.</li>
|
815 |
</ul>
|
@@ -1179,7 +1182,7 @@
|
|
1179 |
</script>
|
1180 |
<!-- <p><img alt="tp_sp_memoryusage.svg" src="/assets/images/tp_sp_memoryusage.svg" /></p> -->
|
1181 |
|
1182 |
-
<p>Does that mean that SP incurs more communication than TP? Well, yes and no. In the forward of a vanilla TP we had two all-reduce per transformer block, and in SP we have two all-gather and two reduce-scatter per transformer block. So SP does twice the number of communication operations as TP. But since an all-reduce operation can be broken down into to an all-gather + reduce-scatter (see in
|
1183 |
|
1184 |
<p>If you’ve been paying close attention, you’ll notice that we’re talking about 4 comms ops in each layer (2 for Attention and 2 for MLP). This is how the MLP profiling looks like when using Tensor + Sequence Parallelism:</p>
|
1185 |
|
@@ -1214,16 +1217,16 @@
|
|
1214 |
</ul>
|
1215 |
|
1216 |
<p><strong>We have seen how TP helps us shard activations across several GPUs by splitting the attention and feedforward operations along the hidden dimension and how SP is a natural complement for the remaining operations by splitting along the sequence dimension.</strong></p>
|
1217 |
-
|
1218 |
-
<p>However, there are two limits to TP and SP: 1) if we scale the sequence length the activation memory will still blow up in the TP region and 2) if the model is too big to fit with TP=8 then we will see a massive slow-down due to the inter-node connectivity.</p>
|
1219 |
-
|
1220 |
<div class="note-box">
|
1221 |
<p class="note-box-title">📝 Note</p>
|
1222 |
<p class="note-box-content">
|
1223 |
-
<p>Since LayerNorms in the SP region operate on different portions of the sequence, their gradients will differ across TP ranks. To ensure the weights stay synchronized, we need to
|
1224 |
</p>
|
1225 |
</div>
|
1226 |
|
|
|
|
|
1227 |
<p>We can tackle problem 1) with Context parallelism and problem 2) with Pipeline parallelism. Let’s first have a look at Context parallelism!</p>
|
1228 |
|
1229 |
<h2>Context Parallelism</h2>
|
@@ -1828,12 +1831,12 @@
|
|
1828 |
<p>On the compute side, GPUs consist of an array of compute units called <strong>Streaming Multiprocessors</strong> (SM). Each SM contains and controls a set of streaming processors, also known as cores. For example, an Nvidia H100 GPU has 132 SMs with 128 cores per SM, resulting in a total of 16,896 cores (see <a href="https://resources.nvidia.com/en-us-tensor-core">docs for tensor cores</a> for details), each capable of handling multiple threads simultaneously.</p>
|
1829 |
|
1830 |
<p><img alt="image.png" src="/assets/images/diving_primergpu.svg" /></p>
|
1831 |
-
<p>
|
1832 |
|
1833 |
<p>The memory side is also highly hierarchical with several layers of cache and memory: <strong>Registers</strong> are the smallest units and are private to the threads during executions, <strong>Shared Memory</strong> and <strong>L1 cache are</strong> shared between the threads running on a single SM, higher up is the <strong>L2 cache</strong> shared by all SMs, finally there is the <strong>Global Memory</strong> which is the largest memory on the GPU (the advertised 80 GB for a H100 for instance) but also the slowest to access and query.</p>
|
1834 |
|
1835 |
<p><img alt="image.png" src="/assets/images/diving_primergpu2.svg" /></p>
|
1836 |
-
<p>
|
1837 |
|
1838 |
<p>The goal of GPU will be to run as many workloads as possible, in parallel, on the GPU cores, by taking advantage of this hierarchical organization of compute/memory.</p>
|
1839 |
|
@@ -1870,7 +1873,6 @@
|
|
1870 |
x & \text{if } x \geq 0
|
1871 |
\end{cases}
|
1872 |
</d-math>
|
1873 |
-
<p>TODO: something off with spacing but seems the rendering engine</p>
|
1874 |
|
1875 |
<p>You can start by a simple pytorch implementation and then just add the <code>@torch.compile</code> decorator on top:</p>
|
1876 |
|
@@ -2297,7 +2299,7 @@
|
|
2297 |
<td>Above without FP32 grad accumulation</td>
|
2298 |
<td>bf16</td>
|
2299 |
<td>fp32</td>
|
2300 |
-
<td
|
2301 |
<td>bf16</td>
|
2302 |
<td>bf16</td>
|
2303 |
<td>fp32 + fp32</td>
|
@@ -2306,8 +2308,8 @@
|
|
2306 |
<tr>
|
2307 |
<td>Transformer Engine</td>
|
2308 |
<td>fp8</td>
|
2309 |
-
<td
|
2310 |
-
<td
|
2311 |
<td>fp32</td>
|
2312 |
<td>fp32</td>
|
2313 |
<td>fp32 + fp32</td>
|
@@ -2346,7 +2348,7 @@
|
|
2346 |
</tbody>
|
2347 |
</table>
|
2348 |
|
2349 |
-
<p>Overall, FP8 is still an experimental technique and methods are evolving, but will likely become the standard soon replacing bf16 mixed-precision. To follow public implementations of this, please head to the nanotron’s implementation in
|
2350 |
|
2351 |
<p>In the future, Blackwell, the next generation of NVIDIA chips, <a href="https://www.nvidia.com/en-us/data-center/technologies/blackwell-architecture/">have been announced </a> to support FP4 training, further speeding up training but without a doubt also introducing a new training stability challenge.</p>
|
2352 |
|
@@ -2432,9 +2434,8 @@
|
|
2432 |
<h3>What’s next?</h3>
|
2433 |
|
2434 |
<p>You should have a good overview of all the distributed training concepts but there are still things to learn and details we couldn’t cover. To get deeper in the field we recommend doing some of the following steps:</p>
|
2435 |
-
|
2436 |
<ul>
|
2437 |
-
<li>Carefully read some of the landmark or very recent papers. You can find a list of some of the most impactful papers in
|
2438 |
<li>Start from scratch and implement an algorithm yourself. Often a method only fully “clicks” if you implemented it yourself.</li>
|
2439 |
<li>Dive into one of the widely used frameworks and start contributing: fix bugs, answer issues, or implement a new feature. That’s the best way to get in any ML field!</li>
|
2440 |
</ul>
|
@@ -2672,7 +2673,7 @@
|
|
2672 |
|
2673 |
<div>
|
2674 |
<a href="https://www.thonking.ai/"><strong>thonking.ai</strong></a>
|
2675 |
-
<p>Some of Horace He
|
2676 |
</div>
|
2677 |
|
2678 |
<div>
|
@@ -2683,6 +2684,27 @@
|
|
2683 |
|
2684 |
<h2>Appendix</h2>
|
2685 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2686 |
</d-article>
|
2687 |
|
2688 |
<d-appendix>
|
|
|
327 |
|
328 |
<h4>Profiling the memory usage</h4>
|
329 |
|
330 |
+
<p>Using this snippet, we can understand how memory is allocated throughout training. We can see that memory utilization is not a static thing but varies a lot during training and during a training step:</p>
|
331 |
+
|
332 |
+
<aside>Check out <a target="_self" href="#a1%3A_distributed_training_profiling" class="">A1: Distributed Training Profiling</a> for a walkthrough how to profile your model.</aside>
|
333 |
+
|
334 |
|
335 |
<!-- <div class="svg-container l-body-outset" id="svg-first_steps_memory_profile"> </div>
|
336 |
<div class="info" id="svg-first_steps_memory_profile-info">Hover over the elements to see their details</div>
|
|
|
599 |
|
600 |
<p><img alt="image.png" src="/assets/images/dp_diagram.png" /></p>
|
601 |
|
602 |
+
<aside>If you are not familiar with distributed communications patterns like broadcast, gather or all-reduce we put together a small crash course in <a target="_self" href="#a0%3A_parallel_programming_crash_course" class="">A0: Parallel Programming Crash Course</a>.</aside>
|
603 |
|
604 |
<p>This involves our first “distributed communication” primitive: <em><strong>all-reduce</em></strong> which handles the synchronization and communication between GPU instances and nodes.</p>
|
605 |
|
|
|
812 |
<ul>
|
813 |
<li>Forward pass with all bf16 parameters, but different microbatches across DP ranks</li>
|
814 |
<li>Backward pass with all gradients, but different microbatches across DP ranks</li>
|
815 |
+
<li>Perform an reduce-scatter on the gradients (reduce-scatter is 2 times faster than all reduce! <em>Yay, a third communication primitive!</em>)</li>
|
816 |
<li>- Each replica perform an optimizer step (has only <d-math>\frac{1}{N_d}</d-math> optimizer states) updates only on <d-math>\frac{1}{N_d}</d-math> of fp32 parameters, and then <d-math>\frac{1}{N_d}</d-math> of bf16 parameters.</li>
|
817 |
<li>Perform an all-gather of bf16 parameters to send missing slices back to each replica. This is a new operation in ZeRO, and not used in vanilla DP.</li>
|
818 |
</ul>
|
|
|
1182 |
</script>
|
1183 |
<!-- <p><img alt="tp_sp_memoryusage.svg" src="/assets/images/tp_sp_memoryusage.svg" /></p> -->
|
1184 |
|
1185 |
+
<p>Does that mean that SP incurs more communication than TP? Well, yes and no. In the forward of a vanilla TP we had two all-reduce per transformer block, and in SP we have two all-gather and two reduce-scatter per transformer block. So SP does twice the number of communication operations as TP. But since an all-reduce operation can be broken down into to an all-gather + reduce-scatter (see the <a target="_self" href="#a_quick_focus_on_ring_allreduce" class="">A quick focus on Ring AllReduce</a> section in the appendix) they’re actually equivalent in terms of communication. Same reasoning for backward as we just use the conjugate of each operation (no-op ↔ allreduce and allgather ↔ reducescatter).</p>
|
1186 |
|
1187 |
<p>If you’ve been paying close attention, you’ll notice that we’re talking about 4 comms ops in each layer (2 for Attention and 2 for MLP). This is how the MLP profiling looks like when using Tensor + Sequence Parallelism:</p>
|
1188 |
|
|
|
1217 |
</ul>
|
1218 |
|
1219 |
<p><strong>We have seen how TP helps us shard activations across several GPUs by splitting the attention and feedforward operations along the hidden dimension and how SP is a natural complement for the remaining operations by splitting along the sequence dimension.</strong></p>
|
1220 |
+
|
|
|
|
|
1221 |
<div class="note-box">
|
1222 |
<p class="note-box-title">📝 Note</p>
|
1223 |
<p class="note-box-content">
|
1224 |
+
<p>Since LayerNorms in the SP region operate on different portions of the sequence, their gradients will differ across TP ranks. To ensure the weights stay synchronized, we need to all-reduce their gradients during the backward pass, similar to how DP ensures weights stay in sync. This is a small communication overhead since LayerNorm has relatively few parameters.
|
1225 |
</p>
|
1226 |
</div>
|
1227 |
|
1228 |
+
<p>However, there are two limits to TP and SP: 1) if we scale the sequence length the activation memory will still blow up in the TP region and 2) if the model is too big to fit with TP=8 then we will see a massive slow-down due to the inter-node connectivity.</p>
|
1229 |
+
|
1230 |
<p>We can tackle problem 1) with Context parallelism and problem 2) with Pipeline parallelism. Let’s first have a look at Context parallelism!</p>
|
1231 |
|
1232 |
<h2>Context Parallelism</h2>
|
|
|
1831 |
<p>On the compute side, GPUs consist of an array of compute units called <strong>Streaming Multiprocessors</strong> (SM). Each SM contains and controls a set of streaming processors, also known as cores. For example, an Nvidia H100 GPU has 132 SMs with 128 cores per SM, resulting in a total of 16,896 cores (see <a href="https://resources.nvidia.com/en-us-tensor-core">docs for tensor cores</a> for details), each capable of handling multiple threads simultaneously.</p>
|
1832 |
|
1833 |
<p><img alt="image.png" src="/assets/images/diving_primergpu.svg" /></p>
|
1834 |
+
<p><em>Source: https://blog.codingconfessions.com/p/gpu-computing.</em></p>
|
1835 |
|
1836 |
<p>The memory side is also highly hierarchical with several layers of cache and memory: <strong>Registers</strong> are the smallest units and are private to the threads during executions, <strong>Shared Memory</strong> and <strong>L1 cache are</strong> shared between the threads running on a single SM, higher up is the <strong>L2 cache</strong> shared by all SMs, finally there is the <strong>Global Memory</strong> which is the largest memory on the GPU (the advertised 80 GB for a H100 for instance) but also the slowest to access and query.</p>
|
1837 |
|
1838 |
<p><img alt="image.png" src="/assets/images/diving_primergpu2.svg" /></p>
|
1839 |
+
<p><em>Source: https://www.youtube.com/watch?v=ZQKMZIP3Fzg</em></p>
|
1840 |
|
1841 |
<p>The goal of GPU will be to run as many workloads as possible, in parallel, on the GPU cores, by taking advantage of this hierarchical organization of compute/memory.</p>
|
1842 |
|
|
|
1873 |
x & \text{if } x \geq 0
|
1874 |
\end{cases}
|
1875 |
</d-math>
|
|
|
1876 |
|
1877 |
<p>You can start by a simple pytorch implementation and then just add the <code>@torch.compile</code> decorator on top:</p>
|
1878 |
|
|
|
2299 |
<td>Above without FP32 grad accumulation</td>
|
2300 |
<td>bf16</td>
|
2301 |
<td>fp32</td>
|
2302 |
+
<td>n/a</td>
|
2303 |
<td>bf16</td>
|
2304 |
<td>bf16</td>
|
2305 |
<td>fp32 + fp32</td>
|
|
|
2308 |
<tr>
|
2309 |
<td>Transformer Engine</td>
|
2310 |
<td>fp8</td>
|
2311 |
+
<td>n/a</td>
|
2312 |
+
<td>n/a</td>
|
2313 |
<td>fp32</td>
|
2314 |
<td>fp32</td>
|
2315 |
<td>fp32 + fp32</td>
|
|
|
2348 |
</tbody>
|
2349 |
</table>
|
2350 |
|
2351 |
+
<p>Overall, FP8 is still an experimental technique and methods are evolving, but will likely become the standard soon replacing bf16 mixed-precision. To follow a public implementations of this, please head to the nanotron’s implementation in <a href="https://github.com/huggingface/nanotron/pull/70">this PR</a>. </p>
|
2352 |
|
2353 |
<p>In the future, Blackwell, the next generation of NVIDIA chips, <a href="https://www.nvidia.com/en-us/data-center/technologies/blackwell-architecture/">have been announced </a> to support FP4 training, further speeding up training but without a doubt also introducing a new training stability challenge.</p>
|
2354 |
|
|
|
2434 |
<h3>What’s next?</h3>
|
2435 |
|
2436 |
<p>You should have a good overview of all the distributed training concepts but there are still things to learn and details we couldn’t cover. To get deeper in the field we recommend doing some of the following steps:</p>
|
|
|
2437 |
<ul>
|
2438 |
+
<li>Carefully read some of the landmark or very recent papers. You can find a list of some of the most impactful papers in <a target="_self" href="#references" class="">References</a>.</li>
|
2439 |
<li>Start from scratch and implement an algorithm yourself. Often a method only fully “clicks” if you implemented it yourself.</li>
|
2440 |
<li>Dive into one of the widely used frameworks and start contributing: fix bugs, answer issues, or implement a new feature. That’s the best way to get in any ML field!</li>
|
2441 |
</ul>
|
|
|
2673 |
|
2674 |
<div>
|
2675 |
<a href="https://www.thonking.ai/"><strong>thonking.ai</strong></a>
|
2676 |
+
<p>Some of Horace He's blogposts</p>
|
2677 |
</div>
|
2678 |
|
2679 |
<div>
|
|
|
2684 |
|
2685 |
<h2>Appendix</h2>
|
2686 |
|
2687 |
+
<h3>A0: Parallel Programming Crash Course</h3>
|
2688 |
+
|
2689 |
+
<h4>Broadcast</h4>
|
2690 |
+
|
2691 |
+
<h4>Reduce & AllReduce</h4>
|
2692 |
+
|
2693 |
+
<h4>A quick focus on Ring AllReduce</h4>
|
2694 |
+
|
2695 |
+
<h4>Gather & AllGather </h4>
|
2696 |
+
|
2697 |
+
<h4>Scatter & ReduceScatter</h4>
|
2698 |
+
|
2699 |
+
<h4>Barrier</h4>
|
2700 |
+
|
2701 |
+
<h4>NCCL: NVIDIA Collective Communications Library</h4>
|
2702 |
+
|
2703 |
+
<h3>A1: Distributed Training Profiling</h3>
|
2704 |
+
|
2705 |
+
<h3>A2: Math for Compute/Comms Overlap</h3>
|
2706 |
+
|
2707 |
+
|
2708 |
</d-article>
|
2709 |
|
2710 |
<d-appendix>
|