harmdevries commited on
Commit
df3088f
·
1 Parent(s): 4885a19

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +44 -6
app.py CHANGED
@@ -19,8 +19,9 @@ n = number_field('Seq length', value=1024)
19
 
20
  st.header('Query, Key, Value projection')
21
 
22
- mha_flop = 2*bs*n*d*3*d
23
- mha_bytes = 2*bs*n*d + 2*3*d*d + 2*bs*n*3*d
 
24
  mha_int = mha_flop/mha_bytes
25
  mha_time = (mha_flop/TFLOPS + mha_bytes/GB_S)*1000
26
 
@@ -28,7 +29,7 @@ mha_flop = round(mha_flop/1e9, 2)
28
  mha_bytes = round(mha_bytes/1e6, 2)
29
 
30
 
31
- st.subheader("Multi-head Attention")
32
  c1, c2 = st.columns([2, 3])
33
  c1.write("GFLOP:")
34
  c2.write(str(mha_flop))
@@ -40,15 +41,15 @@ c1.write("Time (ms):")
40
  c2.write(str(mha_time))
41
 
42
 
43
- mqa_flop = 2*bs*n*d*(1+2/h)*d
44
- mqa_bytes = 2*bs*n*d + 2*(2/h)*d*d + 2*bs*n*(2/h)*d
45
  mqa_intensity = mqa_flop/mqa_bytes
46
  mqa_time = (mqa_flop/TFLOPS + mqa_bytes/GB_S)*1000
47
 
48
  mqa_flop = round(mqa_flop/1e9, 2)
49
  mqa_bytes = round(mqa_bytes/1e6, 2)
50
 
51
- st.subheader("Multi-query Attention")
52
  c1, c2 = st.columns([2, 3])
53
  c1.write("GFLOP:")
54
  c2.write(str(mqa_flop))
@@ -60,4 +61,41 @@ c1.write("Time (ms):")
60
  c2.write(str(mqa_time))
61
 
62
  st.header('Attention')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
  st.header('Query, Key, Value projection')
21
 
22
+
23
+ mha_flop = 2*bs*1*d*3*d
24
+ mha_bytes = 2*bs*1*d + 2*3*d*d + 2*bs*1*3*d
25
  mha_int = mha_flop/mha_bytes
26
  mha_time = (mha_flop/TFLOPS + mha_bytes/GB_S)*1000
27
 
 
29
  mha_bytes = round(mha_bytes/1e6, 2)
30
 
31
 
32
+ st.subheader("Multi-Head Attention")
33
  c1, c2 = st.columns([2, 3])
34
  c1.write("GFLOP:")
35
  c2.write(str(mha_flop))
 
41
  c2.write(str(mha_time))
42
 
43
 
44
+ mqa_flop = 2*bs*1*d*(1+2/h)*d
45
+ mqa_bytes = 2*bs*1*d + 2*(2/h)*d*d + 2*bs*1*(2/h)*d
46
  mqa_intensity = mqa_flop/mqa_bytes
47
  mqa_time = (mqa_flop/TFLOPS + mqa_bytes/GB_S)*1000
48
 
49
  mqa_flop = round(mqa_flop/1e9, 2)
50
  mqa_bytes = round(mqa_bytes/1e6, 2)
51
 
52
+ st.subheader("Multi-Query Attention")
53
  c1, c2 = st.columns([2, 3])
54
  c1.write("GFLOP:")
55
  c2.write(str(mqa_flop))
 
61
  c2.write(str(mqa_time))
62
 
63
  st.header('Attention')
64
+ mha_flop = 2*bs*h*(d/h)*n
65
+ mha_bytes = 2*bs*h*(d/h) + 2*bs*h*n*(d/h) + 2*bs*h*n*n
66
+ mha_int = mha_flop/mha_bytes
67
+ mha_time = (mha_flop/TFLOPS + mha_bytes/GB_S)*1000
68
+
69
+ mha_flop = round(mha_flop/1e9, 2)
70
+ mha_bytes = round(mha_bytes/1e6, 2)
71
+
72
+
73
+ st.subheader("Multi-Head Attention")
74
+ c1, c2 = st.columns([2, 3])
75
+ c1.write("GFLOP:")
76
+ c2.write(str(mha_flop))
77
+ c1.write("MB: ")
78
+ c2.write(str(mha_bytes))
79
+ c1.write("Arithm. intensity:")
80
+ c2.write(str(mha_int))
81
+ c1.write("Time (ms):")
82
+ c2.write(str(mha_time))
83
+
84
+ mqa_flop = 2*bs*h*(d/h)*n
85
+ mqa_bytes = 2*bs*h*(d/h) + 2*bs*n*(d/h) + 2*bs*h*n*n
86
+ mqa_intensity = mqa_flop/mqa_bytes
87
+ mqa_time = (mqa_flop/TFLOPS + mqa_bytes/GB_S)*1000
88
 
89
+ mqa_flop = round(mqa_flop/1e9, 2)
90
+ mqa_bytes = round(mqa_bytes/1e6, 2)
91
+
92
+ st.subheader("Multi-Query Attention")
93
+ c1, c2 = st.columns([2, 3])
94
+ c1.write("GFLOP:")
95
+ c2.write(str(mqa_flop))
96
+ c1.write("MB:")
97
+ c2.write(str(mqa_bytes))
98
+ c1.write("Arithm. intensity:")
99
+ c2.write(str(mqa_intensity))
100
+ c1.write("Time (ms):")
101
+ c2.write(str(mqa_time))