drbh
commited on
Commit
·
09e15a7
1
Parent(s):
3bdb4b8
fix: add quickstart and avoid autotune when no cuda
Browse files- README.md +54 -30
- readme_example.py +51 -0
- torch-ext/megablocks/backend/kernels.py +14 -0
README.md
CHANGED
@@ -4,39 +4,63 @@ tags:
|
|
4 |
- kernel
|
5 |
---
|
6 |
|
7 |
-
|
8 |
|
9 |
```bash
|
10 |
-
|
11 |
```
|
12 |
|
13 |
-
expected output:
|
14 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
```
|
16 |
-
============== test session starts ===============
|
17 |
-
platform linux -- Python 3.12.10, pytest-8.3.5, pluggy-1.5.0
|
18 |
-
rootdir: /home/ubuntu/Projects/megablocks-moe
|
19 |
-
plugins: hypothesis-6.130.12
|
20 |
-
collecting 43 items world_size=1
|
21 |
-
collected 387 items
|
22 |
-
|
23 |
-
tests/layers/moe_test.py ...........................................
|
24 |
-
tests/ops/binned_gather_test.py .....................
|
25 |
-
tests/ops/binned_scatter_test.py .....................
|
26 |
-
tests/ops/cumsum_test.py ................................
|
27 |
-
tests/ops/histogram_test.py ......................................................
|
28 |
-
tests/ops/padded_gather_test.py ......................................
|
29 |
-
tests/ops/padded_scatter_test.py ......................................................
|
30 |
-
tests/ops/replicate_test.py ..................................................................................
|
31 |
-
tests/ops/sort_test.py ..................
|
32 |
-
tests/ops/topology_test.py ....................
|
33 |
-
tests/test_mb_moe.py megablocks_moe module imported successfully.
|
34 |
-
Available functions: ['Arguments', 'MLP', 'MoE', 'ParallelDroplessMLP', 'ParallelMLP', 'SparseGLU', 'SparseMLP', '__all__', '__builtins__', '__cached__', '__doc__', '__file__', '__loader__', '__name__', '__package__', '__path__', '__spec__', '_megablocks_a4f6452_dirty', '_ops', 'argsort', 'backend', 'cumsum', 'dMoE', 'exclusive_cumsum', 'get_load_balancing_loss', 'grouped_gemm_util', 'histogram', 'inclusive_cumsum', 'indices', 'layers', 'ops', 'replicate_backward', 'replicate_forward', 'sort', 'torch']
|
35 |
-
.cumsum output: tensor([0, 1, 3, 6], device='cuda:0', dtype=torch.int16)
|
36 |
-
...
|
37 |
-
|
38 |
-
================ warnings summary ================
|
39 |
-
...
|
40 |
-
-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
|
41 |
-
======= 387 passed, 18 warnings in 54.63s ========
|
42 |
-
```
|
|
|
4 |
- kernel
|
5 |
---
|
6 |
|
7 |
+
## Quickstart
|
8 |
|
9 |
```bash
|
10 |
+
uv run https://huggingface.co/kernels-community/megablocks/raw/main/readme_example.py
|
11 |
```
|
12 |
|
|
|
13 |
|
14 |
+
```python
|
15 |
+
# /// script
|
16 |
+
# requires-python = "==3.10"
|
17 |
+
# dependencies = [
|
18 |
+
# "numpy",
|
19 |
+
# "kernels",
|
20 |
+
# "torch"
|
21 |
+
# ]
|
22 |
+
# ///
|
23 |
+
|
24 |
+
import torch
|
25 |
+
from collections import namedtuple
|
26 |
+
|
27 |
+
from kernels import get_kernel
|
28 |
+
|
29 |
+
# Make reproducible
|
30 |
+
torch.manual_seed(42)
|
31 |
+
torch.cuda.manual_seed(42)
|
32 |
+
|
33 |
+
# Download optimized kernels from the Hugging Face hub
|
34 |
+
megablocks = get_kernel("kernels-community/megablocks")
|
35 |
+
print("MegaBlocks kernel downloaded successfully.")
|
36 |
+
|
37 |
+
model = megablocks.layers.MegaBlocksMoeMLP()
|
38 |
+
model.experts = namedtuple("Experts", ["gate_up_proj", "gate_down_proj", "down_proj", "hidden_size"])
|
39 |
+
print("MegaBlocksMoeMLP instance created successfully.")
|
40 |
+
|
41 |
+
# Config
|
42 |
+
ne, hs, isz = 128, 1152, 3072
|
43 |
+
|
44 |
+
# Router with proper initialization
|
45 |
+
model.router = torch.nn.Linear(hs, ne, device="cuda")
|
46 |
+
torch.nn.init.kaiming_uniform_(model.router.weight)
|
47 |
+
|
48 |
+
# Expert layers with realistic weights
|
49 |
+
e = model.experts
|
50 |
+
e.gate_up_proj = torch.nn.Parameter(torch.randn(ne, hs, isz, device="cuda") * 0.02)
|
51 |
+
e.gate_up_proj_bias = torch.nn.Parameter(torch.zeros(ne, isz, device="cuda"))
|
52 |
+
e.down_proj = torch.nn.Parameter(torch.randn(ne, 1536, hs, device="cuda") * 0.02)
|
53 |
+
e.down_proj_bias = torch.nn.Parameter(torch.zeros(ne, hs, device="cuda"))
|
54 |
+
e.hidden_size = hs
|
55 |
+
print("Expert layers initialized successfully.")
|
56 |
+
|
57 |
+
# Test with normalized input
|
58 |
+
x = torch.randn(1, 1, hs, device="cuda") * 0.1
|
59 |
+
output, expert_weights = model(x)
|
60 |
+
print("Model forward pass completed successfully.")
|
61 |
+
|
62 |
+
print(f"Output shape: {output.shape}")
|
63 |
+
print(f"Output range: [{output.min():.3f}, {output.max():.3f}]")
|
64 |
+
print(f"Output: {output.flatten()[:10]}")
|
65 |
+
print(f"Expert weights sum: {expert_weights.sum():.3f}")
|
66 |
```
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
readme_example.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# /// script
|
2 |
+
# requires-python = "==3.10"
|
3 |
+
# dependencies = [
|
4 |
+
# "numpy",
|
5 |
+
# "kernels",
|
6 |
+
# "torch"
|
7 |
+
# ]
|
8 |
+
# ///
|
9 |
+
|
10 |
+
import torch
|
11 |
+
from collections import namedtuple
|
12 |
+
|
13 |
+
from kernels import get_kernel
|
14 |
+
|
15 |
+
# Make reproducible
|
16 |
+
torch.manual_seed(42)
|
17 |
+
torch.cuda.manual_seed(42)
|
18 |
+
|
19 |
+
# Download optimized kernels from the Hugging Face hub
|
20 |
+
megablocks = get_kernel("kernels-community/megablocks")
|
21 |
+
print("MegaBlocks kernel downloaded successfully.")
|
22 |
+
|
23 |
+
model = megablocks.layers.MegaBlocksMoeMLP()
|
24 |
+
model.experts = namedtuple("Experts", ["gate_up_proj", "gate_down_proj", "down_proj", "hidden_size"])
|
25 |
+
print("MegaBlocksMoeMLP instance created successfully.")
|
26 |
+
|
27 |
+
# Config
|
28 |
+
ne, hs, isz = 128, 1152, 3072
|
29 |
+
|
30 |
+
# Router with proper initialization
|
31 |
+
model.router = torch.nn.Linear(hs, ne, device="cuda")
|
32 |
+
torch.nn.init.kaiming_uniform_(model.router.weight)
|
33 |
+
|
34 |
+
# Expert layers with realistic weights
|
35 |
+
e = model.experts
|
36 |
+
e.gate_up_proj = torch.nn.Parameter(torch.randn(ne, hs, isz, device="cuda") * 0.02)
|
37 |
+
e.gate_up_proj_bias = torch.nn.Parameter(torch.zeros(ne, isz, device="cuda"))
|
38 |
+
e.down_proj = torch.nn.Parameter(torch.randn(ne, 1536, hs, device="cuda") * 0.02)
|
39 |
+
e.down_proj_bias = torch.nn.Parameter(torch.zeros(ne, hs, device="cuda"))
|
40 |
+
e.hidden_size = hs
|
41 |
+
print("Expert layers initialized successfully.")
|
42 |
+
|
43 |
+
# Test with normalized input
|
44 |
+
x = torch.randn(1, 1, hs, device="cuda") * 0.1
|
45 |
+
output, expert_weights = model(x)
|
46 |
+
print("Model forward pass completed successfully.")
|
47 |
+
|
48 |
+
print(f"Output shape: {output.shape}")
|
49 |
+
print(f"Output range: [{output.min():.3f}, {output.max():.3f}]")
|
50 |
+
print(f"Output: {output.flatten()[:10]}")
|
51 |
+
print(f"Expert weights sum: {expert_weights.sum():.3f}")
|
torch-ext/megablocks/backend/kernels.py
CHANGED
@@ -5,6 +5,20 @@ import torch
|
|
5 |
import triton
|
6 |
import triton.language as tl
|
7 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
|
9 |
def assert_is_tensor(x, ndim):
|
10 |
if x.ndim != ndim:
|
|
|
5 |
import triton
|
6 |
import triton.language as tl
|
7 |
|
8 |
+
# Stub triton autotune when testing in a env that does not have CUDA
|
9 |
+
# this approach preserves the original code but enables testing without a GPU
|
10 |
+
if torch.cuda.is_available() is False:
|
11 |
+
import warnings
|
12 |
+
|
13 |
+
warnings.warn("CUDA is not available. Triton autotuning is disabled.")
|
14 |
+
|
15 |
+
def _no_autotune(*args, **kwargs):
|
16 |
+
def deco(fn):
|
17 |
+
return fn
|
18 |
+
return deco
|
19 |
+
|
20 |
+
triton.autotune = _no_autotune
|
21 |
+
|
22 |
|
23 |
def assert_is_tensor(x, ndim):
|
24 |
if x.ndim != ndim:
|