harmdevries commited on
Commit
86cd028
·
1 Parent(s): f0d14cb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -5
app.py CHANGED
@@ -6,9 +6,12 @@ def number_field(label, **kwargs):
6
 
7
  return c2.number_input('', **kwargs)
8
 
 
 
 
9
  def print_kernel_execution(c1, c2, comp_flop, mem_bytes):
10
  arith_int = comp_flop/mem_bytes
11
- exec_time = (comp_flop/TFLOPS + mem_bytes/GB_S)*1000
12
 
13
  comp_flop = round(comp_flop/1e9, 2)
14
  mem_bytes = round(mem_bytes/1e6, 2)
@@ -82,13 +85,21 @@ c1, c2 = st.columns([2, 3])
82
  att_mqa_time = print_kernel_execution(c1, c2, mqa_flop, mqa_bytes)
83
 
84
  st.subheader('Output projection')
85
- mlp1_flop = 2*bs*1*d
86
- mlp1_bytes = 2*bs*1*d + 2*d*4*d + 2*bs*1*4*d
87
  c1, c2 = st.columns([2, 3])
88
- mlp1_time = print_kernel_execution(c1, c2, mlp1_flop, mlp1_bytes)
89
 
90
  st.subheader('Element-wise ops')
91
- st.write("A couple of layers ")
 
 
 
 
 
 
 
 
92
 
93
  st.header('MLP')
94
  st.subheader('First Linear')
 
6
 
7
  return c2.number_input('', **kwargs)
8
 
9
+ def calc_exec_time(comp_flop, mem_bytes):
10
+ return (comp_flop/TFLOPS + mem_bytes/GB_S)*1000
11
+
12
  def print_kernel_execution(c1, c2, comp_flop, mem_bytes):
13
  arith_int = comp_flop/mem_bytes
14
+ exec_time = calc_exec_time(comp_flop, mem_bytes)
15
 
16
  comp_flop = round(comp_flop/1e9, 2)
17
  mem_bytes = round(mem_bytes/1e6, 2)
 
85
  att_mqa_time = print_kernel_execution(c1, c2, mqa_flop, mqa_bytes)
86
 
87
  st.subheader('Output projection')
88
+ out_flop = 2*bs*1*d*d
89
+ out_bytes = 2*bs*1*d + 2*d*d + 2*bs*1*d
90
  c1, c2 = st.columns([2, 3])
91
+ out_time = print_kernel_execution(c1, c2, out_flop, out_bytes)
92
 
93
  st.subheader('Element-wise ops')
94
+ st.write("We also need to take into the softmax layer and layer norm")
95
+
96
+ st.caption("Softmax")
97
+ softmax_bytes = 2*bs*h*n + 2*bs*h*n
98
+ c1, c2 = st.columns([2, 3])
99
+ softmax_time = print_kernel_execution(c1, c2, 0, softmax_bytes))
100
+
101
+ st.caption("Layer norm")
102
+
103
 
104
  st.header('MLP')
105
  st.subheader('First Linear')