harmdevries commited on
Commit
3ddfba1
·
1 Parent(s): 5e8eef3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -4
app.py CHANGED
@@ -115,13 +115,24 @@ for i in range(n_start, n):
115
  mqa_time = shared_time + qkv_mqa_exec(bs, h, i, d)[2] + att1_mqa_exec(bs, h, i, d)[2] + att2_mqa_exec(bs, h, i, d)[2]
116
  mqa_total_time += l*mqa_time
117
 
118
- st.write("Multi-Head Attention: " + str(mha_total_time))
119
- st.write("Multi-Query Attention: " + str(mqa_total_time))
120
- st.write("Speed-up MQA over MHA: " + str(mha_total_time/mqa_total_time))
 
 
 
 
121
 
122
  st.header("Memory consumption")
 
 
123
  num_params = 12*l*d*d
124
- st.write("Num Parameters (in B)" + str(round(num_params/1e9, 3)))
 
 
 
 
 
125
 
126
 
127
  breakdown = st.checkbox("Show breakdown per operation")
 
115
  mqa_time = shared_time + qkv_mqa_exec(bs, h, i, d)[2] + att1_mqa_exec(bs, h, i, d)[2] + att2_mqa_exec(bs, h, i, d)[2]
116
  mqa_total_time += l*mqa_time
117
 
118
+ c1, c2 = st.columns(2, 4)
119
+ c1.write("Multi-Head Attention:")
120
+ c2.write(str(round(mha_total_time, 2)))
121
+ c1.write("Multi-Query Attention:")
122
+ c2.write(str(round(mqa_total_time, 2)))
123
+ c1.write("Speed-up MQA over MHA: ")
124
+ c2.write(str(round(mha_total_time/mqa_total_time),2))
125
 
126
  st.header("Memory consumption")
127
+ st.caption("MHA")
128
+ c1, c2 = st.columns(2, 4)
129
  num_params = 12*l*d*d
130
+ c1.write("Num Parameters (in B)")
131
+ c2.write(str(round(num_params/1e9, 3)))
132
+ c1.write("Storing activations")
133
+ acts = round(2*l*(d/h)*h*n/1e9, 2)
134
+ c2.write(str(acts))
135
+
136
 
137
 
138
  breakdown = st.checkbox("Show breakdown per operation")