harmdevries commited on
Commit
87bf3c7
·
1 Parent(s): bc2a18b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +41 -31
app.py CHANGED
@@ -1,16 +1,39 @@
1
  import streamlit as st
2
 
3
- # A100 specs
4
- TFLOPS = 312e12
5
- GB_S = 1935e9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
  # in ms
8
  THREAD_OVERHEAD = 0.005
9
 
10
  # in ms
11
- def calc_exec_time(comp_flop, mem_bytes):
12
  exec_time = comp_flop/TFLOPS + mem_bytes/GB_S
13
- return max(exec_time*1000, THREAD_OVERHEAD)
 
 
14
 
15
  def qkv_mha_exec(bs, h, n, d):
16
  flop = 2*bs*1*d*3*d
@@ -72,10 +95,11 @@ def mlp_exec(bs, h, n, d):
72
  exec_time = calc_exec_time(flop, nbytes)
73
  return flop, nbytes, exec_time
74
 
75
- def print_kernel_execution(c1, c2, comp_flop, mem_bytes):
76
- exec_time = calc_exec_time(comp_flop, mem_bytes)
77
- comp_flop = round(comp_flop/1e9, 2)
78
- mem_bytes = round(mem_bytes/1e6, 2)
 
79
 
80
  c1.write("GFLOP:")
81
  c2.write(str(comp_flop))
@@ -83,22 +107,11 @@ def print_kernel_execution(c1, c2, comp_flop, mem_bytes):
83
  c2.write(str(mem_bytes))
84
  c1.write("Time (ms):")
85
  c2.write(str(exec_time))
 
 
86
 
87
  return exec_time
88
 
89
- st.sidebar.header("Transformer parameters")
90
- col1, col2 = st.sidebar.columns([2, 4])
91
-
92
- bs = st.sidebar.number_input('Batch size', value=10)
93
- h = st.sidebar.number_input('Num heads',value=16)
94
- d = st.sidebar.number_input('Dimension', value=768)
95
- l = st.sidebar.number_input('Num layers', value=24)
96
-
97
- n_start = st.sidebar.number_input('Start seq', value=1)
98
- n = st.sidebar.number_input('End seq', value=1024)
99
-
100
- st.sidebar.header("GPU parameters")
101
-
102
 
103
  st.header("Execution time (ms)")
104
 
@@ -151,16 +164,13 @@ if breakdown:
151
 
152
  st.subheader('QKV projection')
153
  st.caption("Multi-Head Attention")
154
- mha_flop = 2*bs*1*d*3*d
155
- mha_bytes = 2*bs*1*d + 2*3*d*d + 2*bs*1*3*d
156
- c1, c2 = st.columns([2, 3])
157
- qkv_mha_time = print_kernel_execution(c1, c2, mha_flop, mha_bytes)
158
 
159
  st.caption("Multi-Query Attention")
160
- mqa_flop = 2*bs*1*d*(1+2/h)*d
161
- mqa_bytes = 2*bs*1*d + 2*(2/h)*d*d + 2*bs*1*(2/h)*d
162
- c1, c2 = st.columns([2, 3])
163
- qkv_mqa_time = print_kernel_execution(c1, c2, mqa_flop, mqa_bytes)
164
 
165
  st.subheader('QK gemm')
166
  st.write("Showing calculation for the maximum sequence length (n)")
@@ -187,7 +197,7 @@ if breakdown:
187
 
188
  st.caption("Multi-Query Attention")
189
  mqa_flop = 2*bs*h*n*(d/h)
190
- mqa_bytes = 2*bs*n*(d/h) + 2*bs*n*(d/h) + 2*bs*h*(d/h)
191
  c1, c2 = st.columns([2, 3])
192
  att2_mqa_time = print_kernel_execution(c1, c2, mqa_flop, mqa_bytes)
193
 
 
1
  import streamlit as st
2
 
3
+ st.sidebar.header("Transformer parameters")
4
+ col1, col2 = st.sidebar.columns([2, 4])
5
+
6
+ bs = st.sidebar.number_input('Batch size', value=10)
7
+ h = st.sidebar.number_input('Num heads',value=16)
8
+ d = st.sidebar.number_input('Dimension', value=768)
9
+ l = st.sidebar.number_input('Num layers', value=24)
10
+
11
+ n_start = st.sidebar.number_input('Start seq', value=1)
12
+ n = st.sidebar.number_input('End seq', value=1024)
13
+
14
+ st.sidebar.header("GPU parameters")
15
+
16
+ GPU = st.selectbox('GPU', ('A100', 'V100'))
17
+
18
+ if GPU == 'A100':
19
+ # A100 specs
20
+ TFLOPS = 312e12
21
+ GB_S = 1935e9
22
+ elif GPU == 'V100':
23
+ TFLOPS = 112e12
24
+ GB_S = 900e9
25
+ else:
26
+ raise ValueError('Unknown GPU')
27
 
28
  # in ms
29
  THREAD_OVERHEAD = 0.005
30
 
31
  # in ms
32
+ def calc_exec_time(comp_flop, mem_bytes, include_overhead=True):
33
  exec_time = comp_flop/TFLOPS + mem_bytes/GB_S
34
+ if include_overhead:
35
+ exec_time = max(exec_time*1000, THREAD_OVERHEAD)
36
+ return exec_time
37
 
38
  def qkv_mha_exec(bs, h, n, d):
39
  flop = 2*bs*1*d*3*d
 
95
  exec_time = calc_exec_time(flop, nbytes)
96
  return flop, nbytes, exec_time
97
 
98
+ def print_kernel_execution(flop, mem_bytes):
99
+ c1, c2 = st.columns([2, 3])
100
+ exec_time = calc_exec_time(comp_flop, mem_bytes, include_overhead=False)
101
+ flop = round(flop/1e9, 2)
102
+ nbytes = round(nbytes/1e6, 2)
103
 
104
  c1.write("GFLOP:")
105
  c2.write(str(comp_flop))
 
107
  c2.write(str(mem_bytes))
108
  c1.write("Time (ms):")
109
  c2.write(str(exec_time))
110
+ c1.write("Overhead (ms):")
111
+ c2.write(str(THREAD_OVERHEAD))
112
 
113
  return exec_time
114
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
 
116
  st.header("Execution time (ms)")
117
 
 
164
 
165
  st.subheader('QKV projection')
166
  st.caption("Multi-Head Attention")
167
+ flop, nbytes, exec_time = qkv_mha_exec(bs, h, n, d)
168
+ print_kernel_execution(flop, nbytes)
169
+
 
170
 
171
  st.caption("Multi-Query Attention")
172
+ flop, nbytes, exec_time = qkv_mqa_exec(bs, h, n, d)
173
+ print_kernel_execution(flop, nbytes)
 
 
174
 
175
  st.subheader('QK gemm')
176
  st.write("Showing calculation for the maximum sequence length (n)")
 
197
 
198
  st.caption("Multi-Query Attention")
199
  mqa_flop = 2*bs*h*n*(d/h)
200
+ mqa_bytes = 2*bs*h*n + 2*bs*n*(d/h) + 2*bs*h*(d/h)
201
  c1, c2 = st.columns([2, 3])
202
  att2_mqa_time = print_kernel_execution(c1, c2, mqa_flop, mqa_bytes)
203