Spaces:
Runtime error
Runtime error
Commit
·
86cd028
1
Parent(s):
f0d14cb
Update app.py
Browse files
app.py
CHANGED
@@ -6,9 +6,12 @@ def number_field(label, **kwargs):
|
|
6 |
|
7 |
return c2.number_input('', **kwargs)
|
8 |
|
|
|
|
|
|
|
9 |
def print_kernel_execution(c1, c2, comp_flop, mem_bytes):
|
10 |
arith_int = comp_flop/mem_bytes
|
11 |
-
exec_time = (comp_flop
|
12 |
|
13 |
comp_flop = round(comp_flop/1e9, 2)
|
14 |
mem_bytes = round(mem_bytes/1e6, 2)
|
@@ -82,13 +85,21 @@ c1, c2 = st.columns([2, 3])
|
|
82 |
att_mqa_time = print_kernel_execution(c1, c2, mqa_flop, mqa_bytes)
|
83 |
|
84 |
st.subheader('Output projection')
|
85 |
-
|
86 |
-
|
87 |
c1, c2 = st.columns([2, 3])
|
88 |
-
|
89 |
|
90 |
st.subheader('Element-wise ops')
|
91 |
-
st.write("
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
92 |
|
93 |
st.header('MLP')
|
94 |
st.subheader('First Linear')
|
|
|
6 |
|
7 |
return c2.number_input('', **kwargs)
|
8 |
|
9 |
+
def calc_exec_time(comp_flop, mem_bytes):
|
10 |
+
return (comp_flop/TFLOPS + mem_bytes/GB_S)*1000
|
11 |
+
|
12 |
def print_kernel_execution(c1, c2, comp_flop, mem_bytes):
|
13 |
arith_int = comp_flop/mem_bytes
|
14 |
+
exec_time = calc_exec_time(comp_flop, mem_bytes)
|
15 |
|
16 |
comp_flop = round(comp_flop/1e9, 2)
|
17 |
mem_bytes = round(mem_bytes/1e6, 2)
|
|
|
85 |
att_mqa_time = print_kernel_execution(c1, c2, mqa_flop, mqa_bytes)
|
86 |
|
87 |
st.subheader('Output projection')
|
88 |
+
out_flop = 2*bs*1*d*d
|
89 |
+
out_bytes = 2*bs*1*d + 2*d*d + 2*bs*1*d
|
90 |
c1, c2 = st.columns([2, 3])
|
91 |
+
out_time = print_kernel_execution(c1, c2, out_flop, out_bytes)
|
92 |
|
93 |
st.subheader('Element-wise ops')
|
94 |
+
st.write("We also need to take into the softmax layer and layer norm")
|
95 |
+
|
96 |
+
st.caption("Softmax")
|
97 |
+
softmax_bytes = 2*bs*h*n + 2*bs*h*n
|
98 |
+
c1, c2 = st.columns([2, 3])
|
99 |
+
softmax_time = print_kernel_execution(c1, c2, 0, softmax_bytes))
|
100 |
+
|
101 |
+
st.caption("Layer norm")
|
102 |
+
|
103 |
|
104 |
st.header('MLP')
|
105 |
st.subheader('First Linear')
|