harmdevries commited on
Commit
9ff92c0
·
1 Parent(s): 999ad78

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -10
app.py CHANGED
@@ -1,9 +1,5 @@
1
  import streamlit as st
2
 
3
- def number_field(label, c1, c2, **kwargs):
4
- c1.write(label)
5
- return c2.number_input('', **kwargs)
6
-
7
  def calc_exec_time(comp_flop, mem_bytes):
8
  return (comp_flop/TFLOPS + mem_bytes/GB_S)*1000
9
 
@@ -31,12 +27,14 @@ GB_S = 1935e9
31
  st.sidebar.header("Transformer parameters")
32
  col1, col2 = st.sidebar.columns([2, 4])
33
 
34
- bs = number_field('Batch size', col1, col2, value=10)
35
- h = number_field('Num heads', col1, col2, value=16)
36
- d = number_field('Dimension', col1, col2, value=768)
37
- n_start = number_field('Start seq', col1, col2, value=1)
38
- n = number_field('End seq', col1, col2, value=1024)
39
- l = number_field('Num layers', col1, col2, value=24)
 
 
40
 
41
  st.header('Attention layer')
42
 
 
1
  import streamlit as st
2
 
 
 
 
 
3
  def calc_exec_time(comp_flop, mem_bytes):
4
  return (comp_flop/TFLOPS + mem_bytes/GB_S)*1000
5
 
 
27
  st.sidebar.header("Transformer parameters")
28
  col1, col2 = st.sidebar.columns([2, 4])
29
 
30
+ bs = st.sidebar.number_input('Batch size', value=10)
31
+ h = st.sidebar.number_input('Num heads',value=16)
32
+ d = st.sidebar.number_input('Dimension', value=768)
33
+ l = st.sidebar.number_input('Num layers', value=24)
34
+ n_start = st.sidebar.number_input('Start seq', value=1)
35
+ n = st.sidebar.number_input('End seq', value=1024)
36
+
37
+ st.sidebar.header("GPU parameters")
38
 
39
  st.header('Attention layer')
40