Spaces:
Runtime error
Runtime error
Commit
·
a275f69
1
Parent(s):
ea57214
Update app.py
Browse files
app.py
CHANGED
@@ -35,51 +35,54 @@ h = number_field('Num heads', value=16)
|
|
35 |
d = number_field('Dimension', value=768)
|
36 |
n_start = number_field('Start seq', value=1)
|
37 |
n = number_field('End seq', value=1024)
|
38 |
-
l = number_field('Num layers', value=
|
39 |
|
40 |
-
st.header('
|
41 |
|
42 |
-
st.subheader(
|
|
|
43 |
mha_flop = 2*bs*1*d*3*d
|
44 |
mha_bytes = 2*bs*1*d + 2*3*d*d + 2*bs*1*3*d
|
45 |
c1, c2 = st.columns([2, 3])
|
46 |
qkv_mha_time = print_kernel_execution(c1, c2, mha_flop, mha_bytes)
|
47 |
|
48 |
-
st.
|
49 |
mqa_flop = 2*bs*1*d*(1+2/h)*d
|
50 |
mqa_bytes = 2*bs*1*d + 2*(2/h)*d*d + 2*bs*1*(2/h)*d
|
51 |
c1, c2 = st.columns([2, 3])
|
52 |
qkv_mha_time = print_kernel_execution(c1, c2, mqa_flop, mqa_bytes)
|
53 |
|
54 |
-
st.
|
55 |
-
st.write("
|
56 |
|
57 |
-
st.
|
58 |
mha_flop = 2*bs*h*(d/h)*n
|
59 |
mha_bytes = 2*bs*h*(d/h) + 2*bs*h*n*(d/h) + 2*bs*h*n
|
60 |
c1, c2 = st.columns([2, 3])
|
61 |
att1_mha_time = print_kernel_execution(c1, c2, mha_flop, mha_bytes)
|
62 |
|
63 |
-
st.
|
64 |
mqa_flop = 2*bs*h*(d/h)*n
|
65 |
mqa_bytes = 2*bs*h*(d/h) + 2*bs*n*(d/h) + 2*bs*h*n
|
66 |
c1, c2 = st.columns([2, 3])
|
67 |
att1_mqa_time = print_kernel_execution(c1, c2, mqa_flop, mqa_bytes)
|
68 |
|
69 |
-
st.header('Attention
|
70 |
st.write("Calculation depends on sequence length. We show numbers for maximum sequence length n.")
|
71 |
-
st.
|
72 |
mha_flop = 2*bs*h*n*(d/h)
|
73 |
mha_bytes = 2*bs*h*n + 2*bs*h*n*(d/h) + 2*bs*h*(d/h)
|
74 |
c1, c2 = st.columns([2, 3])
|
75 |
att_mha_time = print_kernel_execution(c1, c2, mha_flop, mha_bytes)
|
76 |
|
77 |
-
st.
|
78 |
mqa_flop = 2*bs*h*n*(d/h)
|
79 |
mqa_bytes = 2*bs*n*(d/h) + 2*bs*n*(d/h) + 2*bs*h*(d/h)
|
80 |
c1, c2 = st.columns([2, 3])
|
81 |
att_mqa_time = print_kernel_execution(c1, c2, mqa_flop, mqa_bytes)
|
82 |
|
|
|
|
|
83 |
st.header('MLP')
|
84 |
st.subheader('First Linear')
|
85 |
mlp1_flop = 2*bs*1*d*4*d
|
|
|
35 |
d = number_field('Dimension', value=768)
|
36 |
n_start = number_field('Start seq', value=1)
|
37 |
n = number_field('End seq', value=1024)
|
38 |
+
l = number_field('Num layers', value=24)
|
39 |
|
40 |
+
st.header('Attention layer')
|
41 |
|
42 |
+
st.subheader('QKV projection')
|
43 |
+
st.caption("Multi-Head Attention")
|
44 |
mha_flop = 2*bs*1*d*3*d
|
45 |
mha_bytes = 2*bs*1*d + 2*3*d*d + 2*bs*1*3*d
|
46 |
c1, c2 = st.columns([2, 3])
|
47 |
qkv_mha_time = print_kernel_execution(c1, c2, mha_flop, mha_bytes)
|
48 |
|
49 |
+
st.caption("Multi-Query Attention")
|
50 |
mqa_flop = 2*bs*1*d*(1+2/h)*d
|
51 |
mqa_bytes = 2*bs*1*d + 2*(2/h)*d*d + 2*bs*1*(2/h)*d
|
52 |
c1, c2 = st.columns([2, 3])
|
53 |
qkv_mha_time = print_kernel_execution(c1, c2, mqa_flop, mqa_bytes)
|
54 |
|
55 |
+
st.subheader('QK gemm')
|
56 |
+
st.write("Note that calculation depends on sequence length (n)")
|
57 |
|
58 |
+
st.caption("Multi-Head Attention")
|
59 |
mha_flop = 2*bs*h*(d/h)*n
|
60 |
mha_bytes = 2*bs*h*(d/h) + 2*bs*h*n*(d/h) + 2*bs*h*n
|
61 |
c1, c2 = st.columns([2, 3])
|
62 |
att1_mha_time = print_kernel_execution(c1, c2, mha_flop, mha_bytes)
|
63 |
|
64 |
+
st.caption("Multi-Query Attention")
|
65 |
mqa_flop = 2*bs*h*(d/h)*n
|
66 |
mqa_bytes = 2*bs*h*(d/h) + 2*bs*n*(d/h) + 2*bs*h*n
|
67 |
c1, c2 = st.columns([2, 3])
|
68 |
att1_mqa_time = print_kernel_execution(c1, c2, mqa_flop, mqa_bytes)
|
69 |
|
70 |
+
st.header('Attention-value gemm')
|
71 |
st.write("Calculation depends on sequence length. We show numbers for maximum sequence length n.")
|
72 |
+
st.caption("Multi-Head Attention")
|
73 |
mha_flop = 2*bs*h*n*(d/h)
|
74 |
mha_bytes = 2*bs*h*n + 2*bs*h*n*(d/h) + 2*bs*h*(d/h)
|
75 |
c1, c2 = st.columns([2, 3])
|
76 |
att_mha_time = print_kernel_execution(c1, c2, mha_flop, mha_bytes)
|
77 |
|
78 |
+
st.caption("Multi-Query Attention")
|
79 |
mqa_flop = 2*bs*h*n*(d/h)
|
80 |
mqa_bytes = 2*bs*n*(d/h) + 2*bs*n*(d/h) + 2*bs*h*(d/h)
|
81 |
c1, c2 = st.columns([2, 3])
|
82 |
att_mqa_time = print_kernel_execution(c1, c2, mqa_flop, mqa_bytes)
|
83 |
|
84 |
+
st.subheader('Output projection')
|
85 |
+
|
86 |
st.header('MLP')
|
87 |
st.subheader('First Linear')
|
88 |
mlp1_flop = 2*bs*1*d*4*d
|