harmdevries commited on
Commit
b75a0d1
·
1 Parent(s): 23b5435

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +32 -10
app.py CHANGED
@@ -1,10 +1,32 @@
1
  import streamlit as st
2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  st.header("Transformer parameters")
4
- bs = st.number_input('Batch size: ', value=10)
5
- h = st.number_input('Num heads: ', value=16)
6
- d = st.number_input('Dimension: ', value=768)
7
- n = st.number_input('Seq length: ', value=1024)
8
 
9
  st.header('Query, Key, Value projection')
10
 
@@ -12,16 +34,16 @@ mha_flop = 2*bs*n*d*3*d
12
  mha_bytes = 2*bs*n*d + 2*3*d*d + 2*bs*n*3*d
13
 
14
  st.subheader("Multi-query Attention")
15
- st.text("FLOP: " + str(mha_flop))
16
- st.text("bytes: " + str(mha_bytes))
17
- st.text("Arithm. intensity:" + str(mha_flop/mha_bytes))
18
 
19
  mqa_flop = 2*bs*n*d*(1+2/h)*d
20
  mqa_bytes = 2*bs*n*d + 2*(2/h)*d*d + 2*bs*n*(2/h)*d
21
 
22
  st.subheader("Multi-query Attention")
23
- st.write("FLOP: " + str(mqa_flop))
24
- st.write("bytes: " + str(mqa_bytes))
25
- st.write("Arithm. intensity:" + str(mqa_flop/mqa_bytes))
26
 
27
  st.header('Attention')
 
1
  import streamlit as st
2
 
3
+ def number_field(label, columns=None, **input_params):
4
+ c1, c2 = st.beta_columns(columns or [1, 4])
5
+
6
+ # Display field name with some alignment
7
+ c1.markdown("##")
8
+ c1.markdown(label)
9
+
10
+ # Sets a default key parameter to avoid duplicate key errors
11
+ input_params.setdefault("key", label)
12
+
13
+ # Forward text input parameters
14
+ return c2.number_input("", **input_params)
15
+
16
+ def key_value(key, value, columns=None):
17
+ c1, c2 = st.beta_columns(columns or [2, 3])
18
+
19
+ # Display field name with some alignment
20
+ c1.markdown("##")
21
+ c1.markdown(key)
22
+ c2.markdown("##")
23
+ c2.markdown(value)
24
+
25
  st.header("Transformer parameters")
26
+ bs = number_field('Batch size: ', value=10)
27
+ h = number_field('Num heads: ', value=16)
28
+ d = number_field('Dimension: ', value=768)
29
+ n = number_field('Seq length: ', value=1024)
30
 
31
  st.header('Query, Key, Value projection')
32
 
 
34
  mha_bytes = 2*bs*n*d + 2*3*d*d + 2*bs*n*3*d
35
 
36
  st.subheader("Multi-query Attention")
37
+ key_value("FLOP: ", str(mha_flop))
38
+ key_value("bytes: ", str(mha_bytes))
39
+ key_value("Arithm. intensity:", str(mha_flop/mha_bytes))
40
 
41
  mqa_flop = 2*bs*n*d*(1+2/h)*d
42
  mqa_bytes = 2*bs*n*d + 2*(2/h)*d*d + 2*bs*n*(2/h)*d
43
 
44
  st.subheader("Multi-query Attention")
45
+ key_value("FLOP: ", str(mqa_flop))
46
+ key_value("bytes: ", str(mqa_bytes))
47
+ key_value("Arithm. intensity:", str(mqa_flop/mqa_bytes))
48
 
49
  st.header('Attention')