File size: 1,325 Bytes
44ac4da
 
1adbed4
 
 
 
 
 
1934207
 
 
f98bc09
d9d6497
 
1adbed4
 
 
 
f98bc09
 
f86ffab
f98bc09
 
622e054
1934207
622e054
 
 
 
f86ffab
23b5435
2e71f02
622e054
4719f35
622e054
4719f35
 
622e054
1934207
 
622e054
f86ffab
 
 
 
1934207
 
 
23b5435
d9d6497
1934207
d9d6497
1934207
d9d6497
 
 
f86ffab
23b5435
1934207
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import streamlit as st

def number_field(label, **args):
  c1, c2 = st.columns([2, 4])
  c1.write(label)
  
  return c2.number_input('', **args)

TFLOPS = 312e12
GB_S = 1935e9

st.header("Transformer parameters")
col1, col2 = st.columns([2, 4])

bs = number_field('Batch size', value=10)
h = number_field('Num heads', value=16)
d = number_field('Dimension', value=768)
n = number_field('Seq length', value=1024)

st.header('Query, Key, Value projection')

mha_flop = 2*bs*n*d*3*d
mha_bytes = 2*bs*n*d + 2*3*d*d + 2*bs*n*3*d
mha_int = mha_flop/mha_bytes
mha_time = (mha_flop/TFLOPS + mha_bytes/GB_S)*1000

mha_flop = round(mha_flop/1e9, 2)
mha_bytes = round(mha_bytes/1e6, 2)


st.subheader("Multi-query Attention")
c1, c2 = st.columns([2, 3])
c1.write("GFLOP:")
c2.write(str(mha_flop))
c1.write("MB: ")
c2.write(str(mha_bytes))
c1.write("Arithm. intensity:")
c2.write(str(mha_int))
c1.write("Time (ms):")
c2.write(str(mha_time))


mqa_flop = 2*bs*n*d*(1+2/h)*d
mqa_bytes = 2*bs*n*d + 2*(2/h)*d*d + 2*bs*n*(2/h)*d

mqa_flop = round(mqa_flop/1e9, 2)
mqa_bytes = round(mqa_bytes/1e6, 2)

st.subheader("Multi-query Attention")
c1, c2 = st.columns([2, 3])
c1.write("GFLOP:")
c2.write(str(mqa_flop))
c1.write("MB:")
c2.write(str(mqa_bytes))
c1.write("Arithm. intensity:")
c2.write(str(mqa_flop/mqa_bytes))

st.header('Attention')