Spaces:
Runtime error
Runtime error
Commit
·
c88286f
1
Parent(s):
064a5f0
Update app.py
Browse files
app.py
CHANGED
@@ -116,7 +116,7 @@ def print_kernel_execution(flop, nbytes):
|
|
116 |
c2.write(str(THREAD_OVERHEAD))
|
117 |
|
118 |
st.title("Inference time MHA vs MQA")
|
119 |
-
st.write("This space approximates the inference time for Multi-Query Attention and Multi-Head Attention
|
120 |
|
121 |
mqa_total_time = 0.
|
122 |
mha_total_time = 0.
|
@@ -187,63 +187,62 @@ st.latex("max(T_{math}, T_{mem})")
|
|
187 |
|
188 |
st.markdown("We also a minimum time for executing the operation due to [kernel launch overhead](https://forums.developer.nvidia.com/t/any-way-to-measure-the-latency-of-a-kernel-launch/221413/2)")
|
189 |
|
190 |
-
st.subheader("
|
|
|
191 |
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
flop, nbytes, exec_time = qkv_mha_exec(bs, h, n, d)
|
199 |
-
print_kernel_execution(flop, nbytes)
|
200 |
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
-
|
215 |
-
|
216 |
-
|
217 |
-
|
218 |
-
|
219 |
-
|
220 |
-
|
221 |
-
|
222 |
-
|
223 |
-
|
224 |
-
|
225 |
-
|
226 |
-
|
227 |
-
|
228 |
-
|
229 |
-
|
230 |
-
|
231 |
-
|
232 |
-
|
233 |
-
|
234 |
-
|
235 |
-
|
236 |
-
|
237 |
-
|
238 |
-
|
239 |
-
|
240 |
-
|
241 |
-
|
242 |
-
|
243 |
-
|
244 |
-
|
245 |
-
|
246 |
-
|
247 |
-
|
248 |
-
|
249 |
-
|
|
|
116 |
c2.write(str(THREAD_OVERHEAD))
|
117 |
|
118 |
st.title("Inference time MHA vs MQA")
|
119 |
+
st.write("This space approximates the inference time for Multi-Query Attention and Multi-Head Attention transformers. You can change the hyperparameters in sidebar.")
|
120 |
|
121 |
mqa_total_time = 0.
|
122 |
mha_total_time = 0.
|
|
|
187 |
|
188 |
st.markdown("We also a minimum time for executing the operation due to [kernel launch overhead](https://forums.developer.nvidia.com/t/any-way-to-measure-the-latency-of-a-kernel-launch/221413/2)")
|
189 |
|
190 |
+
st.subheader("Inference time for Transformer operations")
|
191 |
+
st.text("We can now estimate the execution for each of the operations in the transformer model. I suggest you inspect the code for details on the calculations. ")
|
192 |
|
193 |
+
st.subheader('Attention layer')
|
194 |
+
|
195 |
+
st.markdown('**QKV projection**')
|
196 |
+
st.caption("Multi-Head Attention")
|
197 |
+
flop, nbytes, exec_time = qkv_mha_exec(bs, h, n, d)
|
198 |
+
print_kernel_execution(flop, nbytes)
|
|
|
|
|
199 |
|
200 |
+
st.caption("Multi-Query Attention")
|
201 |
+
flop, nbytes, exec_time = qkv_mqa_exec(bs, h, n, d)
|
202 |
+
print_kernel_execution(flop, nbytes)
|
203 |
+
|
204 |
+
st.markdown('**QK gemm**')
|
205 |
+
st.write("Showing calculation for the maximum sequence length (n)")
|
206 |
+
|
207 |
+
st.caption("Multi-Head Attention")
|
208 |
+
flop, nbytes, exec_time = att1_mha_exec(bs, h, n, d)
|
209 |
+
print_kernel_execution(flop, nbytes)
|
210 |
+
|
211 |
+
st.caption("Multi-Query Attention")
|
212 |
+
flop, nbytes, exec_time = att1_mqa_exec(bs, h, n, d)
|
213 |
+
print_kernel_execution(flop, nbytes)
|
214 |
+
|
215 |
+
st.markdown('**Attention-value gemm**')
|
216 |
+
st.write("Showing calculation for the maximum sequence length (n)")
|
217 |
+
st.caption("Multi-Head Attention")
|
218 |
+
flop, nbytes, exec_time = att2_mha_exec(bs, h, n, d)
|
219 |
+
print_kernel_execution(flop, nbytes)
|
220 |
+
|
221 |
+
st.caption("Multi-Query Attention")
|
222 |
+
flop, nbytes, exec_time = att2_mqa_exec(bs, h, n, d)
|
223 |
+
print_kernel_execution(flop, nbytes)
|
224 |
+
|
225 |
+
st.markdown('**Output projection**')
|
226 |
+
flop, nbytes, exec_time = out_exec(bs, h, n, d)
|
227 |
+
print_kernel_execution(flop, nbytes)
|
228 |
+
|
229 |
+
st.markdown('**Element-wise ops**')
|
230 |
+
st.write("We also need to take into the softmax layer, layer norm, and residual connection. We assume that these operations are memory bound. ")
|
231 |
+
|
232 |
+
st.caption("Softmax")
|
233 |
+
flop, nbytes, exec_time = softmax_exec(bs, h, n, d)
|
234 |
+
print_kernel_execution(flop, nbytes)
|
235 |
+
|
236 |
+
st.caption("Layer norm/residual connection")
|
237 |
+
flop, nbytes, exec_time = ln_exec(bs, h, n, d)
|
238 |
+
print_kernel_execution(flop, nbytes)
|
239 |
+
|
240 |
+
st.subheader('MLP layer')
|
241 |
+
st.markdown('**First and Second Linear Layer**')
|
242 |
+
flop, nbytes, exec_time = mlp_exec(bs, h, n, d)
|
243 |
+
print_kernel_execution(flop, nbytes)
|
244 |
+
|
245 |
+
st.markdown('**Element-wise ops**')
|
246 |
+
st.write("We also need to take into the GeLU, layer norm, and residual connection. We assume that these operations are memory bound. ")
|
247 |
+
flop, nbytes, exec_time = ln_exec(bs, h, n, d)
|
248 |
+
print_kernel_execution(flop, nbytes)
|