File size: 2,028 Bytes
e45d058
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
/******************************************************************************

 * Copyright (c) 2023, Tri Dao.

 ******************************************************************************/

#include <torch/python.h>
#include <ATen/native/TensorIterator.h>
#include <ATen/native/cuda/Loops.cuh>

void apply_rotary_cuda(const torch::Tensor x1, const torch::Tensor x2,

                       const torch::Tensor cos, const torch::Tensor sin,

                       torch::Tensor out1, torch::Tensor out2,

                       const bool conj) {
    auto iter = at::TensorIteratorConfig()
        .add_output(out1)
        .add_output(out2)
        .add_input(x1)
        .add_input(x2)
        .add_input(cos)
        .add_input(sin)
        .check_all_same_dtype(false)
        .promote_inputs_to_common_dtype(false)
        .build();

    if (!conj) {
        AT_DISPATCH_FLOATING_TYPES_AND2(at::kBFloat16, at::kHalf, x1.scalar_type(), "rotary_kernel", [&] {
            at::native::gpu_kernel_multiple_outputs(
                iter, [] GPU_LAMBDA (scalar_t x1, scalar_t x2, scalar_t cos,
                                    scalar_t sin) -> thrust::tuple<scalar_t, scalar_t> {
                scalar_t out1 = float(x1) * float(cos) - float(x2) * float(sin);
                scalar_t out2 = float(x1) * float(sin) + float(x2) * float(cos);
                return {out1, out2};
            });
        });
    } else {
        AT_DISPATCH_FLOATING_TYPES_AND2(at::kBFloat16, at::kHalf, x1.scalar_type(), "rotary_kernel", [&] {
            at::native::gpu_kernel_multiple_outputs(
                iter, [] GPU_LAMBDA (scalar_t x1, scalar_t x2, scalar_t cos,
                                    scalar_t sin) -> thrust::tuple<scalar_t, scalar_t> {
                scalar_t out1 = float(x1) * float(cos) + float(x2) * float(sin);
                scalar_t out2 = -float(x1) * float(sin) + float(x2) * float(cos);
                return {out1, out2};
            });
        });
    }
}