File size: 2,428 Bytes
44ac4da
 
1adbed4
 
 
 
 
 
1934207
 
 
f98bc09
d9d6497
 
1adbed4
 
 
 
f98bc09
 
f86ffab
df3088f
 
 
622e054
1934207
622e054
 
 
 
f86ffab
df3088f
2e71f02
622e054
4719f35
622e054
4719f35
 
622e054
1934207
 
622e054
f86ffab
df3088f
 
4885a19
 
f86ffab
1934207
 
 
df3088f
d9d6497
1934207
d9d6497
1934207
d9d6497
 
4885a19
 
 
f86ffab
23b5435
df3088f
b9ffa51
df3088f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b9ffa51
df3088f
 
1934207
df3088f
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import streamlit as st

def number_field(label, **args):
  c1, c2 = st.columns([2, 4])
  c1.write(label)
  
  return c2.number_input('', **args)

TFLOPS = 312e12
GB_S = 1935e9

st.header("Transformer parameters")
col1, col2 = st.columns([2, 4])

bs = number_field('Batch size', value=10)
h = number_field('Num heads', value=16)
d = number_field('Dimension', value=768)
n = number_field('Seq length', value=1024)

st.header('Query, Key, Value projection')


mha_flop = 2*bs*1*d*3*d
mha_bytes = 2*bs*1*d + 2*3*d*d + 2*bs*1*3*d
mha_int = mha_flop/mha_bytes
mha_time = (mha_flop/TFLOPS + mha_bytes/GB_S)*1000

mha_flop = round(mha_flop/1e9, 2)
mha_bytes = round(mha_bytes/1e6, 2)


st.subheader("Multi-Head Attention")
c1, c2 = st.columns([2, 3])
c1.write("GFLOP:")
c2.write(str(mha_flop))
c1.write("MB: ")
c2.write(str(mha_bytes))
c1.write("Arithm. intensity:")
c2.write(str(mha_int))
c1.write("Time (ms):")
c2.write(str(mha_time))


mqa_flop = 2*bs*1*d*(1+2/h)*d
mqa_bytes = 2*bs*1*d + 2*(2/h)*d*d + 2*bs*1*(2/h)*d
mqa_intensity = mqa_flop/mqa_bytes
mqa_time = (mqa_flop/TFLOPS + mqa_bytes/GB_S)*1000

mqa_flop = round(mqa_flop/1e9, 2)
mqa_bytes = round(mqa_bytes/1e6, 2)

st.subheader("Multi-Query Attention")
c1, c2 = st.columns([2, 3])
c1.write("GFLOP:")
c2.write(str(mqa_flop))
c1.write("MB:")
c2.write(str(mqa_bytes))
c1.write("Arithm. intensity:")
c2.write(str(mqa_intensity))
c1.write("Time (ms):")
c2.write(str(mqa_time))

st.header('Attention')
mha_flop = 2*bs*h*(d/h)*n
mha_bytes = 2*bs*h*(d/h) + 2*bs*h*n*(d/h) + 2*bs*h*n
mha_int = mha_flop/mha_bytes
mha_time = (mha_flop/TFLOPS + mha_bytes/GB_S)*1000

mha_flop = round(mha_flop/1e9, 2)
mha_bytes = round(mha_bytes/1e6, 2)


st.subheader("Multi-Head Attention")
c1, c2 = st.columns([2, 3])
c1.write("GFLOP:")
c2.write(str(mha_flop))
c1.write("MB: ")
c2.write(str(mha_bytes))
c1.write("Arithm. intensity:")
c2.write(str(mha_int))
c1.write("Time (ms):")
c2.write(str(mha_time))

mqa_flop = 2*bs*h*(d/h)*n
mqa_bytes = 2*bs*h*(d/h) + 2*bs*n*(d/h) + 2*bs*h*n
mqa_intensity = mqa_flop/mqa_bytes
mqa_time = (mqa_flop/TFLOPS + mqa_bytes/GB_S)*1000

mqa_flop = round(mqa_flop/1e9, 2)
mqa_bytes = round(mqa_bytes/1e6, 2)

st.subheader("Multi-Query Attention")
c1, c2 = st.columns([2, 3])
c1.write("GFLOP:")
c2.write(str(mqa_flop))
c1.write("MB:")
c2.write(str(mqa_bytes))
c1.write("Arithm. intensity:")
c2.write(str(mqa_intensity))
c1.write("Time (ms):")
c2.write(str(mqa_time))