Spaces:
Sleeping
Sleeping
title: "Optimizing LLM Performance Using Triton" | |
format: | |
revealjs: | |
theme: dark | |
transition: slide | |
slide-number: true | |
author: "Matej Sirovatka" | |
date: "2025-02-22" | |
## `whoami` | |
- My name is Matej | |
- I'm a Master's student at Brno University of Technology | |
- I'm currently working on distributed training at Hugging Face π€ | |
## `What is Triton?` | |
- open-source programming language for GPU kernels by Open AI | |
- Designed for AI/ML workloads | |
- Simplifies GPU programming compared to CUDA | |
{.center fig-align="center"} | |
## `Why Optimize with Triton?` | |
- Simple yet effective | |
- Less headache than CUDA | |
- GPUs go `brrrrrrr` π | |
- Feel cool when your kernel is faster than PyTorch π | |
## `Example Problem: KL Divergence` | |
- commonly used in LLMs for knowledge distillation | |
- for probability distributions $P$ and $Q$, the Kullback-Leibler divergence is defined as: | |
$$ | |
D_{KL}(P \| Q) = \sum_{i} P_i \log\left(\frac{P_i}{Q_i}\right) | |
$$ | |
```python | |
import torch | |
from torch.nn.functional import kl_div | |
def kl_div_torch(p: torch.Tensor, q: torch.Tensor) -> torch.Tensor: | |
return kl_div(p, q) | |
``` | |
## `How about Triton?` | |
```python | |
import triton | |
import triton.language as tl | |
@triton.jit | |
def kl_div_triton( | |
p_ptr, q_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr | |
): | |
pid = tl.program_id(0) | |
block_start = pid * BLOCK_SIZE | |
offsets = block_start + tl.arange(0, BLOCK_SIZE) | |
mask = offsets < n_elements | |
p = tl.load(p_ptr + offsets, mask=mask) | |
q = tl.load(q_ptr + offsets, mask=mask) | |
output = p * (tl.log(p) - tl.log(q)) | |
tl.store(output_ptr + offsets, output, mask=mask) | |
``` | |
## `How to integrate with PyTorch?` | |
- How to use our custom kernel with PyTorch autograd? | |
```python | |
import torch | |
class KlDiv(torch.autograd.Function): | |
@staticmethod | |
def forward(ctx, p, q): | |
ctx.save_for_backward(q) | |
output = torch.empty_like(p) | |
grid = (len(p) + 512 - 1) // 512 | |
kl_div_triton[grid](p, q, output, len(p), BLOCK_SIZE=512) | |
return output | |
@staticmethod | |
def backward(ctx, grad_output): | |
q = ctx.saved_tensors[0] | |
# Calculate gradients (another triton kernel) | |
return ... | |
``` | |
## `Some benchmarks` | |
- A KL Divergence kernel that is currently used in [Liger Kernel](https://github.com/linkedin/liger-kernel) written by @me | |
:::: {.columns} | |
::: {.column width="50%"} | |
{.center fig-align="center"} | |
::: | |
::: {.column width="50%"} | |
{.center fig-align="center"} | |
::: | |
:::: | |
## `Do I have to write everything?` | |
- TLDR: No | |
- Many cool projects already using Triton | |
- Better Integration with PyTorch and even Hugging Face π€ | |
- Liger Kernel, Unsloth AI, etc. | |
:::: {.columns} | |
::: {.column width="50%"} | |
{.center fig-align="center"} | |
::: | |
::: {.column width="50%"} | |
{.center fig-align="center"} | |
::: | |
:::: | |
## `So how can I use this in my LLM? π` | |
- Liger Kernel is a great example, providing examples of how to integrate with Hugging Face π€ Trainer | |
```diff | |
- from transformers import AutoModelForCausalLM | |
+ from liger_kernel.transformers import AutoLigerKernelForCausalLM | |
model_path = "meta-llama/Meta-Llama-3-8B-Instruct" | |
- model = AutoModelForCausalLM.from_pretrained(model_path) | |
+ model = AutoLigerKernelForCausalLM.from_pretrained(model_path) | |
# training/inference logic... | |
``` | |
## `Key Optimization Techniques adapted by Liger Kernel` | |
- Kernel Fusion | |
- Domain-specific optimizations | |
- Memory Access Patterns | |
- Preemptive memory freeing | |
## `Aaand some more benchmarks π` | |
- Saving memory is key to run bigger batch size on smaller GPUs | |
:::: {.columns} | |
::: {.column width="50%"} | |
{fig-align="center"} | |
::: | |
::: {.column width="50%"} | |
{fig-align="center"} | |
::: | |
:::: | |
## `Last benchmark I promise...` | |
- But is it faster? Yes, it is! | |
{fig-align="center" height=50% width=50%} | |
:::: {.columns} | |
::: {.column width="60%"} | |
*Attention is all you need, so I thank you for yours!* π€ | |
::: | |
::: {.column width="40%"} | |
{height=25% width=25% fig-align="center"} | |
::: | |
:::: | |