Update autotune configuration to avoid crash on AMD devices

#2
by ror - opened
torch-ext/triton_layer_norm/layer_norm.py CHANGED
@@ -16,6 +16,22 @@ import triton
16
  import triton.language as tl
17
 
18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  def layer_norm_ref(
20
  x,
21
  weight,
@@ -128,14 +144,7 @@ def rms_norm_ref(
128
 
129
 
130
  @triton.autotune(
131
- configs=[
132
- triton.Config({}, num_warps=1),
133
- triton.Config({}, num_warps=2),
134
- triton.Config({}, num_warps=4),
135
- triton.Config({}, num_warps=8),
136
- triton.Config({}, num_warps=16),
137
- triton.Config({}, num_warps=32),
138
- ],
139
  key=["N", "HAS_RESIDUAL", "STORE_RESIDUAL_OUT", "IS_RMS_NORM", "HAS_BIAS"],
140
  )
141
  # @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
@@ -407,14 +416,7 @@ def _layer_norm_fwd(
407
 
408
 
409
  @triton.autotune(
410
- configs=[
411
- triton.Config({}, num_warps=1),
412
- triton.Config({}, num_warps=2),
413
- triton.Config({}, num_warps=4),
414
- triton.Config({}, num_warps=8),
415
- triton.Config({}, num_warps=16),
416
- triton.Config({}, num_warps=32),
417
- ],
418
  key=[
419
  "N",
420
  "HAS_DRESIDUAL",
 
16
  import triton.language as tl
17
 
18
 
19
+ autotune_configs = [
20
+ triton.Config({}, num_warps=1),
21
+ triton.Config({}, num_warps=2),
22
+ triton.Config({}, num_warps=4),
23
+ triton.Config({}, num_warps=8),
24
+ triton.Config({}, num_warps=16),
25
+ triton.Config({}, num_warps=32),
26
+ ]
27
+
28
+ if torch.cuda.is_available():
29
+ is_amd_device = ("AMD" in torch.cuda.get_device_name())
30
+ # AMD devices have a maximum of 16 warps, so we remove the 32 warps autotune config
31
+ if is_amd_device and autotune_configs[-1].num_warps == 32:
32
+ autotune_configs.pop()
33
+
34
+
35
  def layer_norm_ref(
36
  x,
37
  weight,
 
144
 
145
 
146
  @triton.autotune(
147
+ configs=autotune_configs[:],
 
 
 
 
 
 
 
148
  key=["N", "HAS_RESIDUAL", "STORE_RESIDUAL_OUT", "IS_RMS_NORM", "HAS_BIAS"],
149
  )
150
  # @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
 
416
 
417
 
418
  @triton.autotune(
419
+ configs=autotune_configs[:],
 
 
 
 
 
 
 
420
  key=[
421
  "N",
422
  "HAS_DRESIDUAL",