File size: 2,011 Bytes
57e3690
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
#version 450

#include "types.comp"

#define BLOCK_SIZE 1024
#define ASC 0

layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;

layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
layout (binding = 1)          buffer D {int data_d[];};

layout (push_constant) uniform parameter {
    uint ncols;
    uint ncols_pad;
    uint order;
} p;

shared int dst_row[BLOCK_SIZE];

void swap(uint idx0, uint idx1) {
    int tmp = dst_row[idx0];
    dst_row[idx0] = dst_row[idx1];
    dst_row[idx1] = tmp;
}

void main() {
    // bitonic sort
    const int col = int(gl_LocalInvocationID.x);
    const uint row = gl_WorkGroupID.y;

    const uint row_offset = row * p.ncols;

    // initialize indices
    if (col < p.ncols_pad) {
        dst_row[col] = col;
    }
    barrier();

    for (uint k = 2; k <= p.ncols_pad; k *= 2) {
        for (uint j = k / 2; j > 0; j /= 2) {
            const uint ixj = col ^ j;
            if (col < p.ncols_pad && ixj > col) {
                if ((col & k) == 0) {
                    if (dst_row[col] >= p.ncols ||
                        (dst_row[ixj] < p.ncols && (p.order == ASC ?
                            data_a[row_offset + dst_row[col]] > data_a[row_offset + dst_row[ixj]] :
                            data_a[row_offset + dst_row[col]] < data_a[row_offset + dst_row[ixj]]))
                    ) {
                        swap(col, ixj);
                    }
                } else {
                    if (dst_row[ixj] >= p.ncols ||
                        (dst_row[col] < p.ncols && (p.order == ASC ?
                            data_a[row_offset + dst_row[col]] < data_a[row_offset + dst_row[ixj]] :
                            data_a[row_offset + dst_row[col]] > data_a[row_offset + dst_row[ixj]]))
                    ) {
                        swap(col, ixj);
                    }
                }
            }
            barrier();
        }
    }

    if (col < p.ncols) {
        data_d[row_offset + col] = dst_row[col];
    }
}