|
#extension GL_EXT_buffer_reference : require |
|
#extension GL_EXT_buffer_reference2 : require |
|
|
|
#define ACQUIRE gl_StorageSemanticsBuffer, gl_SemanticsAcquire |
|
#define RELEASE gl_StorageSemanticsBuffer, gl_SemanticsRelease |
|
|
|
|
|
#define FLAG_NOT_READY 0u |
|
#define FLAG_AGGREGATE_READY 1u |
|
#define FLAG_PREFIX_READY 2u |
|
|
|
layout(buffer_reference, buffer_reference_align = T_ALIGN) nonprivate buffer StateData { |
|
DTYPE aggregate; |
|
DTYPE prefix; |
|
uint flag; |
|
}; |
|
|
|
shared DTYPE sh_scratch[WG_SIZE]; |
|
shared DTYPE sh_prefix; |
|
shared uint sh_part_ix; |
|
shared uint sh_flag; |
|
|
|
void prefix_sum(DataBuffer dst, uint dst_stride, DataBuffer src, uint src_stride) |
|
{ |
|
DTYPE local[N_ROWS]; |
|
|
|
if (gl_GlobalInvocationID.x == 0) |
|
sh_part_ix = gl_WorkGroupID.x; |
|
|
|
|
|
barrier(); |
|
uint part_ix = sh_part_ix; |
|
|
|
uint ix = part_ix * PARTITION_SIZE + gl_LocalInvocationID.x * N_ROWS; |
|
|
|
|
|
local[0] = src.v[ix*src_stride]; |
|
for (uint i = 1; i < N_ROWS; i++) |
|
local[i] = local[i - 1] + src.v[(ix + i)*src_stride]; |
|
|
|
DTYPE agg = local[N_ROWS - 1]; |
|
sh_scratch[gl_LocalInvocationID.x] = agg; |
|
for (uint i = 0; i < LG_WG_SIZE; i++) { |
|
barrier(); |
|
if (gl_LocalInvocationID.x >= (1u << i)) |
|
agg += sh_scratch[gl_LocalInvocationID.x - (1u << i)]; |
|
barrier(); |
|
|
|
sh_scratch[gl_LocalInvocationID.x] = agg; |
|
} |
|
|
|
|
|
if (gl_LocalInvocationID.x == WG_SIZE - 1) { |
|
state[part_ix].aggregate = agg; |
|
if (part_ix == 0) |
|
state[0].prefix = agg; |
|
} |
|
|
|
|
|
if (gl_LocalInvocationID.x == WG_SIZE - 1) { |
|
uint flag = part_ix == 0 ? FLAG_PREFIX_READY : FLAG_AGGREGATE_READY; |
|
atomicStore(state[part_ix].flag, flag, gl_ScopeDevice, RELEASE); |
|
} |
|
|
|
DTYPE exclusive = DTYPE(0); |
|
if (part_ix != 0) { |
|
|
|
uint look_back_ix = part_ix - 1; |
|
|
|
DTYPE their_agg; |
|
uint their_ix = 0; |
|
while (true) { |
|
|
|
if (gl_LocalInvocationID.x == WG_SIZE - 1) |
|
sh_flag = atomicLoad(state[look_back_ix].flag, gl_ScopeDevice, ACQUIRE); |
|
|
|
|
|
|
|
|
|
barrier(); |
|
|
|
uint flag = sh_flag; |
|
barrier(); |
|
|
|
if (flag == FLAG_PREFIX_READY) { |
|
if (gl_LocalInvocationID.x == WG_SIZE - 1) { |
|
DTYPE their_prefix = state[look_back_ix].prefix; |
|
exclusive = their_prefix + exclusive; |
|
} |
|
break; |
|
} else if (flag == FLAG_AGGREGATE_READY) { |
|
if (gl_LocalInvocationID.x == WG_SIZE - 1) { |
|
their_agg = state[look_back_ix].aggregate; |
|
exclusive = their_agg + exclusive; |
|
} |
|
look_back_ix--; |
|
their_ix = 0; |
|
continue; |
|
} |
|
|
|
if (gl_LocalInvocationID.x == WG_SIZE - 1) { |
|
|
|
|
|
|
|
DTYPE m = src.v[(look_back_ix * PARTITION_SIZE + their_ix)*src_stride]; |
|
if (their_ix == 0) |
|
their_agg = m; |
|
else |
|
their_agg += m; |
|
|
|
their_ix++; |
|
if (their_ix == PARTITION_SIZE) { |
|
exclusive = their_agg + exclusive; |
|
if (look_back_ix == 0) { |
|
sh_flag = FLAG_PREFIX_READY; |
|
} else { |
|
look_back_ix--; |
|
their_ix = 0; |
|
} |
|
} |
|
} |
|
barrier(); |
|
flag = sh_flag; |
|
barrier(); |
|
if (flag == FLAG_PREFIX_READY) |
|
break; |
|
} |
|
|
|
|
|
if (gl_LocalInvocationID.x == WG_SIZE - 1) { |
|
DTYPE inclusive_prefix = exclusive + agg; |
|
sh_prefix = exclusive; |
|
state[part_ix].prefix = inclusive_prefix; |
|
} |
|
|
|
if (gl_LocalInvocationID.x == WG_SIZE - 1) |
|
atomicStore(state[part_ix].flag, FLAG_PREFIX_READY, gl_ScopeDevice, RELEASE); |
|
} |
|
|
|
barrier(); |
|
if (part_ix != 0) |
|
exclusive = sh_prefix; |
|
|
|
DTYPE row = exclusive; |
|
if (gl_LocalInvocationID.x > 0) |
|
row += sh_scratch[gl_LocalInvocationID.x - 1]; |
|
|
|
|
|
for (uint i = 0; i < N_ROWS; i++) |
|
dst.v[(ix + i)*dst_stride] = row + local[i]; |
|
} |
|
|