harmdevries commited on
Commit
67aca21
·
1 Parent(s): 1c3236f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -26
app.py CHANGED
@@ -31,8 +31,9 @@ THREAD_OVERHEAD = 0.005
31
  # in ms
32
  def calc_exec_time(comp_flop, mem_bytes, include_overhead=True):
33
  exec_time = comp_flop/TFLOPS + mem_bytes/GB_S
 
34
  if include_overhead:
35
- exec_time = max(exec_time*1000, THREAD_OVERHEAD)
36
  return exec_time
37
 
38
  def qkv_mha_exec(bs, h, n, d):
@@ -109,9 +110,6 @@ def print_kernel_execution(flop, nbytes):
109
  c2.write(str(exec_time))
110
  c1.write("Overhead (ms):")
111
  c2.write(str(THREAD_OVERHEAD))
112
-
113
- return exec_time
114
-
115
 
116
  st.header("Execution time (ms)")
117
 
@@ -185,44 +183,34 @@ if breakdown:
185
  st.subheader('Attention-value gemm')
186
  st.write("Showing calculation for the maximum sequence length (n)")
187
  st.caption("Multi-Head Attention")
188
- mha_flop = 2*bs*h*n*(d/h)
189
- mha_bytes = 2*bs*h*n + 2*bs*h*n*(d/h) + 2*bs*h*(d/h)
190
- c1, c2 = st.columns([2, 3])
191
- att2_mha_time = print_kernel_execution(c1, c2, mha_flop, mha_bytes)
192
 
193
  st.caption("Multi-Query Attention")
194
- mqa_flop = 2*bs*h*n*(d/h)
195
- mqa_bytes = 2*bs*h*n + 2*bs*n*(d/h) + 2*bs*h*(d/h)
196
- c1, c2 = st.columns([2, 3])
197
- att2_mqa_time = print_kernel_execution(c1, c2, mqa_flop, mqa_bytes)
198
 
199
  st.subheader('Output projection')
200
- out_flop = 2*bs*1*d*d
201
- out_bytes = 2*bs*1*d + 2*d*d + 2*bs*1*d
202
- c1, c2 = st.columns([2, 3])
203
- out_time = print_kernel_execution(c1, c2, out_flop, out_bytes)
204
 
205
  st.subheader('Element-wise ops')
206
  st.write("We also need to take into the softmax layer, layer norm, and residual connection. We assume that these operations are memory bound. ")
207
 
208
  st.caption("Softmax")
209
- softmax_bytes = 2*bs*h*n + 2*bs*h*n
210
- c1, c2 = st.columns([2, 3])
211
- softmax_time = print_kernel_execution(c1, c2, 0, softmax_bytes)
212
 
213
  st.caption("Layer norm/residual connection")
214
- ln_bytes = 2*bs*1*d
215
- ln_flop = 0
216
- ln_time = print_kernel_execution(c1, c2, 0, ln_bytes)
217
 
218
  st.header('MLP layer')
219
  st.subheader('First and Second Linear Layer')
220
  flop, nbytes, exec_time = mlp_exec(bs, h, n, d)
221
- c1, c2 = st.columns([2, 3])
222
- mlp2_time = print_kernel_execution(c1, c2, flop, nbytes)
223
 
224
  st.subheader('Element-wise ops')
225
  st.write("We also need to take into the GeLU, layer norm, and residual connection. We assume that these operations are memory bound. ")
226
  flop, nbytes, exec_time = ln_exec(bs, h, n, d)
227
- c1, c2 = st.columns([2, 3])
228
- mlp2_time = print_kernel_execution(c1, c2, flop, nbytes)
 
31
  # in ms
32
  def calc_exec_time(comp_flop, mem_bytes, include_overhead=True):
33
  exec_time = comp_flop/TFLOPS + mem_bytes/GB_S
34
+ exec_time *= 1000
35
  if include_overhead:
36
+ exec_time = max(exec_time, THREAD_OVERHEAD)
37
  return exec_time
38
 
39
  def qkv_mha_exec(bs, h, n, d):
 
110
  c2.write(str(exec_time))
111
  c1.write("Overhead (ms):")
112
  c2.write(str(THREAD_OVERHEAD))
 
 
 
113
 
114
  st.header("Execution time (ms)")
115
 
 
183
  st.subheader('Attention-value gemm')
184
  st.write("Showing calculation for the maximum sequence length (n)")
185
  st.caption("Multi-Head Attention")
186
+ flop, nbytes, exec_time = att2_mha_exec(bs, h, n, d)
187
+ print_kernel_execution(flop, nbytes)
 
 
188
 
189
  st.caption("Multi-Query Attention")
190
+ flop, nbytes, exec_time = att2_mqa_exec(bs, h, n, d)
191
+ print_kernel_execution(flop, nbytes)
 
 
192
 
193
  st.subheader('Output projection')
194
+ flop, nbytes, exec_time = out_exec(bs, h, n, d)
195
+ print_kernel_execution(flop, nbytes)
 
 
196
 
197
  st.subheader('Element-wise ops')
198
  st.write("We also need to take into the softmax layer, layer norm, and residual connection. We assume that these operations are memory bound. ")
199
 
200
  st.caption("Softmax")
201
+ flop, nbytes, exec_time = softmax_exec(bs, h, n, d)
202
+ print_kernel_execution(flop, nbytes)
 
203
 
204
  st.caption("Layer norm/residual connection")
205
+ flop, nbytes, exec_time = ln_exec(bs, h, n, d)
206
+ print_kernel_execution(flop, nbytes)
 
207
 
208
  st.header('MLP layer')
209
  st.subheader('First and Second Linear Layer')
210
  flop, nbytes, exec_time = mlp_exec(bs, h, n, d)
211
+ print_kernel_execution(flop, nbytes)
 
212
 
213
  st.subheader('Element-wise ops')
214
  st.write("We also need to take into the GeLU, layer norm, and residual connection. We assume that these operations are memory bound. ")
215
  flop, nbytes, exec_time = ln_exec(bs, h, n, d)
216
+ print_kernel_execution(flop, nbytes)