Spaces:
Runtime error
Runtime error
layout(local_size_x_id = 0) in; | |
layout(local_size_y_id = 1) in; | |
layout(local_size_z = 1) in; | |
layout (binding = 0) readonly buffer tensorInA { uint8_t inA[]; }; | |
layout (binding = 1) readonly buffer tensorInB { float inB[]; }; | |
layout (binding = 2) writeonly buffer tensorOut { float out_[]; }; | |
layout (push_constant) uniform parameter { | |
uint inAOff; | |
uint inBOff; | |
uint outOff; | |
int ne00; | |
int ne10; | |
int ne0; | |
int ne1; | |
int ne01; | |
int gqa; | |
} pcs; | |
void main() { | |
const uint8_t kmask1 = uint8_t(0x03); | |
const uint8_t kmask2 = uint8_t(0x0C); | |
const uint8_t kmask3 = uint8_t(0x30); | |
const uint8_t kmask4 = uint8_t(0xC0); | |
const uint nb = pcs.ne00/QK_K; | |
const uint r0 = gl_WorkGroupID.x; | |
const uint r1 = gl_WorkGroupID.y; | |
const uint r2 = gl_WorkGroupID.z; | |
const uint row = (r0 * gl_NumSubgroups + gl_SubgroupID); | |
const uint offset0 = r2/pcs.gqa*(nb*pcs.ne0); | |
const uint x = row * nb + offset0; // Based from inA without base offset | |
const uint yy = r1*pcs.ne10 + r2*pcs.ne00*pcs.ne1+pcs.inBOff; // Based from inB | |
float sumf = 0; | |
// bits of invocation ID for gl_SubgroupSize=32: | |
// x x x x x | |
// 4 3 2 1 0 | |
// ( tid ) ix | |
// ip ( il ) | |
const uint block_stride = gl_SubgroupSize / 16; // number of blocks each subgroup processes | |
const uint tid = gl_SubgroupInvocationID/block_stride; // first block_stride groups have tid=0 | |
const uint ix = gl_SubgroupInvocationID%block_stride; // first block is 0..block_stride-1 | |
const uint ip = tid/8; // first or second half of block (0 or 1) | |
const uint il = tid%8; // each half has 8 parts, one per scale | |
const uint n = 4; // 4 scales at a time (and 4 sums) | |
const uint l0 = n*il; // offset into half-block, 0..28 | |
const uint is = 8*ip + l0/16; // 0, 1, 8, 9 | |
const uint y_offset = 128*ip + l0; | |
const uint q_offset_l = 64*ip + l0; | |
const uint q_offset_h = 32*ip + l0; | |
for (uint i = ix; i < nb; i += block_stride) { | |
const uint baseIndex = (x + i) * SIZE_OF_BLOCK + pcs.inAOff; | |
const uint qlIndex = q_offset_l; | |
const uint q2Index = qlIndex + QK_K/8; | |
const uint qhIndex = q_offset_h; | |
const uint y = yy + i * QK_K + y_offset; | |
float sums[4] = {0.0f, 0.0f, 0.0f, 0.0f}; | |
for (uint l = 0; l < n; ++l) { | |
const uint8_t currentQ1 = inA[baseIndex + qlIndex + l]; | |
const uint8_t currentQ2 = inA[baseIndex + q2Index + l]; | |
const uint8_t currentQh = inA[baseIndex + QK_K/2 + qhIndex + l]; | |
sums[0] += inB[y+l+ 0] * (int8_t((currentQ1 & 0xF) | ((currentQh & kmask1) << 4)) - 32); | |
sums[1] += inB[y+l+32] * (int8_t((currentQ2 & 0xF) | ((currentQh & kmask2) << 2)) - 32); | |
sums[2] += inB[y+l+64] * (int8_t((currentQ1 >> 4) | ((currentQh & kmask3) << 0)) - 32); | |
sums[3] += inB[y+l+96] * (int8_t((currentQ2 >> 4) | ((currentQh & kmask4) >> 2)) - 32); | |
} | |
float d = u8BufToFloat16(inA, baseIndex + QK_K/2 + QK_K/4 + QK_K/16); | |
sumf += d * (sums[0] * int8_t(inA[baseIndex + QK_K/2 + QK_K/4 + is]) + sums[1] * int8_t(inA[baseIndex + QK_K/2 + QK_K/4 + 2 + is]) + sums[2] * int8_t(inA[baseIndex + QK_K/2 + QK_K/4 + 4 + is]) + sums[3] * int8_t(inA[baseIndex + QK_K/2 + QK_K/4 + 6 + is])); | |
} | |
const float tot = subgroupAdd(sumf); | |
if (subgroupElect()) { | |
out_[r1*pcs.ne0 + r2*pcs.ne0*pcs.ne1 + row + pcs.outOff] = tot; | |
} | |
} | |