Spaces:
Runtime error
Runtime error
Commit
·
52a11ab
1
Parent(s):
82370ff
Update app.py
Browse files
app.py
CHANGED
@@ -134,7 +134,7 @@ c1.write("Multi-Head Attention:")
|
|
134 |
c2.write(str(round(mha_total_time, 2)))
|
135 |
c1.write("Multi-Query Attention:")
|
136 |
c2.write(str(round(mqa_total_time, 2)))
|
137 |
-
c1.write("Speed-up MQA over MHA:
|
138 |
c2.write(str(round(mha_total_time/mqa_total_time,2)))
|
139 |
|
140 |
st.subheader("Memory consumption")
|
@@ -161,7 +161,7 @@ acts = round(2*bs*l*(d/h)*2*n/1e9, 2)
|
|
161 |
c2.write(str(acts))
|
162 |
|
163 |
st.subheader("Approximations")
|
164 |
-
st.markdown("
|
165 |
|
166 |
st.latex("C = A \cdot B")
|
167 |
st.latex("A \in \mathbb{R}^{MxK}, B \in R^{KxN}, C \in \mathbb{R}^{MxN}")
|
@@ -173,11 +173,6 @@ To execute this operation on the GPU, we need to
|
|
173 |
3. Write C to memory
|
174 |
''')
|
175 |
|
176 |
-
st.latex('''
|
177 |
-
For float16 operations (2 bytes), we can estimate the memory access time of A as follows:
|
178 |
-
T_mem(A) = 2*M*K / BW_mem
|
179 |
-
where BW_mem is the memory bandwidth of the GPU (e.g. 1935 GB/s for A100)
|
180 |
-
''')
|
181 |
|
182 |
st.latex('''
|
183 |
For float16 operations (2 bytes), we can estimate the memory access time of A as follows:
|
@@ -185,6 +180,9 @@ T_mem(A) = 2*M*K / BW_mem
|
|
185 |
where BW_mem is the memory bandwidth of the GPU (e.g. 1935 GB/s for A100)
|
186 |
''')
|
187 |
|
|
|
|
|
|
|
188 |
|
189 |
|
190 |
|
|
|
134 |
c2.write(str(round(mha_total_time, 2)))
|
135 |
c1.write("Multi-Query Attention:")
|
136 |
c2.write(str(round(mqa_total_time, 2)))
|
137 |
+
c1.write("Speed-up MQA over MHA:")
|
138 |
c2.write(str(round(mha_total_time/mqa_total_time,2)))
|
139 |
|
140 |
st.subheader("Memory consumption")
|
|
|
161 |
c2.write(str(acts))
|
162 |
|
163 |
st.subheader("Approximations")
|
164 |
+
st.markdown("We use the [following crude approximation](https://docs.nvidia.com/deeplearning/performance/dl-performance-gpu-background/index.html#understand-perf) to estimate the execution time for each matrix multiplication.")
|
165 |
|
166 |
st.latex("C = A \cdot B")
|
167 |
st.latex("A \in \mathbb{R}^{MxK}, B \in R^{KxN}, C \in \mathbb{R}^{MxN}")
|
|
|
173 |
3. Write C to memory
|
174 |
''')
|
175 |
|
|
|
|
|
|
|
|
|
|
|
176 |
|
177 |
st.latex('''
|
178 |
For float16 operations (2 bytes), we can estimate the memory access time of A as follows:
|
|
|
180 |
where BW_mem is the memory bandwidth of the GPU (e.g. 1935 GB/s for A100)
|
181 |
''')
|
182 |
|
183 |
+
st.markdown("For float16 operations (2 bytes), we can estimate the memory access time of A as follows:")
|
184 |
+
st.latex("T_{mem}(A) = 2*M*K / BW_{mem}")
|
185 |
+
st.markdown("where BW_mem is the memory bandwidth of the GPU (e.g. 1935 GB/s for A100)")
|
186 |
|
187 |
|
188 |
|