Spaces:
Runtime error
Runtime error
File size: 2,937 Bytes
44ac4da 35a9fc2 1adbed4 28b4830 1adbed4 c904168 28b4830 ea57214 c904168 3edd3ca c904168 3edd3ca c904168 3edd3ca c904168 3edd3ca c904168 1934207 f98bc09 d9d6497 1adbed4 8cbefab a275f69 f98bc09 a275f69 f86ffab a275f69 c904168 2e71f02 c904168 f86ffab a275f69 df3088f d9d6497 c904168 f86ffab a275f69 c904168 a275f69 df3088f b9ffa51 c904168 8cbefab df3088f a275f69 8cbefab df3088f a275f69 8cbefab a275f69 28b4830 8cbefab df3088f a275f69 28b4830 c904168 1934207 a275f69 c904168 54cd0e6 c904168 54cd0e6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 |
import streamlit as st
def number_field(label, **kwargs):
c1, c2 = st.columns([2, 4])
c1.write(label)
return c2.number_input('', **kwargs)
def print_kernel_execution(c1, c2, comp_flop, mem_bytes):
arith_int = comp_flop/mem_bytes
exec_time = (comp_flop/TFLOPS + mem_bytes/GB_S)*1000
comp_flop = round(comp_flop/1e9, 2)
mem_bytes = round(mem_bytes/1e6, 2)
c1.write("GFLOP:")
c2.write(str(comp_flop))
c1.write("MB: ")
c2.write(str(mem_bytes))
c1.write("Arithm. intensity:")
c2.write(str(arith_int))
c1.write("Time (ms):")
c2.write(str(exec_time))
return exec_time
TFLOPS = 312e12
GB_S = 1935e9
st.header("Transformer parameters")
col1, col2 = st.columns([2, 4])
bs = number_field('Batch size', value=10)
h = number_field('Num heads', value=16)
d = number_field('Dimension', value=768)
n_start = number_field('Start seq', value=1)
n = number_field('End seq', value=1024)
l = number_field('Num layers', value=24)
st.header('Attention layer')
st.subheader('QKV projection')
st.caption("Multi-Head Attention")
mha_flop = 2*bs*1*d*3*d
mha_bytes = 2*bs*1*d + 2*3*d*d + 2*bs*1*3*d
c1, c2 = st.columns([2, 3])
qkv_mha_time = print_kernel_execution(c1, c2, mha_flop, mha_bytes)
st.caption("Multi-Query Attention")
mqa_flop = 2*bs*1*d*(1+2/h)*d
mqa_bytes = 2*bs*1*d + 2*(2/h)*d*d + 2*bs*1*(2/h)*d
c1, c2 = st.columns([2, 3])
qkv_mha_time = print_kernel_execution(c1, c2, mqa_flop, mqa_bytes)
st.subheader('QK gemm')
st.write("Note that calculation depends on sequence length (n)")
st.caption("Multi-Head Attention")
mha_flop = 2*bs*h*(d/h)*n
mha_bytes = 2*bs*h*(d/h) + 2*bs*h*n*(d/h) + 2*bs*h*n
c1, c2 = st.columns([2, 3])
att1_mha_time = print_kernel_execution(c1, c2, mha_flop, mha_bytes)
st.caption("Multi-Query Attention")
mqa_flop = 2*bs*h*(d/h)*n
mqa_bytes = 2*bs*h*(d/h) + 2*bs*n*(d/h) + 2*bs*h*n
c1, c2 = st.columns([2, 3])
att1_mqa_time = print_kernel_execution(c1, c2, mqa_flop, mqa_bytes)
st.header('Attention-value gemm')
st.write("Calculation depends on sequence length. We show numbers for maximum sequence length n.")
st.caption("Multi-Head Attention")
mha_flop = 2*bs*h*n*(d/h)
mha_bytes = 2*bs*h*n + 2*bs*h*n*(d/h) + 2*bs*h*(d/h)
c1, c2 = st.columns([2, 3])
att_mha_time = print_kernel_execution(c1, c2, mha_flop, mha_bytes)
st.caption("Multi-Query Attention")
mqa_flop = 2*bs*h*n*(d/h)
mqa_bytes = 2*bs*n*(d/h) + 2*bs*n*(d/h) + 2*bs*h*(d/h)
c1, c2 = st.columns([2, 3])
att_mqa_time = print_kernel_execution(c1, c2, mqa_flop, mqa_bytes)
st.subheader('Output projection')
st.header('MLP')
st.subheader('First Linear')
mlp1_flop = 2*bs*1*d*4*d
mlp1_bytes = 2*bs*1*d + 2*d*4*d + 2*bs*1*4*d
c1, c2 = st.columns([2, 3])
mlp1_time = print_kernel_execution(c1, c2, mlp1_flop, mlp1_bytes)
st.subheader('Second Linear')
mlp2_flop = 2*bs*1*d*4*d
mlp2_bytes = 2*bs*1*d + 2*d*4*d + 2*bs*1*4*d
c1, c2 = st.columns([2, 3])
mlp2_time = print_kernel_execution(c1, c2, mlp2_flop, mlp2_bytes)
|