|
#version 450 |
|
|
|
#include "types.comp" |
|
#include "generic_unary_head.comp" |
|
#include "dequant_funcs.comp" |
|
|
|
#if defined(DATA_A_IQ4_NL) |
|
|
|
layout(local_size_x = 16, local_size_y = 1, local_size_z = 1) in; |
|
#else |
|
layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in; |
|
#endif |
|
|
|
void main() { |
|
#ifdef NEEDS_INIT_IQ_SHMEM |
|
init_iq_shmem(gl_WorkGroupSize); |
|
if (gl_LocalInvocationIndex.x != 0) { |
|
return; |
|
} |
|
#endif |
|
|
|
const uint idx = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x * QUANT_K; |
|
|
|
if (idx >= p.ne) { |
|
return; |
|
} |
|
|
|
uint dst_idx = get_doffset() + dst_idx(idx); |
|
uint src_idx = src0_idx_quant(idx, QUANT_K); |
|
|
|
const uint a_offset = 0; |
|
const uint ib = src_idx; |
|
const vec2 dm = get_dm(ib, a_offset); |
|
|
|
[[unroll]] for (int j = 0; j < QUANT_K; j += 4) { |
|
vec4 v = dequantize4(ib, j / QUANT_R, a_offset); |
|
v = v * dm.x + vec4(dm.y); |
|
|
|
#if QUANT_R == 2 |
|
data_d[dst_idx + j/2 + 0] = v[0]; |
|
data_d[dst_idx + j/2 + QUANT_K/2 + 0] = v[1]; |
|
data_d[dst_idx + j/2 + 1] = v[2]; |
|
data_d[dst_idx + j/2 + QUANT_K/2 + 1] = v[3]; |
|
#else |
|
data_d[dst_idx + j + 0] = v[0]; |
|
data_d[dst_idx + j + 1] = v[1]; |
|
data_d[dst_idx + j + 2] = v[2]; |
|
data_d[dst_idx + j + 3] = v[3]; |
|
#endif |
|
} |
|
} |
|
|