Spaces:
Runtime error
Runtime error
Commit
·
8cbefab
1
Parent(s):
54cd0e6
Update app.py
Browse files
app.py
CHANGED
@@ -33,12 +33,12 @@ col1, col2 = st.columns([2, 4])
|
|
33 |
bs = number_field('Batch size', value=10)
|
34 |
h = number_field('Num heads', value=16)
|
35 |
d = number_field('Dimension', value=768)
|
36 |
-
|
|
|
|
|
37 |
|
38 |
st.header('Query, Key, Value projection')
|
39 |
|
40 |
-
|
41 |
-
|
42 |
st.subheader("Multi-Head Attention")
|
43 |
mha_flop = 2*bs*1*d*3*d
|
44 |
mha_bytes = 2*bs*1*d + 2*3*d*d + 2*bs*1*3*d
|
@@ -51,16 +51,28 @@ mqa_bytes = 2*bs*1*d + 2*(2/h)*d*d + 2*bs*1*(2/h)*d
|
|
51 |
c1, c2 = st.columns([2, 3])
|
52 |
qkv_mha_time = print_kernel_execution(c1, c2, mqa_flop, mqa_bytes)
|
53 |
|
54 |
-
st.header('Attention')
|
|
|
55 |
|
56 |
st.subheader("Multi-Head Attention")
|
57 |
mha_flop = 2*bs*h*(d/h)*n
|
58 |
mha_bytes = 2*bs*h*(d/h) + 2*bs*h*n*(d/h) + 2*bs*h*n
|
59 |
c1, c2 = st.columns([2, 3])
|
60 |
-
|
61 |
-
|
62 |
|
|
|
|
|
|
|
|
|
|
|
63 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
64 |
|
65 |
st.subheader("Multi-Query Attention")
|
66 |
mqa_flop = 2*bs*h*(d/h)*n
|
@@ -68,7 +80,6 @@ mqa_bytes = 2*bs*h*(d/h) + 2*bs*n*(d/h) + 2*bs*h*n
|
|
68 |
c1, c2 = st.columns([2, 3])
|
69 |
att_mqa_time = print_kernel_execution(c1, c2, mqa_flop, mqa_bytes)
|
70 |
|
71 |
-
|
72 |
st.header('MLP')
|
73 |
st.subheader('First Linear')
|
74 |
mlp1_flop = 2*bs*1*d*4*d
|
|
|
33 |
bs = number_field('Batch size', value=10)
|
34 |
h = number_field('Num heads', value=16)
|
35 |
d = number_field('Dimension', value=768)
|
36 |
+
n_start = number_field('Start seq', value=1)
|
37 |
+
n = number_field('End seq', value=1024)
|
38 |
+
l = number_field('Num layers', value=2)
|
39 |
|
40 |
st.header('Query, Key, Value projection')
|
41 |
|
|
|
|
|
42 |
st.subheader("Multi-Head Attention")
|
43 |
mha_flop = 2*bs*1*d*3*d
|
44 |
mha_bytes = 2*bs*1*d + 2*3*d*d + 2*bs*1*3*d
|
|
|
51 |
c1, c2 = st.columns([2, 3])
|
52 |
qkv_mha_time = print_kernel_execution(c1, c2, mqa_flop, mqa_bytes)
|
53 |
|
54 |
+
st.header('Attention scores: query-key gemm')
|
55 |
+
st.write("Calculation depends on sequence length (n). Take end of sequence.")
|
56 |
|
57 |
st.subheader("Multi-Head Attention")
|
58 |
mha_flop = 2*bs*h*(d/h)*n
|
59 |
mha_bytes = 2*bs*h*(d/h) + 2*bs*h*n*(d/h) + 2*bs*h*n
|
60 |
c1, c2 = st.columns([2, 3])
|
61 |
+
att1_mha_time = print_kernel_execution(c1, c2, mha_flop, mha_bytes)
|
|
|
62 |
|
63 |
+
st.subheader("Multi-Query Attention")
|
64 |
+
mqa_flop = 2*bs*h*(d/h)*n
|
65 |
+
mqa_bytes = 2*bs*h*(d/h) + 2*bs*n*(d/h) + 2*bs*h*n
|
66 |
+
c1, c2 = st.columns([2, 3])
|
67 |
+
att1_mqa_time = print_kernel_execution(c1, c2, mqa_flop, mqa_bytes)
|
68 |
|
69 |
+
st.header('Attention scores: ')
|
70 |
+
st.write("Calculation depends on sequence length. We show numbers for maximum sequence length n.")
|
71 |
+
st.subheader("Multi-Head Attention")
|
72 |
+
mha_flop = 2*bs*h*(d/h)*n
|
73 |
+
mha_bytes = 2*bs*h*(d/h) + 2*bs*h*n*(d/h) + 2*bs*h*n
|
74 |
+
c1, c2 = st.columns([2, 3])
|
75 |
+
att_mha_time = print_kernel_execution(c1, c2, mha_flop, mha_bytes)
|
76 |
|
77 |
st.subheader("Multi-Query Attention")
|
78 |
mqa_flop = 2*bs*h*(d/h)*n
|
|
|
80 |
c1, c2 = st.columns([2, 3])
|
81 |
att_mqa_time = print_kernel_execution(c1, c2, mqa_flop, mqa_bytes)
|
82 |
|
|
|
83 |
st.header('MLP')
|
84 |
st.subheader('First Linear')
|
85 |
mlp1_flop = 2*bs*1*d*4*d
|