Spaces:
Runtime error
Runtime error
File size: 6,688 Bytes
44ac4da a21c8ab 86cd028 a21c8ab 86cd028 c904168 86cd028 c904168 28b4830 ea57214 c904168 3edd3ca c904168 3edd3ca c904168 3edd3ca c904168 a21c8ab 1934207 999ad78 d9d6497 9ff92c0 a21c8ab 9ff92c0 f98bc09 a21c8ab a275f69 f86ffab a275f69 c904168 2e71f02 c904168 f86ffab a275f69 df3088f d9d6497 2a3864d f86ffab a275f69 c904168 a275f69 df3088f b9ffa51 c904168 8cbefab df3088f a275f69 8cbefab df3088f f0d14cb 8cbefab a275f69 28b4830 8cbefab 729a063 df3088f a275f69 28b4830 c904168 729a063 1934207 a275f69 86cd028 f0d14cb 86cd028 f0d14cb 729a063 86cd028 3a743e7 86cd028 729a063 2a3864d a275f69 c904168 54cd0e6 c904168 54cd0e6 729a063 2a3864d 729a063 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 |
import streamlit as st
# A100 specs
TFLOPS = 312e12
GB_S = 1935e9
# in ms
THREAD_OVERHEAD = 0.005
def calc_exec_time(comp_flop, mem_bytes):
return (comp_flop/TFLOPS + mem_bytes/GB_S)*1000
def qkv_mha_exec(bs, h, n, d):
flop = 2*bs*1*d*3*d
nbytes = 2*bs*1*d + 2*3*d*d + 2*bs*1*3*d
exec_time = calc_exec_time(flop, nbytes)
return flop, nbytes, exec_time
def qkv_mqa_exec(bs, h, n, d):
flop = 2*bs*1*d*(1+2/h)*d
nbytes = 2*bs*1*d + 2*(2/h)*d*d + 2*bs*1*(2/h)*d
exec_time = calc_exec_time(flop, nbytes)
return flop, nbytes, exec_time
def att1_mha_exec(bs, h, n, d):
flop = 2*bs*h*(d/h)*n
nbytes = 2*bs*h*(d/h) + 2*bs*h*n*(d/h) + 2*bs*h*n
exec_time = calc_exec_time(flop, nbytes)
return flop, nbytes, exec_time
def att1_mqa_exec(bs, h, n, d):
flop = 2*bs*h*(d/h)*n
nbytes = 2*bs*h*(d/h) + 2*bs*n*(d/h) + 2*bs*h*n
exec_time = calc_exec_time(flop, nbytes)
return flop, nbytes, exec_time
def att2_mha_exec(bs, h, n, d):
flop = 2*bs*h*n*(d/h)
nbytes = 2*bs*h*n + 2*bs*h*n*(d/h) + 2*bs*h*(d/h)
exec_time = calc_exec_time(flop, nbytes)
return flop, nbytes, exec_time
def att2_mqa_exec(bs, h, n, d):
flop = 2*bs*h*n*(d/h)
nbytes = 2*bs*n*(d/h) + 2*bs*n*(d/h) + 2*bs*h*(d/h)
exec_time = calc_exec_time(flop, nbytes)
return flop, nbytes, exec_time
def out_exec(bs, h, n, d):
flop = 2*bs*1*d*d
nbytes = 2*bs*1*d + 2*d*d + 2*bs*1*d
exec_time = calc_exec_time(flop, nbytes)
return flop, nbytes, exec_time
def softmax_exec(bs, h, n, d):
flop = 0
nbytes = 2*bs*h*n + 2*bs*h*n
exec_time = calc_exec_time(flop, nbytes)
return flop, nbytes, exec_time
def ln_exec(bs, h, n, d):
nbytes = 2*bs*1*d + 2*bs*1*d
flop = 0
exec_time = calc_exec_time(flop, nbytes)
return flop, nbytes, exec_time
def mlp_exec(bs, h, n, d):
flop = 2*bs*1*d*4*d
nbytes = 2*bs*1*d + 2*d*4*d + 2*bs*1*4*d
exec_time = calc_exec_time(flop, nbytes)
return flop, nbytes, exec_time
def print_kernel_execution(c1, c2, comp_flop, mem_bytes):
arith_int = comp_flop/mem_bytes
exec_time = calc_exec_time(comp_flop, mem_bytes)
comp_flop = round(comp_flop/1e9, 2)
mem_bytes = round(mem_bytes/1e6, 2)
c1.write("GFLOP:")
c2.write(str(comp_flop))
c1.write("MB: ")
c2.write(str(mem_bytes))
c1.write("Time (ms):")
c2.write(str(exec_time))
return exec_time
st.sidebar.header("Transformer parameters")
col1, col2 = st.sidebar.columns([2, 4])
bs = st.sidebar.number_input('Batch size', value=10)
h = st.sidebar.number_input('Num heads',value=16)
d = st.sidebar.number_input('Dimension', value=768)
l = st.sidebar.number_input('Num layers', value=24)
n_start = st.sidebar.number_input('Start seq', value=1)
n = st.sidebar.number_input('End seq', value=1024)
st.sidebar.header("GPU parameters")
st.header("Total execution time")
mqa_total_time = 0.
mha_total_time = 0.
for i in range(n_start, n):
shared_time = out_exec(bs, h, i, d)[2] + softmax_exec(bs, h, i , d)[2] \
2*ln_exec(bs, h, i, d)[2] + 2*mlp_exec(bs, h, i, d)[2] + 3*ln_exec(bs, h, i, d)
mha_time = shared_time + qkv_mha_exec(bs, h, i, d)[2] + att1_mha_exec(bs, h, i, d)[2] + att2_mha_exec(bs, h, i, d)[2]
mha_total_time += l*mha_time
mqa_time = shared_time + qkv_mqa_exec(bs, h, i, d)[2] + att1_mqa_exec(bs, h, i, d)[2] + att2_mqa_exec(bs, h, i, d)[2]
mqa_total_time += l*mqa_time
st.write("MHA exec time: " + str(mha_total_time))
st.write("MQA exec time: " + str(mqa_total_time))
st.header('Attention layer')
st.subheader('QKV projection')
st.caption("Multi-Head Attention")
mha_flop = 2*bs*1*d*3*d
mha_bytes = 2*bs*1*d + 2*3*d*d + 2*bs*1*3*d
c1, c2 = st.columns([2, 3])
qkv_mha_time = print_kernel_execution(c1, c2, mha_flop, mha_bytes)
st.caption("Multi-Query Attention")
mqa_flop = 2*bs*1*d*(1+2/h)*d
mqa_bytes = 2*bs*1*d + 2*(2/h)*d*d + 2*bs*1*(2/h)*d
c1, c2 = st.columns([2, 3])
qkv_mqa_time = print_kernel_execution(c1, c2, mqa_flop, mqa_bytes)
st.subheader('QK gemm')
st.write("Note that calculation depends on sequence length (n)")
st.caption("Multi-Head Attention")
mha_flop = 2*bs*h*(d/h)*n
mha_bytes = 2*bs*h*(d/h) + 2*bs*h*n*(d/h) + 2*bs*h*n
c1, c2 = st.columns([2, 3])
att1_mha_time = print_kernel_execution(c1, c2, mha_flop, mha_bytes)
st.caption("Multi-Query Attention")
mqa_flop = 2*bs*h*(d/h)*n
mqa_bytes = 2*bs*h*(d/h) + 2*bs*n*(d/h) + 2*bs*h*n
c1, c2 = st.columns([2, 3])
att1_mqa_time = print_kernel_execution(c1, c2, mqa_flop, mqa_bytes)
st.subheader('Attention-value gemm')
st.write("Calculation depends on sequence length. We show numbers for maximum sequence length n.")
st.caption("Multi-Head Attention")
mha_flop = 2*bs*h*n*(d/h)
mha_bytes = 2*bs*h*n + 2*bs*h*n*(d/h) + 2*bs*h*(d/h)
c1, c2 = st.columns([2, 3])
att2_mha_time = print_kernel_execution(c1, c2, mha_flop, mha_bytes)
st.caption("Multi-Query Attention")
mqa_flop = 2*bs*h*n*(d/h)
mqa_bytes = 2*bs*n*(d/h) + 2*bs*n*(d/h) + 2*bs*h*(d/h)
c1, c2 = st.columns([2, 3])
att2_mqa_time = print_kernel_execution(c1, c2, mqa_flop, mqa_bytes)
st.subheader('Output projection')
out_flop = 2*bs*1*d*d
out_bytes = 2*bs*1*d + 2*d*d + 2*bs*1*d
c1, c2 = st.columns([2, 3])
out_time = print_kernel_execution(c1, c2, out_flop, out_bytes)
st.subheader('Element-wise ops')
st.write("We also need to take into the softmax layer, layer norm, and residual connection. We assume that these operations are memory bound. ")
st.caption("Softmax")
softmax_bytes = 2*bs*h*n + 2*bs*h*n
c1, c2 = st.columns([2, 3])
softmax_time = print_kernel_execution(c1, c2, 0, softmax_bytes)
st.caption("Layer norm/residual connection")
ln_bytes = 2*bs*1*d
ln_flop = 0
ln_time = print_kernel_execution(c1, c2, 0, ln_bytes)
st.header('MLP')
st.subheader('First Linear')
mlp1_flop = 2*bs*1*d*4*d
mlp1_bytes = 2*bs*1*d + 2*d*4*d + 2*bs*1*4*d
c1, c2 = st.columns([2, 3])
mlp1_time = print_kernel_execution(c1, c2, mlp1_flop, mlp1_bytes)
st.subheader('Second Linear')
mlp2_flop = 2*bs*1*d*4*d
mlp2_bytes = 2*bs*1*d + 2*d*4*d + 2*bs*1*4*d
c1, c2 = st.columns([2, 3])
mlp2_time = print_kernel_execution(c1, c2, mlp2_flop, mlp2_bytes)
st.subheader('Element-wise ops')
st.write("We also need to take into the GeLU, layer norm, and residual connection. We assume that these operations are memory bound. ")
ln_bytes = 2*bs*1*d
ln_flop = 0
ln_time = print_kernel_execution(c1, c2, 0, ln_bytes)
st.header("Adding it all up")
shared_time = out_time + softmax_time + 2*ln_time + mlp1_time + mlp2_time + 3*ln_time
mha_total_time = qkv_mha_time + att1_mha_time + att2_mha_time + shared_time
mqa_total_time = qkv_mqa_time + att1_mqa_time + att2_mqa_time + shared_time
st.write("MHA exec time (ms): " + str(mha_total_time))
st.write("MQA exec time (ms): " + str(mqa_total_time)) |