File size: 5,750 Bytes
b664585
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
#version 450

#ifdef FLOAT16
#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
#endif
#extension GL_EXT_shader_explicit_arithmetic_types : require

#include "mul_mat_vec_base.comp"

layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;

layout (constant_id = 0) const uint BLOCK_SIZE = 32;
layout (constant_id = 1) const uint NUM_ROWS = 1;

#if !defined(DATA_A_F32) && !defined(DATA_A_F16)
#define K_PER_ITER 8
#else
#define K_PER_ITER 2
#endif


uint a_offset, b_offset, d_offset, y_offset;

shared FLOAT_TYPE tmpsh[NUM_ROWS][BLOCK_SIZE];

void iter(inout FLOAT_TYPE temp[NUM_ROWS], const uint first_row, const uint num_rows, const uint tid, const uint i, bool lastiter)
{
    const uint col = i*BLOCK_SIZE + K_PER_ITER*tid;
    const uint iqs = (col%QUANT_K)/QUANT_R; // quant index
    const uint iybs = col - col%QUANT_K; // y block start index

#if K_PER_ITER == 8
#if QUANT_R == 2
    B_TYPE_VEC4 bv02 = data_b_v4[(b_offset + iybs + iqs) / 4];
    B_TYPE_VEC4 bv13 = data_b_v4[(b_offset + iybs + iqs + y_offset) / 4];
    FLOAT_TYPE b0 = FLOAT_TYPE(bv02.x);
    FLOAT_TYPE b1 = FLOAT_TYPE(bv13.x);
    FLOAT_TYPE b2 = FLOAT_TYPE(bv02.y);
    FLOAT_TYPE b3 = FLOAT_TYPE(bv13.y);
    FLOAT_TYPE b4 = FLOAT_TYPE(bv02.z);
    FLOAT_TYPE b5 = FLOAT_TYPE(bv13.z);
    FLOAT_TYPE b6 = FLOAT_TYPE(bv02.w);
    FLOAT_TYPE b7 = FLOAT_TYPE(bv13.w);
#else
    B_TYPE_VEC4 bv0 = data_b_v4[(b_offset + iybs + iqs) / 4];
    B_TYPE_VEC4 bv1 = data_b_v4[(b_offset + iybs + iqs) / 4 + 1];
    FLOAT_TYPE b0 = FLOAT_TYPE(bv0.x);
    FLOAT_TYPE b1 = FLOAT_TYPE(bv0.y);
    FLOAT_TYPE b2 = FLOAT_TYPE(bv0.z);
    FLOAT_TYPE b3 = FLOAT_TYPE(bv0.w);
    FLOAT_TYPE b4 = FLOAT_TYPE(bv1.x);
    FLOAT_TYPE b5 = FLOAT_TYPE(bv1.y);
    FLOAT_TYPE b6 = FLOAT_TYPE(bv1.z);
    FLOAT_TYPE b7 = FLOAT_TYPE(bv1.w);
#endif
#else
    // Check if the second of the pair of elements is OOB, and don't fetch B or
    // accumulate it. We still fetch a pair of elements for A, which is fine for
    // quantized formats since they'll be within the same block. We should
    // probably skip fetching the second element for F16/F32, but as of now we
    // still do.
    const bool OOB = lastiter && (iybs + iqs + y_offset >= p.ncols);

    FLOAT_TYPE b0 = 0, b1 = 0;
    b0 = FLOAT_TYPE(data_b[b_offset + iybs + iqs]);
    if (!OOB) {
        b1 = FLOAT_TYPE(data_b[b_offset + iybs + iqs + y_offset]);
    }
#endif
    [[unroll]] for (uint n = 0; n < num_rows; ++n) {
        const uint ib = ((first_row + n)*p.ncols + col)/QUANT_K; // block index

#if K_PER_ITER == 8
        const vec4 v = dequantize4(ib, iqs, a_offset);
        const vec4 v2 = dequantize4(ib, iqs+(4/QUANT_R), a_offset);

        // matrix multiplication
        temp[n] = fma(FLOAT_TYPE(v.x), b0, temp[n]);
        temp[n] = fma(FLOAT_TYPE(v.y), b1, temp[n]);
        temp[n] = fma(FLOAT_TYPE(v.z), b2, temp[n]);
        temp[n] = fma(FLOAT_TYPE(v.w), b3, temp[n]);
        temp[n] = fma(FLOAT_TYPE(v2.x), b4, temp[n]);
        temp[n] = fma(FLOAT_TYPE(v2.y), b5, temp[n]);
        temp[n] = fma(FLOAT_TYPE(v2.z), b6, temp[n]);
        temp[n] = fma(FLOAT_TYPE(v2.w), b7, temp[n]);
#else
        const vec2 v = dequantize(ib, iqs, a_offset);

        // matrix multiplication
        temp[n] = fma(FLOAT_TYPE(v.x), b0, temp[n]);
        if (!OOB) {
            temp[n] = fma(FLOAT_TYPE(v.y), b1, temp[n]);
        }
#endif
    }
}

void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
    const uint tid = gl_LocalInvocationID.x;

    get_offsets(a_offset, b_offset, d_offset);
    a_offset /= QUANT_K;

    y_offset = QUANT_R == 1 ? 1 : QUANT_K/2;

    FLOAT_TYPE temp[NUM_ROWS];

    for (uint i = 0; i < NUM_ROWS; ++i) {
        temp[i] = FLOAT_TYPE(0);
    }

    uint num_iters = p.ncols / (K_PER_ITER * BLOCK_SIZE);
    if (num_iters * K_PER_ITER * BLOCK_SIZE + K_PER_ITER*tid < p.ncols) {
        num_iters++;
    }
    int unroll_count = 4;
    uint unrolled_iters = num_iters & ~(unroll_count - 1);

    uint i = 0;
    while (i < unrolled_iters) {
        // Manually partially unroll the loop
        [[unroll]] for (uint k = 0; k < unroll_count; ++k) {
            iter(temp, first_row, num_rows, tid, i*K_PER_ITER, false);
            i++;
        }
    }
    unroll_count = 2;
    unrolled_iters = num_iters & ~(unroll_count - 1);
    while (i < unrolled_iters) {
        // Manually partially unroll the loop
        [[unroll]] for (uint k = 0; k < unroll_count; ++k) {
            iter(temp, first_row, num_rows, tid, i*K_PER_ITER, false);
            i++;
        }
    }
    while (i < num_iters) {
        iter(temp, first_row, num_rows, tid, i*K_PER_ITER, true);
        i++;
    }

    // sum up partial sums and write back result
    [[unroll]] for (uint n = 0; n < num_rows; ++n) {
        tmpsh[n][tid] = temp[n];
    }
    barrier();
    [[unroll]] for (uint s = BLOCK_SIZE/2; s > 0; s >>= 1) {
        if (tid < s) {
            [[unroll]] for (uint n = 0; n < num_rows; ++n) {
                tmpsh[n][tid] += tmpsh[n][tid + s];
            }
        }
        barrier();
    }
    if (tid == 0) {
        [[unroll]] for (uint n = 0; n < num_rows; ++n) {
            data_d[d_offset + first_row + n] = D_TYPE(tmpsh[n][0]);
        }
    }
}

void main() {
    const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z);

#if defined(DATA_A_IQ4_NL)
    init_iq4nl_shmem();
#endif

    // do NUM_ROWS at a time, unless there aren't enough remaining rows
    if (first_row + NUM_ROWS <= p.stride_d) {
        compute_outputs(first_row, NUM_ROWS);
    } else {
        if (first_row >= p.stride_d) {
            return;
        }
        compute_outputs(first_row, p.stride_d - first_row);
    }
}