harmdevries commited on
Commit
992208b
·
1 Parent(s): a6d7fbc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -17
app.py CHANGED
@@ -123,7 +123,7 @@ st.header("Memory consumption")
123
 
124
 
125
 
126
- breakdown = st.checkbox("Show breakdown per layer")
127
  if breakdown:
128
  st.header('Attention layer')
129
 
@@ -141,7 +141,7 @@ if breakdown:
141
  qkv_mqa_time = print_kernel_execution(c1, c2, mqa_flop, mqa_bytes)
142
 
143
  st.subheader('QK gemm')
144
- st.write("Note that calculation depends on sequence length (n)")
145
 
146
  st.caption("Multi-Head Attention")
147
  mha_flop = 2*bs*h*(d/h)*n
@@ -156,7 +156,7 @@ if breakdown:
156
  att1_mqa_time = print_kernel_execution(c1, c2, mqa_flop, mqa_bytes)
157
 
158
  st.subheader('Attention-value gemm')
159
- st.write("Calculation depends on sequence length. We show numbers for maximum sequence length n.")
160
  st.caption("Multi-Head Attention")
161
  mha_flop = 2*bs*h*n*(d/h)
162
  mha_bytes = 2*bs*h*n + 2*bs*h*n*(d/h) + 2*bs*h*(d/h)
@@ -188,21 +188,14 @@ if breakdown:
188
  ln_flop = 0
189
  ln_time = print_kernel_execution(c1, c2, 0, ln_bytes)
190
 
191
- st.header('MLP')
192
- st.subheader('First Linear')
193
- mlp1_flop = 2*bs*1*d*4*d
194
- mlp1_bytes = 2*bs*1*d + 2*d*4*d + 2*bs*1*4*d
195
  c1, c2 = st.columns([2, 3])
196
- mlp1_time = print_kernel_execution(c1, c2, mlp1_flop, mlp1_bytes)
197
-
198
- st.subheader('Second Linear')
199
- mlp2_flop = 2*bs*1*d*4*d
200
- mlp2_bytes = 2*bs*1*d + 2*d*4*d + 2*bs*1*4*d
201
- c1, c2 = st.columns([2, 3])
202
- mlp2_time = print_kernel_execution(c1, c2, mlp2_flop, mlp2_bytes)
203
 
204
  st.subheader('Element-wise ops')
205
  st.write("We also need to take into the GeLU, layer norm, and residual connection. We assume that these operations are memory bound. ")
206
- ln_bytes = 2*bs*1*d
207
- ln_flop = 0
208
- ln_time = print_kernel_execution(c1, c2, 0, ln_bytes)
 
123
 
124
 
125
 
126
+ breakdown = st.checkbox("Show breakdown per operation")
127
  if breakdown:
128
  st.header('Attention layer')
129
 
 
141
  qkv_mqa_time = print_kernel_execution(c1, c2, mqa_flop, mqa_bytes)
142
 
143
  st.subheader('QK gemm')
144
+ st.write("Showing calculation for the maximum sequence length (n)")
145
 
146
  st.caption("Multi-Head Attention")
147
  mha_flop = 2*bs*h*(d/h)*n
 
156
  att1_mqa_time = print_kernel_execution(c1, c2, mqa_flop, mqa_bytes)
157
 
158
  st.subheader('Attention-value gemm')
159
+ st.write("Showing calculation for the maximum sequence length (n)")
160
  st.caption("Multi-Head Attention")
161
  mha_flop = 2*bs*h*n*(d/h)
162
  mha_bytes = 2*bs*h*n + 2*bs*h*n*(d/h) + 2*bs*h*(d/h)
 
188
  ln_flop = 0
189
  ln_time = print_kernel_execution(c1, c2, 0, ln_bytes)
190
 
191
+ st.header('MLP layer')
192
+ st.subheader('First and Second Linear Layer')
193
+ flop, nbytes, exec_time = mlp_exec(bs, h, n, d)
 
194
  c1, c2 = st.columns([2, 3])
195
+ mlp2_time = print_kernel_execution(c1, c2, flop, nbytes)
 
 
 
 
 
 
196
 
197
  st.subheader('Element-wise ops')
198
  st.write("We also need to take into the GeLU, layer norm, and residual connection. We assume that these operations are memory bound. ")
199
+ flop, nbytes, exec_time = ln_exec(bs, h, n, d)
200
+ c1, c2 = st.columns([2, 3])
201
+ mlp2_time = print_kernel_execution(c1, c2, flop, nbytes)