harmdevries commited on
Commit
a6d7fbc
·
1 Parent(s): b31a1d5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +88 -89
app.py CHANGED
@@ -117,93 +117,92 @@ for i in range(n_start, n):
117
 
118
  st.write("Multi-Head Attention: " + str(mha_total_time))
119
  st.write("Multi-Query Attention: " + str(mqa_total_time))
 
120
 
121
- st.header('Attention layer')
122
-
123
- st.subheader('QKV projection')
124
- st.caption("Multi-Head Attention")
125
- mha_flop = 2*bs*1*d*3*d
126
- mha_bytes = 2*bs*1*d + 2*3*d*d + 2*bs*1*3*d
127
- c1, c2 = st.columns([2, 3])
128
- qkv_mha_time = print_kernel_execution(c1, c2, mha_flop, mha_bytes)
129
-
130
- st.caption("Multi-Query Attention")
131
- mqa_flop = 2*bs*1*d*(1+2/h)*d
132
- mqa_bytes = 2*bs*1*d + 2*(2/h)*d*d + 2*bs*1*(2/h)*d
133
- c1, c2 = st.columns([2, 3])
134
- qkv_mqa_time = print_kernel_execution(c1, c2, mqa_flop, mqa_bytes)
135
-
136
- st.subheader('QK gemm')
137
- st.write("Note that calculation depends on sequence length (n)")
138
-
139
- st.caption("Multi-Head Attention")
140
- mha_flop = 2*bs*h*(d/h)*n
141
- mha_bytes = 2*bs*h*(d/h) + 2*bs*h*n*(d/h) + 2*bs*h*n
142
- c1, c2 = st.columns([2, 3])
143
- att1_mha_time = print_kernel_execution(c1, c2, mha_flop, mha_bytes)
144
-
145
- st.caption("Multi-Query Attention")
146
- mqa_flop = 2*bs*h*(d/h)*n
147
- mqa_bytes = 2*bs*h*(d/h) + 2*bs*n*(d/h) + 2*bs*h*n
148
- c1, c2 = st.columns([2, 3])
149
- att1_mqa_time = print_kernel_execution(c1, c2, mqa_flop, mqa_bytes)
150
-
151
- st.subheader('Attention-value gemm')
152
- st.write("Calculation depends on sequence length. We show numbers for maximum sequence length n.")
153
- st.caption("Multi-Head Attention")
154
- mha_flop = 2*bs*h*n*(d/h)
155
- mha_bytes = 2*bs*h*n + 2*bs*h*n*(d/h) + 2*bs*h*(d/h)
156
- c1, c2 = st.columns([2, 3])
157
- att2_mha_time = print_kernel_execution(c1, c2, mha_flop, mha_bytes)
158
-
159
- st.caption("Multi-Query Attention")
160
- mqa_flop = 2*bs*h*n*(d/h)
161
- mqa_bytes = 2*bs*n*(d/h) + 2*bs*n*(d/h) + 2*bs*h*(d/h)
162
- c1, c2 = st.columns([2, 3])
163
- att2_mqa_time = print_kernel_execution(c1, c2, mqa_flop, mqa_bytes)
164
-
165
- st.subheader('Output projection')
166
- out_flop = 2*bs*1*d*d
167
- out_bytes = 2*bs*1*d + 2*d*d + 2*bs*1*d
168
- c1, c2 = st.columns([2, 3])
169
- out_time = print_kernel_execution(c1, c2, out_flop, out_bytes)
170
-
171
- st.subheader('Element-wise ops')
172
- st.write("We also need to take into the softmax layer, layer norm, and residual connection. We assume that these operations are memory bound. ")
173
-
174
- st.caption("Softmax")
175
- softmax_bytes = 2*bs*h*n + 2*bs*h*n
176
- c1, c2 = st.columns([2, 3])
177
- softmax_time = print_kernel_execution(c1, c2, 0, softmax_bytes)
178
-
179
- st.caption("Layer norm/residual connection")
180
- ln_bytes = 2*bs*1*d
181
- ln_flop = 0
182
- ln_time = print_kernel_execution(c1, c2, 0, ln_bytes)
183
-
184
- st.header('MLP')
185
- st.subheader('First Linear')
186
- mlp1_flop = 2*bs*1*d*4*d
187
- mlp1_bytes = 2*bs*1*d + 2*d*4*d + 2*bs*1*4*d
188
- c1, c2 = st.columns([2, 3])
189
- mlp1_time = print_kernel_execution(c1, c2, mlp1_flop, mlp1_bytes)
190
-
191
- st.subheader('Second Linear')
192
- mlp2_flop = 2*bs*1*d*4*d
193
- mlp2_bytes = 2*bs*1*d + 2*d*4*d + 2*bs*1*4*d
194
- c1, c2 = st.columns([2, 3])
195
- mlp2_time = print_kernel_execution(c1, c2, mlp2_flop, mlp2_bytes)
196
-
197
- st.subheader('Element-wise ops')
198
- st.write("We also need to take into the GeLU, layer norm, and residual connection. We assume that these operations are memory bound. ")
199
- ln_bytes = 2*bs*1*d
200
- ln_flop = 0
201
- ln_time = print_kernel_execution(c1, c2, 0, ln_bytes)
202
-
203
- st.header("Adding it all up")
204
-
205
- shared_time = out_time + softmax_time + 2*ln_time + mlp1_time + mlp2_time + 3*ln_time
206
- mha_total_time = qkv_mha_time + att1_mha_time + att2_mha_time + shared_time
207
- mqa_total_time = qkv_mqa_time + att1_mqa_time + att2_mqa_time + shared_time
208
- st.write("MHA exec time (ms): " + str(mha_total_time))
209
- st.write("MQA exec time (ms): " + str(mqa_total_time))
 
117
 
118
  st.write("Multi-Head Attention: " + str(mha_total_time))
119
  st.write("Multi-Query Attention: " + str(mqa_total_time))
120
+ st.write("Speed-up MQA over MHA: " + str(mha_total_time/mqa_total_time))
121
 
122
+ st.header("Memory consumption")
123
+
124
+
125
+
126
+ breakdown = st.checkbox("Show breakdown per layer")
127
+ if breakdown:
128
+ st.header('Attention layer')
129
+
130
+ st.subheader('QKV projection')
131
+ st.caption("Multi-Head Attention")
132
+ mha_flop = 2*bs*1*d*3*d
133
+ mha_bytes = 2*bs*1*d + 2*3*d*d + 2*bs*1*3*d
134
+ c1, c2 = st.columns([2, 3])
135
+ qkv_mha_time = print_kernel_execution(c1, c2, mha_flop, mha_bytes)
136
+
137
+ st.caption("Multi-Query Attention")
138
+ mqa_flop = 2*bs*1*d*(1+2/h)*d
139
+ mqa_bytes = 2*bs*1*d + 2*(2/h)*d*d + 2*bs*1*(2/h)*d
140
+ c1, c2 = st.columns([2, 3])
141
+ qkv_mqa_time = print_kernel_execution(c1, c2, mqa_flop, mqa_bytes)
142
+
143
+ st.subheader('QK gemm')
144
+ st.write("Note that calculation depends on sequence length (n)")
145
+
146
+ st.caption("Multi-Head Attention")
147
+ mha_flop = 2*bs*h*(d/h)*n
148
+ mha_bytes = 2*bs*h*(d/h) + 2*bs*h*n*(d/h) + 2*bs*h*n
149
+ c1, c2 = st.columns([2, 3])
150
+ att1_mha_time = print_kernel_execution(c1, c2, mha_flop, mha_bytes)
151
+
152
+ st.caption("Multi-Query Attention")
153
+ mqa_flop = 2*bs*h*(d/h)*n
154
+ mqa_bytes = 2*bs*h*(d/h) + 2*bs*n*(d/h) + 2*bs*h*n
155
+ c1, c2 = st.columns([2, 3])
156
+ att1_mqa_time = print_kernel_execution(c1, c2, mqa_flop, mqa_bytes)
157
+
158
+ st.subheader('Attention-value gemm')
159
+ st.write("Calculation depends on sequence length. We show numbers for maximum sequence length n.")
160
+ st.caption("Multi-Head Attention")
161
+ mha_flop = 2*bs*h*n*(d/h)
162
+ mha_bytes = 2*bs*h*n + 2*bs*h*n*(d/h) + 2*bs*h*(d/h)
163
+ c1, c2 = st.columns([2, 3])
164
+ att2_mha_time = print_kernel_execution(c1, c2, mha_flop, mha_bytes)
165
+
166
+ st.caption("Multi-Query Attention")
167
+ mqa_flop = 2*bs*h*n*(d/h)
168
+ mqa_bytes = 2*bs*n*(d/h) + 2*bs*n*(d/h) + 2*bs*h*(d/h)
169
+ c1, c2 = st.columns([2, 3])
170
+ att2_mqa_time = print_kernel_execution(c1, c2, mqa_flop, mqa_bytes)
171
+
172
+ st.subheader('Output projection')
173
+ out_flop = 2*bs*1*d*d
174
+ out_bytes = 2*bs*1*d + 2*d*d + 2*bs*1*d
175
+ c1, c2 = st.columns([2, 3])
176
+ out_time = print_kernel_execution(c1, c2, out_flop, out_bytes)
177
+
178
+ st.subheader('Element-wise ops')
179
+ st.write("We also need to take into the softmax layer, layer norm, and residual connection. We assume that these operations are memory bound. ")
180
+
181
+ st.caption("Softmax")
182
+ softmax_bytes = 2*bs*h*n + 2*bs*h*n
183
+ c1, c2 = st.columns([2, 3])
184
+ softmax_time = print_kernel_execution(c1, c2, 0, softmax_bytes)
185
+
186
+ st.caption("Layer norm/residual connection")
187
+ ln_bytes = 2*bs*1*d
188
+ ln_flop = 0
189
+ ln_time = print_kernel_execution(c1, c2, 0, ln_bytes)
190
+
191
+ st.header('MLP')
192
+ st.subheader('First Linear')
193
+ mlp1_flop = 2*bs*1*d*4*d
194
+ mlp1_bytes = 2*bs*1*d + 2*d*4*d + 2*bs*1*4*d
195
+ c1, c2 = st.columns([2, 3])
196
+ mlp1_time = print_kernel_execution(c1, c2, mlp1_flop, mlp1_bytes)
197
+
198
+ st.subheader('Second Linear')
199
+ mlp2_flop = 2*bs*1*d*4*d
200
+ mlp2_bytes = 2*bs*1*d + 2*d*4*d + 2*bs*1*4*d
201
+ c1, c2 = st.columns([2, 3])
202
+ mlp2_time = print_kernel_execution(c1, c2, mlp2_flop, mlp2_bytes)
203
+
204
+ st.subheader('Element-wise ops')
205
+ st.write("We also need to take into the GeLU, layer norm, and residual connection. We assume that these operations are memory bound. ")
206
+ ln_bytes = 2*bs*1*d
207
+ ln_flop = 0
208
+ ln_time = print_kernel_execution(c1, c2, 0, ln_bytes)