harmdevries commited on
Commit
999ad78
·
1 Parent(s): 2a3864d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -11
app.py CHANGED
@@ -1,9 +1,7 @@
1
  import streamlit as st
2
 
3
- def number_field(label, **kwargs):
4
- c1, c2 = st.columns([2, 4])
5
  c1.write(label)
6
-
7
  return c2.number_input('', **kwargs)
8
 
9
  def calc_exec_time(comp_flop, mem_bytes):
@@ -30,15 +28,15 @@ def print_kernel_execution(c1, c2, comp_flop, mem_bytes):
30
  TFLOPS = 312e12
31
  GB_S = 1935e9
32
 
33
- st.header("Transformer parameters")
34
- col1, col2 = st.columns([2, 4])
35
 
36
- bs = number_field('Batch size', value=10)
37
- h = number_field('Num heads', value=16)
38
- d = number_field('Dimension', value=768)
39
- n_start = number_field('Start seq', value=1)
40
- n = number_field('End seq', value=1024)
41
- l = number_field('Num layers', value=24)
42
 
43
  st.header('Attention layer')
44
 
 
1
  import streamlit as st
2
 
3
+ def number_field(label, c1, c2, **kwargs):
 
4
  c1.write(label)
 
5
  return c2.number_input('', **kwargs)
6
 
7
  def calc_exec_time(comp_flop, mem_bytes):
 
28
  TFLOPS = 312e12
29
  GB_S = 1935e9
30
 
31
+ st.sidebar.header("Transformer parameters")
32
+ col1, col2 = st.sidebar.columns([2, 4])
33
 
34
+ bs = number_field('Batch size', col1, col2, value=10)
35
+ h = number_field('Num heads', col1, col2, value=16)
36
+ d = number_field('Dimension', col1, col2, value=768)
37
+ n_start = number_field('Start seq', col1, col2, value=1)
38
+ n = number_field('End seq', col1, col2, value=1024)
39
+ l = number_field('Num layers', col1, col2, value=24)
40
 
41
  st.header('Attention layer')
42