Spaces:
Runtime error
Runtime error
File size: 2,428 Bytes
44ac4da 1adbed4 1934207 f98bc09 d9d6497 1adbed4 f98bc09 f86ffab df3088f 622e054 1934207 622e054 f86ffab df3088f 2e71f02 622e054 4719f35 622e054 4719f35 622e054 1934207 622e054 f86ffab df3088f 4885a19 f86ffab 1934207 df3088f d9d6497 1934207 d9d6497 1934207 d9d6497 4885a19 f86ffab 23b5435 df3088f b9ffa51 df3088f b9ffa51 df3088f 1934207 df3088f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 |
import streamlit as st
def number_field(label, **args):
c1, c2 = st.columns([2, 4])
c1.write(label)
return c2.number_input('', **args)
TFLOPS = 312e12
GB_S = 1935e9
st.header("Transformer parameters")
col1, col2 = st.columns([2, 4])
bs = number_field('Batch size', value=10)
h = number_field('Num heads', value=16)
d = number_field('Dimension', value=768)
n = number_field('Seq length', value=1024)
st.header('Query, Key, Value projection')
mha_flop = 2*bs*1*d*3*d
mha_bytes = 2*bs*1*d + 2*3*d*d + 2*bs*1*3*d
mha_int = mha_flop/mha_bytes
mha_time = (mha_flop/TFLOPS + mha_bytes/GB_S)*1000
mha_flop = round(mha_flop/1e9, 2)
mha_bytes = round(mha_bytes/1e6, 2)
st.subheader("Multi-Head Attention")
c1, c2 = st.columns([2, 3])
c1.write("GFLOP:")
c2.write(str(mha_flop))
c1.write("MB: ")
c2.write(str(mha_bytes))
c1.write("Arithm. intensity:")
c2.write(str(mha_int))
c1.write("Time (ms):")
c2.write(str(mha_time))
mqa_flop = 2*bs*1*d*(1+2/h)*d
mqa_bytes = 2*bs*1*d + 2*(2/h)*d*d + 2*bs*1*(2/h)*d
mqa_intensity = mqa_flop/mqa_bytes
mqa_time = (mqa_flop/TFLOPS + mqa_bytes/GB_S)*1000
mqa_flop = round(mqa_flop/1e9, 2)
mqa_bytes = round(mqa_bytes/1e6, 2)
st.subheader("Multi-Query Attention")
c1, c2 = st.columns([2, 3])
c1.write("GFLOP:")
c2.write(str(mqa_flop))
c1.write("MB:")
c2.write(str(mqa_bytes))
c1.write("Arithm. intensity:")
c2.write(str(mqa_intensity))
c1.write("Time (ms):")
c2.write(str(mqa_time))
st.header('Attention')
mha_flop = 2*bs*h*(d/h)*n
mha_bytes = 2*bs*h*(d/h) + 2*bs*h*n*(d/h) + 2*bs*h*n
mha_int = mha_flop/mha_bytes
mha_time = (mha_flop/TFLOPS + mha_bytes/GB_S)*1000
mha_flop = round(mha_flop/1e9, 2)
mha_bytes = round(mha_bytes/1e6, 2)
st.subheader("Multi-Head Attention")
c1, c2 = st.columns([2, 3])
c1.write("GFLOP:")
c2.write(str(mha_flop))
c1.write("MB: ")
c2.write(str(mha_bytes))
c1.write("Arithm. intensity:")
c2.write(str(mha_int))
c1.write("Time (ms):")
c2.write(str(mha_time))
mqa_flop = 2*bs*h*(d/h)*n
mqa_bytes = 2*bs*h*(d/h) + 2*bs*n*(d/h) + 2*bs*h*n
mqa_intensity = mqa_flop/mqa_bytes
mqa_time = (mqa_flop/TFLOPS + mqa_bytes/GB_S)*1000
mqa_flop = round(mqa_flop/1e9, 2)
mqa_bytes = round(mqa_bytes/1e6, 2)
st.subheader("Multi-Query Attention")
c1, c2 = st.columns([2, 3])
c1.write("GFLOP:")
c2.write(str(mqa_flop))
c1.write("MB:")
c2.write(str(mqa_bytes))
c1.write("Arithm. intensity:")
c2.write(str(mqa_intensity))
c1.write("Time (ms):")
c2.write(str(mqa_time))
|