File size: 21,435 Bytes
e45d058
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
# Adapted from https://github.com/ELS-RD/kernl/blob/main/src/kernl/implementations/linear_layer.py
# and https://github.com/openai/triton/blob/master/python/triton/ops/matmul.py
from typing import Optional

import torch
import triton
import triton.language as tl
from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time

from flash_attn.ops.triton.k_activations import (
    gelu,
    gelu_approx,
    gelu_approx_grad,
    gelu_grad,
    squared_relu,
    squared_relu_grad,
)

# CREDITS: Initially inspired by the Triton tutorial on matrix multiplications


def init_to_zero(name):
    return lambda nargs: nargs[name].zero_()


def get_configs_io_bound():
    configs = []
    for num_stages in [2, 3, 4, 5, 6]:
        for block_m in [16, 32]:
            for block_k in [32, 64]:
                for block_n in [32, 64, 128, 256]:
                    num_warps = 2 if block_n <= 64 else 4
                    configs.append(
                        triton.Config(
                            {
                                "BLOCK_M": block_m,
                                "BLOCK_N": block_n,
                                "BLOCK_K": block_k,
                                "SPLIT_K": 1,
                            },
                            num_stages=num_stages,
                            num_warps=num_warps,
                        )
                    )
                    # split_k not used
                    # for split_k in [2, 4, 8, 16]:
                    #     configs.append(triton.Config(
                    #         {'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': split_k},
                    #         num_stages=num_stages, num_warps=num_warps, pre_hook=init_to_zero('C')))
    return configs


@triton.autotune(

    configs=[

        triton.Config(

            {"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=3, num_warps=8

        ),

        triton.Config(

            {"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=3, num_warps=8

        ),

        triton.Config(

            {"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4

        ),

        triton.Config(

            {"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4

        ),

        triton.Config(

            {"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4

        ),

        triton.Config(

            {"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4

        ),

        triton.Config(

            {"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4

        ),

        triton.Config(

            {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4

        ),

        triton.Config(

            {"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=5, num_warps=2

        ),

        # good for int8

        triton.Config(

            {"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 128, "SPLIT_K": 1},

            num_stages=3,

            num_warps=8,

        ),

        triton.Config(

            {"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 128, "SPLIT_K": 1},

            num_stages=3,

            num_warps=8,

        ),

        triton.Config(

            {"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=4, num_warps=4

        ),

        triton.Config(

            {"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=4, num_warps=4

        ),

        triton.Config(

            {"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 128, "SPLIT_K": 1},

            num_stages=4,

            num_warps=4,

        ),

        triton.Config(

            {"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=4, num_warps=4

        ),

        triton.Config(

            {"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=4, num_warps=4

        ),

        triton.Config(

            {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=4, num_warps=4

        ),

        triton.Config(

            {"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=5, num_warps=2

        ),

    ]

    + get_configs_io_bound(),

    key=["CACHE_KEY_M", "CACHE_KEY_N", "CACHE_KEY_K"],

    prune_configs_by={

        "early_config_prune": early_config_prune,

        "perf_model": estimate_matmul_time,

        "top_k": 10,

    },

)
@triton.heuristics(

    {

        "EVEN_K": lambda args: args["K"] % (args["BLOCK_K"] * args["SPLIT_K"]) == 0,

    }

)
@triton.jit
def kernel_fwd(

    C,  # Pointers to matrices

    ACT_INPUT,

    A,

    B,

    bias,

    # Matrix dimensions

    M,

    N,

    K,

    CACHE_KEY_M,

    CACHE_KEY_N,

    CACHE_KEY_K,

    # The stride variables represent how much to increase the ptr by when moving by 1

    # element in a particular dimension. E.g. stride_am is how much to increase a_ptr

    # by to get the element one row down (A has M rows)

    stride_cm,

    # stride_cn,  # Assume that stride_cn == 1

    stride_am,

    stride_ak,

    stride_bn,

    stride_bk,

    # Meta-parameters

    BLOCK_M: tl.constexpr,

    GROUP_M: tl.constexpr,

    BLOCK_N: tl.constexpr,

    BLOCK_K: tl.constexpr,

    # split k not used, not performant with activation, kept because early_config_prune is expecting it

    SPLIT_K: tl.constexpr,

    EVEN_K: tl.constexpr,

    A_ROWMAJOR: tl.constexpr,

    B_COLMAJOR: tl.constexpr,

    BIAS: tl.constexpr,

    SAVE_ACT_INPUT: tl.constexpr,

    ACTIVATION: tl.constexpr,

):

    """

    Kernel for computing Out = activation(A x W + C)

    - Input has shape (M, K)

    - Weight has shape (K, N)

    - Bias has shape (N,)

    - Output has shape (M, N)

    - ActInputs (optional) has shape (M, N)

    'ActInputs' optionally saves the A x W + C intermediate for backward computations

    This kernel will consolidate over K

    """

    pid = tl.program_id(axis=0)

    grid_m = (M + BLOCK_M - 1) // BLOCK_M
    grid_n = (N + BLOCK_N - 1) // BLOCK_N
    # re-order program ID for better L2 performance
    width = GROUP_M * grid_n
    group_id = pid // width
    group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
    pid_m = group_id * GROUP_M + (pid % group_size)
    pid_n = (pid % width) // (group_size)

    # now compute the block that each program will go through
    # rm (resp. rn) denotes a range of indices
    # for rows (resp. col) of C
    rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
    rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
    # trick to avoid masking on M and N axis
    ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
    rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
    rk = tl.arange(0, BLOCK_K)

    if A_ROWMAJOR:
        A = A + (ram[:, None] * stride_am + rk[None, :])
    else:
        A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
    if B_COLMAJOR:
        B = B + (rk[:, None] + rbn[None, :] * stride_bn)
    else:
        B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)

    acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)

    for k in range(K, 0, -BLOCK_K):
        if EVEN_K:
            a = tl.load(A)
            b = tl.load(B)
        else:
            a = tl.load(A, mask=rk[None, :] < k, other=0.0)
            b = tl.load(B, mask=rk[:, None] < k, other=0.0)
        acc += tl.dot(a, b)

        if A_ROWMAJOR:
            A += BLOCK_K
        else:
            A += BLOCK_K * stride_ak
        if B_COLMAJOR:
            B += BLOCK_K
        else:
            B += BLOCK_K * stride_bk

    # Putting bias after the matmul (instead of before) is faster, idk why
    if BIAS:
        bias = tl.load(bias + rn, mask=rn < N, other=0.0).to(tl.float32)
        acc += bias[None, :]

    # optional: save the activation inputs
    if SAVE_ACT_INPUT:
        # act_in_ptrs = ACT_INPUT + ram[:, None] * stride_cm + rbn[None, :] * stride_cn
        act_in_ptrs = ACT_INPUT + ram[:, None] * stride_cm + rbn[None, :]
        tl.store(act_in_ptrs, acc)

    # optional: fused activation (while the data is in shared memory)
    if ACTIVATION == "gelu":
        acc = gelu(acc)
    elif ACTIVATION == "gelu_approx":
        acc = gelu_approx(acc)
    elif ACTIVATION == "squared_relu":
        acc = squared_relu(acc)
    # rematerialize rm and rn to save registers
    rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
    rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)

    # write back result
    # C = C + rm[:, None] * stride_cm + rn[None, :] * stride_cn
    C = C + rm[:, None] * stride_cm + rn[None, :]
    mask = (rm < M)[:, None] & (rn < N)[None, :]
    tl.store(C, acc)


def triton_linear_act(

    x: torch.Tensor,

    weight: torch.Tensor,

    bias: Optional[torch.Tensor] = None,

    activation: str = "id",

    save_act_input: bool = False,

) -> torch.Tensor:
    """

    Compute e = activation(x @ weight.T + bias).

    This wrapper kicks the `kernel_fwd` Triton kernel

    :param x: input tensor

    :param weight: weight matrix

    :param bias: an optional bias tensor

    :param activation: Activation name. Needs to be a Triton kernel.

    :param act_input: an optional tensor to save the activation inputs (for backward)

    :return: result tensor

    """
    # if torch.is_autocast_enabled():
    #     dtype = torch.get_autocast_gpu_dtype()
    #     x, weight, bias = [a.to(dtype=dtype) for a in [x, weight, bias]]

    assert activation in ["id", "gelu", "gelu_approx", "squared_relu"]

    batch_shape, n = x.shape[:-1], x.shape[-1]
    batch_dim = batch_shape.numel()
    x_reshaped = x.reshape(batch_dim, n)

    if x_reshaped.stride(0) > 1 and x_reshaped.stride(1) > 1:
        x_reshaped = x_reshaped.contiguous()
    if weight.stride(0) > 1 and weight.stride(1) > 1:
        weight = weight.contiguous()
    bias = bias.contiguous() if bias is not None else None

    assert (
        x.dtype == weight.dtype
    ), f"Input and weight must have the same dtype, got {x.dtype} and {weight.dtype}"
    if bias is not None:
        assert (
            x.dtype == bias.dtype
        ), f"Input and bias must have the same dtype, got {x.dtype} and {bias.dtype}"
    assert (
        x_reshaped.shape[1] == weight.shape[1]
    ), f"Incompatible dimensions: {x_reshaped.shape} - {weight.shape}"

    assert (
        bias is None or bias.shape[0] == weight.shape[0]
    ), "Incompatible dimensions in between weight and bias"

    M, K = x_reshaped.shape
    N, K = weight.shape

    output = torch.empty((M, N), device=x.device, dtype=x.dtype)
    act_input = torch.empty_like(output) if save_act_input else None

    # 1D launch kernel where each block gets its own program.
    grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),)  # noqa

    kernel_fwd[grid](
        output,
        act_input,
        x_reshaped,
        weight,  # data ptrs
        bias if bias is not None else x,  # auto skip bias if not present
        M,  # shapes
        N,
        K,
        M // 32,  # key for triton cache (limit number of compilations)
        N // 32,
        K // 32,
        stride_cm=output.stride(0),  # strides
        # stride_cn=output.stride(1),
        stride_am=x_reshaped.stride(0),
        stride_ak=x_reshaped.stride(1),
        stride_bk=weight.stride(1),
        stride_bn=weight.stride(0),
        BIAS=bias is not None,  # optional fused bias
        SAVE_ACT_INPUT=save_act_input,  # optional save activation inputs
        ACTIVATION=activation,  # optional fused activation
        A_ROWMAJOR=x_reshaped.stride(1) == 1,
        B_COLMAJOR=weight.stride(1) == 1,
        GROUP_M=8,  # speed optimization: group the programs
    )

    if not save_act_input:
        return output.reshape(*batch_shape, output.shape[-1])
    else:
        return (
            output.reshape(*batch_shape, output.shape[-1]),
            act_input.reshape(*batch_shape, act_input.shape[-1]),
        )


@triton.autotune(

    configs=[

        triton.Config(

            {"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=3, num_warps=8

        ),

        triton.Config(

            {"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=3, num_warps=8

        ),

        triton.Config(

            {"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4

        ),

        triton.Config(

            {"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4

        ),

        triton.Config(

            {"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4

        ),

        triton.Config(

            {"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4

        ),

        triton.Config(

            {"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4

        ),

        triton.Config(

            {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4

        ),

        triton.Config(

            {"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=5, num_warps=2

        ),

        # good for int8

        triton.Config(

            {"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 128, "SPLIT_K": 1},

            num_stages=3,

            num_warps=8,

        ),

        triton.Config(

            {"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 128, "SPLIT_K": 1},

            num_stages=3,

            num_warps=8,

        ),

        triton.Config(

            {"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=4, num_warps=4

        ),

        triton.Config(

            {"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=4, num_warps=4

        ),

        triton.Config(

            {"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 128, "SPLIT_K": 1},

            num_stages=4,

            num_warps=4,

        ),

        triton.Config(

            {"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=4, num_warps=4

        ),

        triton.Config(

            {"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=4, num_warps=4

        ),

        triton.Config(

            {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=4, num_warps=4

        ),

        triton.Config(

            {"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=5, num_warps=2

        ),

    ]

    + get_configs_io_bound(),

    key=["CACHE_KEY_M", "CACHE_KEY_N", "CACHE_KEY_K"],

    prune_configs_by={

        "early_config_prune": early_config_prune,

        "perf_model": estimate_matmul_time,

        "top_k": 10,

    },

)
@triton.heuristics(

    {

        "EVEN_K": lambda args: args["K"] % (args["BLOCK_K"] * args["SPLIT_K"]) == 0,

    }

)
@triton.jit
def kernel_bwd(

    C,  # Pointers to matrices

    ACT_INPUT,

    A,

    B,

    # Matrix dimensions

    M,

    N,

    K,

    CACHE_KEY_M,

    CACHE_KEY_N,

    CACHE_KEY_K,

    # The stride variables represent how much to increase the ptr by when moving by 1

    # element in a particular dimension. E.g. stride_am is how much to increase a_ptr

    # by to get the element one row down (A has M rows)

    stride_cm,

    # stride_cn,  # Assume that stride_cn == 1

    stride_am,

    stride_ak,

    stride_bk,

    stride_bn,

    # Meta-parameters

    BLOCK_M: tl.constexpr,

    GROUP_M: tl.constexpr,

    BLOCK_N: tl.constexpr,

    BLOCK_K: tl.constexpr,

    # split k not used, not performant with activation, kept because early_config_prune is expecting it

    SPLIT_K: tl.constexpr,

    EVEN_K: tl.constexpr,

    ACTIVATION: tl.constexpr,

):

    """

    Kernel for computing Out = activation(A x W + C)

    - Input has shape (M, K)

    - Weight has shape (K, N)

    - Output has shape (M, N)

    - ActInputs (optional) has shape (M, N)

    'ActInputs' optionally saves the A x W + C intermediate for backward computations

    This kernel will consolidate over K

    """

    pid = tl.program_id(axis=0)

    grid_m = (M + BLOCK_M - 1) // BLOCK_M
    grid_n = (N + BLOCK_N - 1) // BLOCK_N
    # re-order program ID for better L2 performance
    width = GROUP_M * grid_n
    group_id = pid // width
    group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
    pid_m = group_id * GROUP_M + (pid % group_size)
    pid_n = (pid % width) // (group_size)

    # now compute the block that each program will go through
    # rm (resp. rn) denotes a range of indices
    # for rows (resp. col) of C
    rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
    rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
    # trick to avoid masking on M and N axis
    ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
    rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
    rk = tl.arange(0, BLOCK_K)

    A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
    B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)

    acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)

    for k in range(K, 0, -BLOCK_K):
        if EVEN_K:
            a = tl.load(A)
            b = tl.load(B)
        else:
            a = tl.load(A, mask=rk[None, :] < k, other=0.0)
            b = tl.load(B, mask=rk[:, None] < k, other=0.0)
        acc += tl.dot(a, b)

        A += BLOCK_K * stride_ak
        B += BLOCK_K * stride_bk

    # optional: fused activation (while the data is in shared memory)
    if ACTIVATION != "id":
        act_in_ptrs = ACT_INPUT + ram[:, None] * stride_cm + rbn[None, :]
        act_input = tl.load(act_in_ptrs).to(acc.dtype)
    if ACTIVATION == "gelu":
        acc *= gelu_grad(act_input)
    elif ACTIVATION == "gelu_approx":
        acc *= gelu_approx_grad(act_input)
    elif ACTIVATION == "squared_relu":
        acc *= squared_relu_grad(act_input)

    # rematerialize rm and rn to save registers
    rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
    rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)

    # write back result
    C = C + rm[:, None] * stride_cm + rn[None, :]
    mask = (rm < M)[:, None] & (rn < N)[None, :]
    tl.store(C, acc, mask=mask)


def triton_dgrad_act(

    grad_output: torch.Tensor,

    weight: torch.Tensor,

    activation: str = "id",

    act_input: Optional[torch.Tensor] = None,

) -> torch.Tensor:
    """

    Compute e = activation(grad_output @ weight + bias).

    This wrapper kicks the `kernel_fwd` Triton kernel

    :param grad_output: input tensor

    :param weight: weight matrix

    :param activation: Activation name. Needs to be a Triton kernel.

    :param act_input: an optional tensor to save the activation inputs (for backward)

    :return: result tensor

    """
    assert activation in ["id", "gelu", "gelu_approx", "squared_relu"]

    batch_shape, n = grad_output.shape[:-1], grad_output.shape[-1]
    batch_dim = batch_shape.numel()
    grad_output_reshaped = grad_output.reshape(batch_dim, n)

    if grad_output_reshaped.stride(0) > 1 and grad_output_reshaped.stride(1) > 1:
        grad_output_reshaped = grad_output_reshaped.contiguous()
    if weight.stride(0) > 1 and weight.stride(1) > 1:
        weight = weight.contiguous()

    assert (
        grad_output.dtype == weight.dtype
    ), f"grad_output and weight must have the same dtype, got {grad_output.dtype} and {weight.dtype}"
    assert (
        grad_output_reshaped.shape[1] == weight.shape[0]
    ), f"Incompatible dimensions: {grad_output_reshaped.shape} - {weight.shape}"
    if activation != "id":
        assert act_input is not None, f"act_input is required for activation {activation}"

    # M, N, K in bwd are different from M, N, K in fwd
    M, K = grad_output_reshaped.shape
    K, N = weight.shape

    grad_input = torch.empty((M, N), device=grad_output.device, dtype=grad_output.dtype)

    # 1D launch kernel where each block gets its own program.
    grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),)  # noqa

    kernel_bwd[grid](
        grad_input,
        act_input,
        grad_output_reshaped,
        weight,  # data ptrs
        M,  # shapes
        N,
        K,
        M // 32,  # key for triton cache (limit number of compilations)
        N // 32,
        K // 32,
        stride_cm=grad_input.stride(0),  # strides
        # stride_cn=grad_input.stride(1),
        stride_am=grad_output_reshaped.stride(0),
        stride_ak=grad_output_reshaped.stride(1),
        stride_bk=weight.stride(0),
        stride_bn=weight.stride(1),
        ACTIVATION=activation,  # optional fused activation
        GROUP_M=8,  # speed optimization: group the programs
    )

    return grad_input.reshape(*batch_shape, grad_input.shape[-1])