File size: 5,265 Bytes
8ead80b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 |
#extension GL_EXT_buffer_reference : require
#extension GL_EXT_buffer_reference2 : require
#define ACQUIRE gl_StorageSemanticsBuffer, gl_SemanticsAcquire
#define RELEASE gl_StorageSemanticsBuffer, gl_SemanticsRelease
// These correspond to X, A, P respectively in the prefix sum paper.
#define FLAG_NOT_READY 0u
#define FLAG_AGGREGATE_READY 1u
#define FLAG_PREFIX_READY 2u
layout(buffer_reference, buffer_reference_align = T_ALIGN) nonprivate buffer StateData {
DTYPE aggregate;
DTYPE prefix;
uint flag;
};
shared DTYPE sh_scratch[WG_SIZE];
shared DTYPE sh_prefix;
shared uint sh_part_ix;
shared uint sh_flag;
void prefix_sum(DataBuffer dst, uint dst_stride, DataBuffer src, uint src_stride)
{
DTYPE local[N_ROWS];
// Determine partition to process by atomic counter (described in Section 4.4 of prefix sum paper).
if (gl_GlobalInvocationID.x == 0)
sh_part_ix = gl_WorkGroupID.x;
// sh_part_ix = atomicAdd(part_counter, 1);
barrier();
uint part_ix = sh_part_ix;
uint ix = part_ix * PARTITION_SIZE + gl_LocalInvocationID.x * N_ROWS;
// TODO: gate buffer read? (evaluate whether shader check or CPU-side padding is better)
local[0] = src.v[ix*src_stride];
for (uint i = 1; i < N_ROWS; i++)
local[i] = local[i - 1] + src.v[(ix + i)*src_stride];
DTYPE agg = local[N_ROWS - 1];
sh_scratch[gl_LocalInvocationID.x] = agg;
for (uint i = 0; i < LG_WG_SIZE; i++) {
barrier();
if (gl_LocalInvocationID.x >= (1u << i))
agg += sh_scratch[gl_LocalInvocationID.x - (1u << i)];
barrier();
sh_scratch[gl_LocalInvocationID.x] = agg;
}
// Publish aggregate for this partition
if (gl_LocalInvocationID.x == WG_SIZE - 1) {
state[part_ix].aggregate = agg;
if (part_ix == 0)
state[0].prefix = agg;
}
// Write flag with release semantics
if (gl_LocalInvocationID.x == WG_SIZE - 1) {
uint flag = part_ix == 0 ? FLAG_PREFIX_READY : FLAG_AGGREGATE_READY;
atomicStore(state[part_ix].flag, flag, gl_ScopeDevice, RELEASE);
}
DTYPE exclusive = DTYPE(0);
if (part_ix != 0) {
// step 4 of paper: decoupled lookback
uint look_back_ix = part_ix - 1;
DTYPE their_agg;
uint their_ix = 0;
while (true) {
// Read flag with acquire semantics.
if (gl_LocalInvocationID.x == WG_SIZE - 1)
sh_flag = atomicLoad(state[look_back_ix].flag, gl_ScopeDevice, ACQUIRE);
// The flag load is done only in the last thread. However, because the
// translation of memoryBarrierBuffer to Metal requires uniform control
// flow, we broadcast it to all threads.
barrier();
uint flag = sh_flag;
barrier();
if (flag == FLAG_PREFIX_READY) {
if (gl_LocalInvocationID.x == WG_SIZE - 1) {
DTYPE their_prefix = state[look_back_ix].prefix;
exclusive = their_prefix + exclusive;
}
break;
} else if (flag == FLAG_AGGREGATE_READY) {
if (gl_LocalInvocationID.x == WG_SIZE - 1) {
their_agg = state[look_back_ix].aggregate;
exclusive = their_agg + exclusive;
}
look_back_ix--;
their_ix = 0;
continue;
} // else spins
if (gl_LocalInvocationID.x == WG_SIZE - 1) {
// Unfortunately there's no guarantee of forward progress of other
// workgroups, so compute a bit of the aggregate before trying again.
// In the worst case, spinning stops when the aggregate is complete.
DTYPE m = src.v[(look_back_ix * PARTITION_SIZE + their_ix)*src_stride];
if (their_ix == 0)
their_agg = m;
else
their_agg += m;
their_ix++;
if (their_ix == PARTITION_SIZE) {
exclusive = their_agg + exclusive;
if (look_back_ix == 0) {
sh_flag = FLAG_PREFIX_READY;
} else {
look_back_ix--;
their_ix = 0;
}
}
}
barrier();
flag = sh_flag;
barrier();
if (flag == FLAG_PREFIX_READY)
break;
}
// step 5 of paper: compute inclusive prefix
if (gl_LocalInvocationID.x == WG_SIZE - 1) {
DTYPE inclusive_prefix = exclusive + agg;
sh_prefix = exclusive;
state[part_ix].prefix = inclusive_prefix;
}
if (gl_LocalInvocationID.x == WG_SIZE - 1)
atomicStore(state[part_ix].flag, FLAG_PREFIX_READY, gl_ScopeDevice, RELEASE);
}
barrier();
if (part_ix != 0)
exclusive = sh_prefix;
DTYPE row = exclusive;
if (gl_LocalInvocationID.x > 0)
row += sh_scratch[gl_LocalInvocationID.x - 1];
// note - may overwrite
for (uint i = 0; i < N_ROWS; i++)
dst.v[(ix + i)*dst_stride] = row + local[i];
}
|