harmdevries commited on
Commit
c88286f
·
1 Parent(s): 064a5f0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +58 -59
app.py CHANGED
@@ -116,7 +116,7 @@ def print_kernel_execution(flop, nbytes):
116
  c2.write(str(THREAD_OVERHEAD))
117
 
118
  st.title("Inference time MHA vs MQA")
119
- st.write("This space approximates the inference time for Multi-Query Attention and Multi-Head Attention model. You can change the hyperparameters in sidebar.")
120
 
121
  mqa_total_time = 0.
122
  mha_total_time = 0.
@@ -187,63 +187,62 @@ st.latex("max(T_{math}, T_{mem})")
187
 
188
  st.markdown("We also a minimum time for executing the operation due to [kernel launch overhead](https://forums.developer.nvidia.com/t/any-way-to-measure-the-latency-of-a-kernel-launch/221413/2)")
189
 
190
- st.subheader("Operations in MHA and MQA")
 
191
 
192
- breakdown = st.checkbox("Show inference time for each operation")
193
- if breakdown:
194
- st.header('Attention layer')
195
-
196
- st.subheader('QKV projection')
197
- st.caption("Multi-Head Attention")
198
- flop, nbytes, exec_time = qkv_mha_exec(bs, h, n, d)
199
- print_kernel_execution(flop, nbytes)
200
 
201
- st.caption("Multi-Query Attention")
202
- flop, nbytes, exec_time = qkv_mqa_exec(bs, h, n, d)
203
- print_kernel_execution(flop, nbytes)
204
-
205
- st.subheader('QK gemm')
206
- st.write("Showing calculation for the maximum sequence length (n)")
207
-
208
- st.caption("Multi-Head Attention")
209
- flop, nbytes, exec_time = att1_mha_exec(bs, h, n, d)
210
- print_kernel_execution(flop, nbytes)
211
-
212
- st.caption("Multi-Query Attention")
213
- flop, nbytes, exec_time = att1_mqa_exec(bs, h, n, d)
214
- print_kernel_execution(flop, nbytes)
215
-
216
- st.subheader('Attention-value gemm')
217
- st.write("Showing calculation for the maximum sequence length (n)")
218
- st.caption("Multi-Head Attention")
219
- flop, nbytes, exec_time = att2_mha_exec(bs, h, n, d)
220
- print_kernel_execution(flop, nbytes)
221
-
222
- st.caption("Multi-Query Attention")
223
- flop, nbytes, exec_time = att2_mqa_exec(bs, h, n, d)
224
- print_kernel_execution(flop, nbytes)
225
-
226
- st.subheader('Output projection')
227
- flop, nbytes, exec_time = out_exec(bs, h, n, d)
228
- print_kernel_execution(flop, nbytes)
229
-
230
- st.subheader('Element-wise ops')
231
- st.write("We also need to take into the softmax layer, layer norm, and residual connection. We assume that these operations are memory bound. ")
232
-
233
- st.caption("Softmax")
234
- flop, nbytes, exec_time = softmax_exec(bs, h, n, d)
235
- print_kernel_execution(flop, nbytes)
236
-
237
- st.caption("Layer norm/residual connection")
238
- flop, nbytes, exec_time = ln_exec(bs, h, n, d)
239
- print_kernel_execution(flop, nbytes)
240
-
241
- st.header('MLP layer')
242
- st.subheader('First and Second Linear Layer')
243
- flop, nbytes, exec_time = mlp_exec(bs, h, n, d)
244
- print_kernel_execution(flop, nbytes)
245
-
246
- st.subheader('Element-wise ops')
247
- st.write("We also need to take into the GeLU, layer norm, and residual connection. We assume that these operations are memory bound. ")
248
- flop, nbytes, exec_time = ln_exec(bs, h, n, d)
249
- print_kernel_execution(flop, nbytes)
 
116
  c2.write(str(THREAD_OVERHEAD))
117
 
118
  st.title("Inference time MHA vs MQA")
119
+ st.write("This space approximates the inference time for Multi-Query Attention and Multi-Head Attention transformers. You can change the hyperparameters in sidebar.")
120
 
121
  mqa_total_time = 0.
122
  mha_total_time = 0.
 
187
 
188
  st.markdown("We also a minimum time for executing the operation due to [kernel launch overhead](https://forums.developer.nvidia.com/t/any-way-to-measure-the-latency-of-a-kernel-launch/221413/2)")
189
 
190
+ st.subheader("Inference time for Transformer operations")
191
+ st.text("We can now estimate the execution for each of the operations in the transformer model. I suggest you inspect the code for details on the calculations. ")
192
 
193
+ st.subheader('Attention layer')
194
+
195
+ st.markdown('**QKV projection**')
196
+ st.caption("Multi-Head Attention")
197
+ flop, nbytes, exec_time = qkv_mha_exec(bs, h, n, d)
198
+ print_kernel_execution(flop, nbytes)
 
 
199
 
200
+ st.caption("Multi-Query Attention")
201
+ flop, nbytes, exec_time = qkv_mqa_exec(bs, h, n, d)
202
+ print_kernel_execution(flop, nbytes)
203
+
204
+ st.markdown('**QK gemm**')
205
+ st.write("Showing calculation for the maximum sequence length (n)")
206
+
207
+ st.caption("Multi-Head Attention")
208
+ flop, nbytes, exec_time = att1_mha_exec(bs, h, n, d)
209
+ print_kernel_execution(flop, nbytes)
210
+
211
+ st.caption("Multi-Query Attention")
212
+ flop, nbytes, exec_time = att1_mqa_exec(bs, h, n, d)
213
+ print_kernel_execution(flop, nbytes)
214
+
215
+ st.markdown('**Attention-value gemm**')
216
+ st.write("Showing calculation for the maximum sequence length (n)")
217
+ st.caption("Multi-Head Attention")
218
+ flop, nbytes, exec_time = att2_mha_exec(bs, h, n, d)
219
+ print_kernel_execution(flop, nbytes)
220
+
221
+ st.caption("Multi-Query Attention")
222
+ flop, nbytes, exec_time = att2_mqa_exec(bs, h, n, d)
223
+ print_kernel_execution(flop, nbytes)
224
+
225
+ st.markdown('**Output projection**')
226
+ flop, nbytes, exec_time = out_exec(bs, h, n, d)
227
+ print_kernel_execution(flop, nbytes)
228
+
229
+ st.markdown('**Element-wise ops**')
230
+ st.write("We also need to take into the softmax layer, layer norm, and residual connection. We assume that these operations are memory bound. ")
231
+
232
+ st.caption("Softmax")
233
+ flop, nbytes, exec_time = softmax_exec(bs, h, n, d)
234
+ print_kernel_execution(flop, nbytes)
235
+
236
+ st.caption("Layer norm/residual connection")
237
+ flop, nbytes, exec_time = ln_exec(bs, h, n, d)
238
+ print_kernel_execution(flop, nbytes)
239
+
240
+ st.subheader('MLP layer')
241
+ st.markdown('**First and Second Linear Layer**')
242
+ flop, nbytes, exec_time = mlp_exec(bs, h, n, d)
243
+ print_kernel_execution(flop, nbytes)
244
+
245
+ st.markdown('**Element-wise ops**')
246
+ st.write("We also need to take into the GeLU, layer norm, and residual connection. We assume that these operations are memory bound. ")
247
+ flop, nbytes, exec_time = ln_exec(bs, h, n, d)
248
+ print_kernel_execution(flop, nbytes)