harmdevries commited on
Commit
d9d6497
·
1 Parent(s): 2e71f02

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -29
app.py CHANGED
@@ -1,32 +1,16 @@
1
  import streamlit as st
2
 
3
- def number_field(label, columns=None, **input_params):
4
- c1, c2 = st.columns(columns or [1, 4])
5
-
6
- # Display field name with some alignment
7
- c1.markdown("##")
8
- c1.markdown(label)
9
-
10
- # Sets a default key parameter to avoid duplicate key errors
11
- input_params.setdefault("key", label)
12
-
13
- # Forward text input parameters
14
- return c2.number_input("", **input_params)
15
-
16
- def key_value(key, value, columns=None):
17
- c1, c2 = st.columns(columns or [2, 3])
18
-
19
- # Display field name with some alignment
20
- c1.markdown("##")
21
- c1.markdown(key)
22
- c2.markdown("##")
23
- c2.markdown(value)
24
-
25
  st.header("Transformer parameters")
26
- bs = number_field('Batch size: ', value=10)
27
- h = number_field('Num heads: ', value=16)
28
- d = number_field('Dimension: ', value=768)
29
- n = number_field('Seq length: ', value=1024)
 
 
 
 
 
 
30
 
31
  st.header('Query, Key, Value projection')
32
 
@@ -46,8 +30,12 @@ mqa_flop = 2*bs*n*d*(1+2/h)*d
46
  mqa_bytes = 2*bs*n*d + 2*(2/h)*d*d + 2*bs*n*(2/h)*d
47
 
48
  st.subheader("Multi-query Attention")
49
- key_value("FLOP: ", str(mqa_flop))
50
- key_value("bytes: ", str(mqa_bytes))
51
- key_value("Arithm. intensity:", str(mqa_flop/mqa_bytes))
 
 
 
 
52
 
53
  st.header('Attention')
 
1
  import streamlit as st
2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  st.header("Transformer parameters")
4
+ col1, col2 = st.columns([2, 4])
5
+
6
+ col1.write('Batch size: ')
7
+ bs = col2.number_input(value=10)
8
+ col1.write('Num heads:')
9
+ h = col2.number_input(value=16)
10
+ col1.write('Dimension:')
11
+ d = col2.number_input(value=768)
12
+ col1.write('Seq length:')
13
+ n = col2.number_input(value=1024)
14
 
15
  st.header('Query, Key, Value projection')
16
 
 
30
  mqa_bytes = 2*bs*n*d + 2*(2/h)*d*d + 2*bs*n*(2/h)*d
31
 
32
  st.subheader("Multi-query Attention")
33
+ c1, c2 = st.columns([2, 3])
34
+ c1.write("FLOP:")
35
+ c2.write(str(mqa_flop))
36
+ c1.write("Bytes: ")
37
+ c2.write(str(mqa_bytes))
38
+ c1.write("Arithm. intensity:")
39
+ c2.write(str(mqa_flop/mqa_bytes))
40
 
41
  st.header('Attention')