|
#include "kernel_operator.h" |
|
|
|
|
|
using namespace AscendC; |
|
|
|
#define BUFFER_NUM 2 |
|
|
|
#define QK8_0 32 |
|
|
|
class GET_ROW_Q8_0 { |
|
public: |
|
__aicore__ inline GET_ROW_Q8_0() {} |
|
__aicore__ inline void init(GM_ADDR input, GM_ADDR indices, GM_ADDR output, |
|
int64_t *input_ne_ub, int64_t *indices_ne_ub, |
|
size_t *indices_nb_ub, int64_t *output_ne_ub, |
|
size_t *output_nb_ub) { |
|
int64_t op_block_num = GetBlockNum(); |
|
int64_t op_block_idx = GetBlockIdx(); |
|
|
|
for (int i = 0; i < 4; i++) { |
|
input_ne[i] = input_ne_ub[i]; |
|
indices_ne[i] = indices_ne_ub[i]; |
|
indices_stride[i] = indices_nb_ub[i] / indices_nb_ub[0]; |
|
scale_ne[i] = input_ne_ub[i]; |
|
output_ne[i] = output_ne_ub[i]; |
|
output_stride[i] = output_nb_ub[i] / output_nb_ub[0]; |
|
} |
|
|
|
|
|
scale_ne[0] /= QK8_0; |
|
|
|
input_stride[0] = 1; |
|
scale_stride[0] = 1; |
|
output_stride[0] = 1; |
|
for (int i = 1; i < 4; i++) { |
|
input_stride[i] = input_stride[i - 1] * input_ne[i - 1]; |
|
scale_stride[i] = scale_stride[i - 1] * scale_ne[i - 1]; |
|
} |
|
|
|
group_size_in_row = input_ne[0] / QK8_0; |
|
int64_t scale_offset = input_ne[0] * input_ne[1] * input_ne[2] * |
|
input_ne[3] * sizeof(int8_t); |
|
|
|
|
|
|
|
uint64_t n_elements = |
|
indices_ne[0] * indices_ne[1] * indices_ne[2] * indices_ne[3]; |
|
dr = n_elements / op_block_num; |
|
|
|
uint64_t tails = n_elements % op_block_num; |
|
if (op_block_idx < tails) { |
|
dr += 1; |
|
ir = dr * op_block_idx; |
|
} else { |
|
ir = dr * op_block_idx + tails; |
|
} |
|
|
|
input_gm.SetGlobalBuffer((__gm__ int8_t *)input); |
|
scale_gm.SetGlobalBuffer((__gm__ half *)(input + scale_offset)); |
|
indices_gm.SetGlobalBuffer((__gm__ int32_t *)indices); |
|
output_gm.SetGlobalBuffer((__gm__ float *)output); |
|
|
|
pipe.InitBuffer(input_queue, BUFFER_NUM, QK8_0 * sizeof(int8_t)); |
|
pipe.InitBuffer(cast_queue, BUFFER_NUM, QK8_0 * sizeof(half)); |
|
pipe.InitBuffer(output_queue, BUFFER_NUM, QK8_0 * sizeof(float)); |
|
} |
|
|
|
__aicore__ inline void copy_in(uint32_t offset) { |
|
LocalTensor<int8_t> input_local = input_queue.AllocTensor<int8_t>(); |
|
DataCopy(input_local, input_gm[offset], QK8_0); |
|
input_queue.EnQue(input_local); |
|
} |
|
|
|
__aicore__ inline void copy_out(uint32_t offset) { |
|
LocalTensor<float> output_local = output_queue.DeQue<float>(); |
|
DataCopy(output_gm[offset], output_local, QK8_0); |
|
output_queue.FreeTensor(output_local); |
|
} |
|
|
|
__aicore__ inline void calculate_group(int64_t idx, int64_t group) { |
|
const int64_t indices_ne2_idx = idx / (indices_ne[0] * indices_ne[1]); |
|
const int64_t indices_ne1_idx = |
|
(idx - indices_ne2_idx * indices_ne[0] * indices_ne[1]) / |
|
indices_ne[0]; |
|
const int64_t indices_ne0_idx = |
|
(idx - indices_ne2_idx * indices_ne[0] * indices_ne[1] - |
|
indices_ne1_idx * indices_ne[0]); |
|
|
|
const int64_t indices_offset = indices_ne0_idx * indices_stride[0] + |
|
indices_ne1_idx * indices_stride[1] + |
|
indices_ne2_idx * indices_stride[2]; |
|
const int32_t selected_row_idx = indices_gm.GetValue(indices_offset); |
|
|
|
const int64_t input_offset = selected_row_idx * input_stride[1] + |
|
indices_ne1_idx * input_stride[2] + |
|
indices_ne2_idx * input_stride[3] + |
|
group * QK8_0; |
|
const int64_t scale_offset = selected_row_idx * scale_stride[1] + |
|
indices_ne1_idx * scale_stride[2] + |
|
indices_ne2_idx * scale_stride[3] + group; |
|
const int64_t output_offset = indices_ne0_idx * output_stride[1] + |
|
indices_ne1_idx * output_stride[2] + |
|
indices_ne2_idx * output_stride[3] + |
|
group * QK8_0; |
|
|
|
copy_in(input_offset); |
|
LocalTensor<int8_t> input_local = input_queue.DeQue<int8_t>(); |
|
LocalTensor<half> cast_local = cast_queue.AllocTensor<half>(); |
|
LocalTensor<float> output_local = output_queue.AllocTensor<float>(); |
|
|
|
|
|
Cast(cast_local, input_local, RoundMode::CAST_NONE, QK8_0); |
|
Cast(output_local, cast_local, RoundMode::CAST_NONE, QK8_0); |
|
|
|
|
|
half scale = scale_gm.GetValue(scale_offset); |
|
Muls(output_local, output_local, (float)scale, QK8_0); |
|
|
|
input_queue.FreeTensor(input_local); |
|
cast_queue.FreeTensor(cast_local); |
|
output_queue.EnQue(output_local); |
|
|
|
copy_out(output_offset); |
|
} |
|
|
|
__aicore__ inline void calculate() { |
|
for (int64_t i = ir; i < ir + dr; i++) { |
|
for (int64_t j = 0; j < group_size_in_row; j++) { |
|
calculate_group(i, j); |
|
} |
|
} |
|
} |
|
|
|
private: |
|
int64_t input_ne[4]; |
|
size_t input_stride[4]; |
|
|
|
int64_t scale_ne[4]; |
|
size_t scale_stride[4]; |
|
|
|
int64_t indices_ne[4]; |
|
size_t indices_stride[4]; |
|
|
|
int64_t output_ne[4]; |
|
size_t output_stride[4]; |
|
|
|
int64_t ir; |
|
int64_t dr; |
|
|
|
int64_t group_size_in_row; |
|
|
|
TPipe pipe; |
|
GlobalTensor<int8_t> input_gm; |
|
GlobalTensor<half> scale_gm; |
|
GlobalTensor<int32_t> indices_gm; |
|
GlobalTensor<float> output_gm; |
|
TQue<QuePosition::VECIN, BUFFER_NUM> input_queue; |
|
TQue<QuePosition::VECOUT, BUFFER_NUM> output_queue; |
|
TQue<QuePosition::VECIN, BUFFER_NUM> cast_queue; |
|
}; |
|
|
|
template <typename T> |
|
__aicore__ inline void copy_to_ub(GM_ADDR gm, T *ub, size_t size) { |
|
auto gm_ptr = (__gm__ uint8_t *)gm; |
|
auto ub_ptr = (uint8_t *)(ub); |
|
for (int32_t i = 0; i < size; ++i, ++ub_ptr, ++gm_ptr) { |
|
*ub_ptr = *gm_ptr; |
|
} |
|
} |
|
|
|
extern "C" __global__ __aicore__ void ascendc_get_row_q8_0( |
|
GM_ADDR input_gm, GM_ADDR indices_gm, GM_ADDR output_gm, |
|
GM_ADDR input_ne_gm, GM_ADDR indices_ne_gm, GM_ADDR indices_nb_gm, |
|
GM_ADDR output_ne_gm, GM_ADDR output_nb_gm) { |
|
int64_t input_ne_ub[4]; |
|
int64_t indices_ne_ub[4]; |
|
size_t indices_nb_ub[4]; |
|
int64_t output_ne_ub[4]; |
|
size_t output_nb_ub[4]; |
|
|
|
copy_to_ub(input_ne_gm, input_ne_ub, 32); |
|
copy_to_ub(indices_ne_gm, indices_ne_ub, 32); |
|
copy_to_ub(indices_nb_gm, indices_nb_ub, 32); |
|
copy_to_ub(output_ne_gm, output_ne_ub, 32); |
|
copy_to_ub(output_nb_gm, output_nb_ub, 32); |
|
|
|
GET_ROW_Q8_0 op; |
|
op.init(input_gm, indices_gm, output_gm, input_ne_ub, indices_ne_ub, |
|
indices_nb_ub, output_ne_ub, output_nb_ub); |
|
op.calculate(); |
|
} |
|
|