Update autotune configuration to avoid crash on AMD devices
#2
by
ror
- opened
torch-ext/triton_layer_norm/layer_norm.py
CHANGED
@@ -16,6 +16,22 @@ import triton
|
|
16 |
import triton.language as tl
|
17 |
|
18 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
def layer_norm_ref(
|
20 |
x,
|
21 |
weight,
|
@@ -128,14 +144,7 @@ def rms_norm_ref(
|
|
128 |
|
129 |
|
130 |
@triton.autotune(
|
131 |
-
configs=[
|
132 |
-
triton.Config({}, num_warps=1),
|
133 |
-
triton.Config({}, num_warps=2),
|
134 |
-
triton.Config({}, num_warps=4),
|
135 |
-
triton.Config({}, num_warps=8),
|
136 |
-
triton.Config({}, num_warps=16),
|
137 |
-
triton.Config({}, num_warps=32),
|
138 |
-
],
|
139 |
key=["N", "HAS_RESIDUAL", "STORE_RESIDUAL_OUT", "IS_RMS_NORM", "HAS_BIAS"],
|
140 |
)
|
141 |
# @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
|
@@ -407,14 +416,7 @@ def _layer_norm_fwd(
|
|
407 |
|
408 |
|
409 |
@triton.autotune(
|
410 |
-
configs=[
|
411 |
-
triton.Config({}, num_warps=1),
|
412 |
-
triton.Config({}, num_warps=2),
|
413 |
-
triton.Config({}, num_warps=4),
|
414 |
-
triton.Config({}, num_warps=8),
|
415 |
-
triton.Config({}, num_warps=16),
|
416 |
-
triton.Config({}, num_warps=32),
|
417 |
-
],
|
418 |
key=[
|
419 |
"N",
|
420 |
"HAS_DRESIDUAL",
|
|
|
16 |
import triton.language as tl
|
17 |
|
18 |
|
19 |
+
autotune_configs = [
|
20 |
+
triton.Config({}, num_warps=1),
|
21 |
+
triton.Config({}, num_warps=2),
|
22 |
+
triton.Config({}, num_warps=4),
|
23 |
+
triton.Config({}, num_warps=8),
|
24 |
+
triton.Config({}, num_warps=16),
|
25 |
+
triton.Config({}, num_warps=32),
|
26 |
+
]
|
27 |
+
|
28 |
+
if torch.cuda.is_available():
|
29 |
+
is_amd_device = ("AMD" in torch.cuda.get_device_name())
|
30 |
+
# AMD devices have a maximum of 16 warps, so we remove the 32 warps autotune config
|
31 |
+
if is_amd_device and autotune_configs[-1].num_warps == 32:
|
32 |
+
autotune_configs.pop()
|
33 |
+
|
34 |
+
|
35 |
def layer_norm_ref(
|
36 |
x,
|
37 |
weight,
|
|
|
144 |
|
145 |
|
146 |
@triton.autotune(
|
147 |
+
configs=autotune_configs[:],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
148 |
key=["N", "HAS_RESIDUAL", "STORE_RESIDUAL_OUT", "IS_RMS_NORM", "HAS_BIAS"],
|
149 |
)
|
150 |
# @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
|
|
|
416 |
|
417 |
|
418 |
@triton.autotune(
|
419 |
+
configs=autotune_configs[:],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
420 |
key=[
|
421 |
"N",
|
422 |
"HAS_DRESIDUAL",
|