Spaces:
Runtime error
Runtime error
File size: 1,850 Bytes
44ac4da 1adbed4 c904168 1934207 f98bc09 d9d6497 1adbed4 f98bc09 f86ffab df3088f f86ffab df3088f c904168 2e71f02 c904168 f86ffab c904168 df3088f d9d6497 c904168 f86ffab 23b5435 c904168 df3088f b9ffa51 c904168 df3088f c904168 df3088f b9ffa51 c904168 1934207 df3088f c904168 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 |
import streamlit as st
def number_field(label, **args):
c1, c2 = st.columns([2, 4])
c1.write(label)
return c2.number_input('', **args)
def print_kernel_execution(c1, c2, comp_flop, mem_bytes):
arith_int = comp_flop/mem_bytes
exec_time = (comp_flop/TFLOPS + mem_bytes/GB_S)*1000
comp_flop = round(mha_flop/1e9, 2)
mem_bytes = round(mha_bytes/1e6, 2)
c1.write("GFLOP:")
c2.write(str(flop))
c1.write("MB: ")
c2.write(str(mha_bytes))
c1.write("Arithm. intensity:")
c2.write(str(mha_int))
c1.write("Time (ms):")
c2.write(str(mha_time))
return exec_time
TFLOPS = 312e12
GB_S = 1935e9
st.header("Transformer parameters")
col1, col2 = st.columns([2, 4])
bs = number_field('Batch size', value=10)
h = number_field('Num heads', value=16)
d = number_field('Dimension', value=768)
n = number_field('Seq length', value=1024)
st.header('Query, Key, Value projection')
st.subheader("Multi-Head Attention")
mha_flop = 2*bs*1*d*3*d
mha_bytes = 2*bs*1*d + 2*3*d*d + 2*bs*1*3*d
c1, c2 = st.columns([2, 3])
qkv_mha_time = print_kernel_execution(c1, c2, mha_flop, mha_bytes)
st.subheader("Multi-Query Attention")
mqa_flop = 2*bs*1*d*(1+2/h)*d
mqa_bytes = 2*bs*1*d + 2*(2/h)*d*d + 2*bs*1*(2/h)*d
c1, c2 = st.columns([2, 3])
qkv_mha_time = print_kernel_execution(c1, c2, mqa_flop, mqa_bytes)
st.header('Attention')
st.subheader("Multi-Head Attention")
mha_flop = 2*bs*h*(d/h)*n
mha_bytes = 2*bs*h*(d/h) + 2*bs*h*n*(d/h) + 2*bs*h*n
c1, c2 = st.columns([2, 3])
att_mha_time = print_kernel_execution(c1, c2, mha_flop, mha_bytes)
st.subheader("Multi-Query Attention")
mqa_flop = 2*bs*h*(d/h)*n
mqa_bytes = 2*bs*h*(d/h) + 2*bs*n*(d/h) + 2*bs*h*n
c1, c2 = st.columns([2, 3])
att_mqa_time = print_kernel_execution(c1, c2, mqa_flop, mqa_bytes)
st.header('MLP')
st.subheader('First Linear')
st.subheader('Second Linear')
|