harmdevries commited on
Commit
b31a1d5
·
1 Parent(s): dbafc77

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -4
app.py CHANGED
@@ -7,8 +7,10 @@ GB_S = 1935e9
7
  # in ms
8
  THREAD_OVERHEAD = 0.005
9
 
 
10
  def calc_exec_time(comp_flop, mem_bytes):
11
- return (comp_flop/TFLOPS + mem_bytes/GB_S)*1000
 
12
 
13
  def qkv_mha_exec(bs, h, n, d):
14
  flop = 2*bs*1*d*3*d
@@ -100,7 +102,7 @@ n = st.sidebar.number_input('End seq', value=1024)
100
  st.sidebar.header("GPU parameters")
101
 
102
 
103
- st.header("Total execution time")
104
 
105
  mqa_total_time = 0.
106
  mha_total_time = 0.
@@ -113,8 +115,8 @@ for i in range(n_start, n):
113
  mqa_time = shared_time + qkv_mqa_exec(bs, h, i, d)[2] + att1_mqa_exec(bs, h, i, d)[2] + att2_mqa_exec(bs, h, i, d)[2]
114
  mqa_total_time += l*mqa_time
115
 
116
- st.write("MHA exec time: " + str(mha_total_time))
117
- st.write("MQA exec time: " + str(mqa_total_time))
118
 
119
  st.header('Attention layer')
120
 
 
7
  # in ms
8
  THREAD_OVERHEAD = 0.005
9
 
10
+ # in ms
11
  def calc_exec_time(comp_flop, mem_bytes):
12
+ exec_time = comp_flop/TFLOPS + mem_bytes/GB_S
13
+ return max(exec_time*1000, THREAD_OVERHEAD)
14
 
15
  def qkv_mha_exec(bs, h, n, d):
16
  flop = 2*bs*1*d*3*d
 
102
  st.sidebar.header("GPU parameters")
103
 
104
 
105
+ st.header("Execution time (ms)")
106
 
107
  mqa_total_time = 0.
108
  mha_total_time = 0.
 
115
  mqa_time = shared_time + qkv_mqa_exec(bs, h, i, d)[2] + att1_mqa_exec(bs, h, i, d)[2] + att2_mqa_exec(bs, h, i, d)[2]
116
  mqa_total_time += l*mqa_time
117
 
118
+ st.write("Multi-Head Attention: " + str(mha_total_time))
119
+ st.write("Multi-Query Attention: " + str(mqa_total_time))
120
 
121
  st.header('Attention layer')
122