harmdevries commited on
Commit
1934207
·
1 Parent(s): 622e054

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -2
app.py CHANGED
@@ -6,6 +6,9 @@ def number_field(label, **args):
6
 
7
  return c2.number_input('', **args)
8
 
 
 
 
9
  st.header("Transformer parameters")
10
  col1, col2 = st.columns([2, 4])
11
 
@@ -19,6 +22,7 @@ st.header('Query, Key, Value projection')
19
  mha_flop = 2*bs*n*d*3*d
20
  mha_bytes = 2*bs*n*d + 2*3*d*d + 2*bs*n*3*d
21
  mha_int = mha_flop/mha_bytes
 
22
 
23
  mha_flop = round(mha_flop/1e9, 2)
24
  mha_bytes = round(mha_bytes/1e6, 2)
@@ -32,18 +36,24 @@ c1.write("MB: ")
32
  c2.write(str(mha_bytes))
33
  c1.write("Arithm. intensity:")
34
  c2.write(str(mha_int))
 
 
35
 
36
 
37
  mqa_flop = 2*bs*n*d*(1+2/h)*d
38
  mqa_bytes = 2*bs*n*d + 2*(2/h)*d*d + 2*bs*n*(2/h)*d
39
 
 
 
 
40
  st.subheader("Multi-query Attention")
41
  c1, c2 = st.columns([2, 3])
42
- c1.write("FLOP:")
43
  c2.write(str(mqa_flop))
44
- c1.write("Bytes: ")
45
  c2.write(str(mqa_bytes))
46
  c1.write("Arithm. intensity:")
47
  c2.write(str(mqa_flop/mqa_bytes))
48
 
49
  st.header('Attention')
 
 
6
 
7
  return c2.number_input('', **args)
8
 
9
+ TFLOPS = 312e12
10
+ GB_S = 1935e9
11
+
12
  st.header("Transformer parameters")
13
  col1, col2 = st.columns([2, 4])
14
 
 
22
  mha_flop = 2*bs*n*d*3*d
23
  mha_bytes = 2*bs*n*d + 2*3*d*d + 2*bs*n*3*d
24
  mha_int = mha_flop/mha_bytes
25
+ mha_time = (mha_flop/TFLOPS + mha_bytes/GB_S)*1000
26
 
27
  mha_flop = round(mha_flop/1e9, 2)
28
  mha_bytes = round(mha_bytes/1e6, 2)
 
36
  c2.write(str(mha_bytes))
37
  c1.write("Arithm. intensity:")
38
  c2.write(str(mha_int))
39
+ c1.write("Time (ms):")
40
+ c2.write(str(mha_time))
41
 
42
 
43
  mqa_flop = 2*bs*n*d*(1+2/h)*d
44
  mqa_bytes = 2*bs*n*d + 2*(2/h)*d*d + 2*bs*n*(2/h)*d
45
 
46
+ mqa_flop = round(mqa_flop/1e9, 2)
47
+ mqa_bytes = round(mqa_bytes/1e6, 2)
48
+
49
  st.subheader("Multi-query Attention")
50
  c1, c2 = st.columns([2, 3])
51
+ c1.write("GFLOP:")
52
  c2.write(str(mqa_flop))
53
+ c1.write("MB:")
54
  c2.write(str(mqa_bytes))
55
  c1.write("Arithm. intensity:")
56
  c2.write(str(mqa_flop/mqa_bytes))
57
 
58
  st.header('Attention')
59
+