Spaces:
Sleeping
Sleeping
File size: 2,028 Bytes
e45d058 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 |
/******************************************************************************
* Copyright (c) 2023, Tri Dao.
******************************************************************************/
#include <torch/python.h>
#include <ATen/native/TensorIterator.h>
#include <ATen/native/cuda/Loops.cuh>
void apply_rotary_cuda(const torch::Tensor x1, const torch::Tensor x2,
const torch::Tensor cos, const torch::Tensor sin,
torch::Tensor out1, torch::Tensor out2,
const bool conj) {
auto iter = at::TensorIteratorConfig()
.add_output(out1)
.add_output(out2)
.add_input(x1)
.add_input(x2)
.add_input(cos)
.add_input(sin)
.check_all_same_dtype(false)
.promote_inputs_to_common_dtype(false)
.build();
if (!conj) {
AT_DISPATCH_FLOATING_TYPES_AND2(at::kBFloat16, at::kHalf, x1.scalar_type(), "rotary_kernel", [&] {
at::native::gpu_kernel_multiple_outputs(
iter, [] GPU_LAMBDA (scalar_t x1, scalar_t x2, scalar_t cos,
scalar_t sin) -> thrust::tuple<scalar_t, scalar_t> {
scalar_t out1 = float(x1) * float(cos) - float(x2) * float(sin);
scalar_t out2 = float(x1) * float(sin) + float(x2) * float(cos);
return {out1, out2};
});
});
} else {
AT_DISPATCH_FLOATING_TYPES_AND2(at::kBFloat16, at::kHalf, x1.scalar_type(), "rotary_kernel", [&] {
at::native::gpu_kernel_multiple_outputs(
iter, [] GPU_LAMBDA (scalar_t x1, scalar_t x2, scalar_t cos,
scalar_t sin) -> thrust::tuple<scalar_t, scalar_t> {
scalar_t out1 = float(x1) * float(cos) + float(x2) * float(sin);
scalar_t out2 = -float(x1) * float(sin) + float(x2) * float(cos);
return {out1, out2};
});
});
}
} |