harmdevries commited on
Commit
2a3864d
·
1 Parent(s): 729a063

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -3
app.py CHANGED
@@ -53,7 +53,7 @@ st.caption("Multi-Query Attention")
53
  mqa_flop = 2*bs*1*d*(1+2/h)*d
54
  mqa_bytes = 2*bs*1*d + 2*(2/h)*d*d + 2*bs*1*(2/h)*d
55
  c1, c2 = st.columns([2, 3])
56
- qkv_mha_time = print_kernel_execution(c1, c2, mqa_flop, mqa_bytes)
57
 
58
  st.subheader('QK gemm')
59
  st.write("Note that calculation depends on sequence length (n)")
@@ -101,7 +101,7 @@ softmax_time = print_kernel_execution(c1, c2, 0, softmax_bytes)
101
  st.caption("Layer norm/residual connection")
102
  ln_bytes = 2*bs*1*d
103
  ln_flop = 0
104
- ln_time = print_kernel_execution(c1, c2, 0, softmax_bytes)
105
 
106
  st.header('MLP')
107
  st.subheader('First Linear')
@@ -120,7 +120,7 @@ st.subheader('Element-wise ops')
120
  st.write("We also need to take into the GeLU, layer norm, and residual connection. We assume that these operations are memory bound. ")
121
  ln_bytes = 2*bs*1*d
122
  ln_flop = 0
123
- ln_time = print_kernel_execution(c1, c2, 0, softmax_bytes)
124
 
125
  st.header("Adding it all up")
126
 
 
53
  mqa_flop = 2*bs*1*d*(1+2/h)*d
54
  mqa_bytes = 2*bs*1*d + 2*(2/h)*d*d + 2*bs*1*(2/h)*d
55
  c1, c2 = st.columns([2, 3])
56
+ qkv_mqa_time = print_kernel_execution(c1, c2, mqa_flop, mqa_bytes)
57
 
58
  st.subheader('QK gemm')
59
  st.write("Note that calculation depends on sequence length (n)")
 
101
  st.caption("Layer norm/residual connection")
102
  ln_bytes = 2*bs*1*d
103
  ln_flop = 0
104
+ ln_time = print_kernel_execution(c1, c2, 0, ln_bytes)
105
 
106
  st.header('MLP')
107
  st.subheader('First Linear')
 
120
  st.write("We also need to take into the GeLU, layer norm, and residual connection. We assume that these operations are memory bound. ")
121
  ln_bytes = 2*bs*1*d
122
  ln_flop = 0
123
+ ln_time = print_kernel_execution(c1, c2, 0, ln_bytes)
124
 
125
  st.header("Adding it all up")
126