harmdevries commited on
Commit
409563e
·
1 Parent(s): 07abc51

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -9
app.py CHANGED
@@ -95,7 +95,7 @@ def mlp_exec(bs, h, n, d):
95
  exec_time = calc_exec_time(flop, nbytes)
96
  return flop, nbytes, exec_time
97
 
98
- def print_kernel_execution(flop, mem_bytes):
99
  c1, c2 = st.columns([2, 3])
100
  exec_time = calc_exec_time(flop, nbytes, include_overhead=False)
101
  flop = round(flop/1e9, 2)
@@ -175,16 +175,12 @@ if breakdown:
175
  st.write("Showing calculation for the maximum sequence length (n)")
176
 
177
  st.caption("Multi-Head Attention")
178
- mha_flop = 2*bs*h*(d/h)*n
179
- mha_bytes = 2*bs*h*(d/h) + 2*bs*h*n*(d/h) + 2*bs*h*n
180
- c1, c2 = st.columns([2, 3])
181
- att1_mha_time = print_kernel_execution(c1, c2, mha_flop, mha_bytes)
182
 
183
  st.caption("Multi-Query Attention")
184
- mqa_flop = 2*bs*h*(d/h)*n
185
- mqa_bytes = 2*bs*h*(d/h) + 2*bs*n*(d/h) + 2*bs*h*n
186
- c1, c2 = st.columns([2, 3])
187
- att1_mqa_time = print_kernel_execution(c1, c2, mqa_flop, mqa_bytes)
188
 
189
  st.subheader('Attention-value gemm')
190
  st.write("Showing calculation for the maximum sequence length (n)")
 
95
  exec_time = calc_exec_time(flop, nbytes)
96
  return flop, nbytes, exec_time
97
 
98
+ def print_kernel_execution(flop, nbytes):
99
  c1, c2 = st.columns([2, 3])
100
  exec_time = calc_exec_time(flop, nbytes, include_overhead=False)
101
  flop = round(flop/1e9, 2)
 
175
  st.write("Showing calculation for the maximum sequence length (n)")
176
 
177
  st.caption("Multi-Head Attention")
178
+ flop, nbytes, exec_time = att1_mha_exec(bs, h, n, d)
179
+ print_kernel_execution(flop, nbytes)
 
 
180
 
181
  st.caption("Multi-Query Attention")
182
+ flop, nbytes, exec_time = att1_mqa_exec(bs, h, n, d)
183
+ print_kernel_execution(flop, nbytes)
 
 
184
 
185
  st.subheader('Attention-value gemm')
186
  st.write("Showing calculation for the maximum sequence length (n)")