harmdevries commited on
Commit
4bb67b3
·
1 Parent(s): 06c81be

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -3
app.py CHANGED
@@ -28,6 +28,9 @@ else:
28
  # in ms
29
 
30
  THREAD_OVERHEAD = st.sidebar.number_input('Thread overhead (in ms)', format="%.3f", value=0.005)
 
 
 
31
 
32
  # in ms
33
  def calc_exec_time(comp_flop, mem_bytes, include_overhead=True):
@@ -112,7 +115,8 @@ def print_kernel_execution(flop, nbytes):
112
  c1.write("Overhead (ms):")
113
  c2.write(str(THREAD_OVERHEAD))
114
 
115
- st.header("Execution time (ms)")
 
116
 
117
  mqa_total_time = 0.
118
  mha_total_time = 0.
@@ -134,7 +138,7 @@ c1.write("Speed-up MQA over MHA: ")
134
  c2.write(str(round(mha_total_time/mqa_total_time,2)))
135
 
136
  st.header("Memory consumption")
137
- st.caption("MHA")
138
  c1, c2 = st.columns([2, 4])
139
  num_params = 12*l*d*d
140
  c1.write("Num Parameters (in B)")
@@ -145,7 +149,7 @@ c1.write("Cached keys and values (GB)")
145
  acts = round(2*bs*l*(d/h)*h*2*n/1e9, 2)
146
  c2.write(str(acts))
147
 
148
- st.caption("MQA")
149
  c1, c2 = st.columns([2, 4])
150
  num_params = (10+2/h)*l*d*d
151
  c1.write("Num Parameters (in B)")
@@ -156,6 +160,8 @@ c1.write("Cached keys and values (GB)")
156
  acts = round(2*bs*l*(d/h)*2*n/1e9, 2)
157
  c2.write(str(acts))
158
 
 
 
159
 
160
  breakdown = st.checkbox("Show breakdown per operation")
161
  if breakdown:
 
28
  # in ms
29
 
30
  THREAD_OVERHEAD = st.sidebar.number_input('Thread overhead (in ms)', format="%.3f", value=0.005)
31
+ GPU_EFFICIENCY = st.sidebar.number_input('GPU efficiency', format="%.3f", value=0.5)
32
+
33
+ TFLOPS = GPU_EFFICIENCY*TFLOPS
34
 
35
  # in ms
36
  def calc_exec_time(comp_flop, mem_bytes, include_overhead=True):
 
115
  c1.write("Overhead (ms):")
116
  c2.write(str(THREAD_OVERHEAD))
117
 
118
+ st.title("Inference time MHA vs MQA")
119
+ st.write("This space approximates the inference time for Multi-Query Attention and Multi-Head Attention model. You can change the hyperparameters in sidebar.")
120
 
121
  mqa_total_time = 0.
122
  mha_total_time = 0.
 
138
  c2.write(str(round(mha_total_time/mqa_total_time,2)))
139
 
140
  st.header("Memory consumption")
141
+ st.subheader("MHA")
142
  c1, c2 = st.columns([2, 4])
143
  num_params = 12*l*d*d
144
  c1.write("Num Parameters (in B)")
 
149
  acts = round(2*bs*l*(d/h)*h*2*n/1e9, 2)
150
  c2.write(str(acts))
151
 
152
+ st.subheader("MQA")
153
  c1, c2 = st.columns([2, 4])
154
  num_params = (10+2/h)*l*d*d
155
  c1.write("Num Parameters (in B)")
 
160
  acts = round(2*bs*l*(d/h)*2*n/1e9, 2)
161
  c2.write(str(acts))
162
 
163
+ st.subheader("How we approximate the inference time")
164
+ st.write("TODO")
165
 
166
  breakdown = st.checkbox("Show breakdown per operation")
167
  if breakdown: