File size: 7,104 Bytes
57e3690
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
#version 450

#include "mul_mat_vec_base.comp"

layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in;

shared FLOAT_TYPE tmp[32];

void main() {
    const uint row = gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z;

    uint a_offset, b_offset, d_offset;
    get_offsets(a_offset, b_offset, d_offset);

    const uint num_blocks_per_row = p.ncols / QUANT_K;
    const uint ib0 = a_offset / QUANT_K + row*num_blocks_per_row;

    const uint tid = gl_LocalInvocationID.x/2;  // 0...31 or 0...16
    const uint ix  = gl_LocalInvocationID.x%2;  // 0 or 0, 1

    const uint il = tid/4;                           // 0...3
    const uint ir = tid - 4*il;                      // 0...7 or 0...3

    const uint v_im = il / 2;  // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224
    const uint v_in = il % 2;

    const uint l0 = 4*ir + 2*v_in;                   // 0...15
    const uint q_offset = 32*v_im + l0;
    const uint y_offset = 64*v_im + l0;

    const uint8_t hm1 = uint8_t(1 << (2*v_im));
    const uint8_t hm2 = uint8_t(hm1 << 4);

    tmp[16 * ix + tid] = FLOAT_TYPE(0.0); // partial sum for thread in warp

    [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += 2) {
        const uint y1_idx = i * QUANT_K + y_offset;
        const uint y2_idx = y1_idx + 128;

        const FLOAT_TYPE dall = FLOAT_TYPE(data_a[ib0 + i].d.x);
        const FLOAT_TYPE dmin = FLOAT_TYPE(data_a[ib0 + i].d.y);

        const uint8_t sc0 = uint8_t(  data_a[ib0 + i].scales[v_im * 2    ]       & 0x3f);
        const uint8_t sc1 = uint8_t(  data_a[ib0 + i].scales[v_im * 2 + 1]       & 0x3f);
        const uint8_t sc2 = uint8_t(  data_a[ib0 + i].scales[v_im * 2 + 4]       & 0x3f);
        const uint8_t sc3 = uint8_t(  data_a[ib0 + i].scales[v_im * 2 + 5]       & 0x3f);
        const uint8_t sc4 = uint8_t(( data_a[ib0 + i].scales[v_im * 2 + 8]       & 0x0f) | ((data_a[ib0 + i].scales[v_im * 2    ] & 0xc0) >> 2));
        const uint8_t sc5 = uint8_t(( data_a[ib0 + i].scales[v_im * 2 + 9]       & 0x0f) | ((data_a[ib0 + i].scales[v_im * 2 + 1] & 0xc0) >> 2));
        const uint8_t sc6 = uint8_t(((data_a[ib0 + i].scales[v_im * 2 + 8] >> 4) & 0x0f) | ((data_a[ib0 + i].scales[v_im * 2 + 4] & 0xc0) >> 2));
        const uint8_t sc7 = uint8_t(((data_a[ib0 + i].scales[v_im * 2 + 9] >> 4) & 0x0f) | ((data_a[ib0 + i].scales[v_im * 2 + 5] & 0xc0) >> 2));

        const uint8_t q4_0  = uint8_t(data_a[ib0 + i].qs[q_offset     ] & 0xf);
        const uint8_t q4_1  = uint8_t(data_a[ib0 + i].qs[q_offset +  1] & 0xf);
        const uint8_t q4_2  = uint8_t(data_a[ib0 + i].qs[q_offset + 16] & 0xf);
        const uint8_t q4_3  = uint8_t(data_a[ib0 + i].qs[q_offset + 17] & 0xf);
        const uint8_t q4_4  = uint8_t(data_a[ib0 + i].qs[q_offset     ]  >> 4);
        const uint8_t q4_5  = uint8_t(data_a[ib0 + i].qs[q_offset +  1]  >> 4);
        const uint8_t q4_6  = uint8_t(data_a[ib0 + i].qs[q_offset + 16]  >> 4);
        const uint8_t q4_7  = uint8_t(data_a[ib0 + i].qs[q_offset + 17]  >> 4);
        const uint8_t q4_8  = uint8_t(data_a[ib0 + i].qs[q_offset + 64] & 0xf);
        const uint8_t q4_9  = uint8_t(data_a[ib0 + i].qs[q_offset + 65] & 0xf);
        const uint8_t q4_10 = uint8_t(data_a[ib0 + i].qs[q_offset + 80] & 0xf);
        const uint8_t q4_11 = uint8_t(data_a[ib0 + i].qs[q_offset + 81] & 0xf);
        const uint8_t q4_12 = uint8_t(data_a[ib0 + i].qs[q_offset + 64]  >> 4);
        const uint8_t q4_13 = uint8_t(data_a[ib0 + i].qs[q_offset + 65]  >> 4);
        const uint8_t q4_14 = uint8_t(data_a[ib0 + i].qs[q_offset + 80]  >> 4);
        const uint8_t q4_15 = uint8_t(data_a[ib0 + i].qs[q_offset + 81]  >> 4);

        const FLOAT_TYPE sx =
          fma(FLOAT_TYPE(data_b[b_offset + y1_idx     ]), (q4_0 + (((data_a[ib0 + i].qh[l0     ] & hm1) != 0) ? 16 : 0)),
          fma(FLOAT_TYPE(data_b[b_offset + y1_idx +  1]), (q4_1 + (((data_a[ib0 + i].qh[l0 +  1] & hm1) != 0) ? 16 : 0)),
          fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 16]), (q4_2 + (((data_a[ib0 + i].qh[l0 + 16] & hm1) != 0) ? 16 : 0)),
             FLOAT_TYPE(data_b[b_offset + y1_idx + 17]) * (q4_3 + (((data_a[ib0 + i].qh[l0 + 17] & hm1) != 0) ? 16 : 0)))));
        const FLOAT_TYPE sy =
          fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 32]), (q4_4 + (((data_a[ib0 + i].qh[l0     ] & (hm1 << 1)) != 0) ? 16 : 0)),
          fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 33]), (q4_5 + (((data_a[ib0 + i].qh[l0 +  1] & (hm1 << 1)) != 0) ? 16 : 0)),
          fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 48]), (q4_6 + (((data_a[ib0 + i].qh[l0 + 16] & (hm1 << 1)) != 0) ? 16 : 0)),
             FLOAT_TYPE(data_b[b_offset + y1_idx + 49]) * (q4_7 + (((data_a[ib0 + i].qh[l0 + 17] & (hm1 << 1)) != 0) ? 16 : 0)))));
        const FLOAT_TYPE sz =
          fma(FLOAT_TYPE(data_b[b_offset + y2_idx     ]), (q4_8  + (((data_a[ib0 + i].qh[l0     ] & hm2) != 0) ? 16 : 0)),
          fma(FLOAT_TYPE(data_b[b_offset + y2_idx +  1]), (q4_9  + (((data_a[ib0 + i].qh[l0 +  1] & hm2) != 0) ? 16 : 0)),
          fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 16]), (q4_10 + (((data_a[ib0 + i].qh[l0 + 16] & hm2) != 0) ? 16 : 0)),
             FLOAT_TYPE(data_b[b_offset + y2_idx + 17]) * (q4_11 + (((data_a[ib0 + i].qh[l0 + 17] & hm2) != 0) ? 16 : 0)))));
        const FLOAT_TYPE sw =
          fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 32]), (q4_12 + (((data_a[ib0 + i].qh[l0     ] & (hm2 << 1)) != 0) ? 16 : 0)),
          fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 33]), (q4_13 + (((data_a[ib0 + i].qh[l0 +  1] & (hm2 << 1)) != 0) ? 16 : 0)),
          fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 48]), (q4_14 + (((data_a[ib0 + i].qh[l0 + 16] & (hm2 << 1)) != 0) ? 16 : 0)),
             FLOAT_TYPE(data_b[b_offset + y2_idx + 49]) * (q4_15 + (((data_a[ib0 + i].qh[l0 + 17] & (hm2 << 1)) != 0) ? 16 : 0)))));
        const FLOAT_TYPE smin =
          fma(FLOAT_TYPE(data_b[b_offset + y1_idx     ]) + FLOAT_TYPE(data_b[b_offset + y1_idx + 1 ]) + FLOAT_TYPE(data_b[b_offset + y1_idx + 16]) + FLOAT_TYPE(data_b[b_offset + y1_idx + 17]), sc2,
          fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 32]) + FLOAT_TYPE(data_b[b_offset + y1_idx + 33]) + FLOAT_TYPE(data_b[b_offset + y1_idx + 48]) + FLOAT_TYPE(data_b[b_offset + y1_idx + 49]), sc3,
          fma(FLOAT_TYPE(data_b[b_offset + y2_idx     ]) + FLOAT_TYPE(data_b[b_offset + y2_idx + 1 ]) + FLOAT_TYPE(data_b[b_offset + y2_idx + 16]) + FLOAT_TYPE(data_b[b_offset + y2_idx + 17]), sc6,
              (FLOAT_TYPE(data_b[b_offset + y2_idx + 32]) + FLOAT_TYPE(data_b[b_offset + y2_idx + 33]) + FLOAT_TYPE(data_b[b_offset + y2_idx + 48]) + FLOAT_TYPE(data_b[b_offset + y2_idx + 49])) * sc7)));
        const uint tmp_idx = 16 * ix + tid;
        tmp[tmp_idx] = fma(dall, fma(sx, sc0, fma(sy, sc1, fma(sz, sc4, sw * sc5))), fma(-dmin, smin, tmp[tmp_idx]));
    }

    // sum up partial sums and write back result
    barrier();
    [[unroll]] for (uint s = 16; s > 0; s >>= 1) {
        if (tid < s) {
            tmp[tid] += tmp[tid + s];
        }
        barrier();
    }
    if (tid == 0) {
        data_d[d_offset + row] = D_TYPE(tmp[0]);
    }
}