harmdevries commited on
Commit
a21c8ab
·
1 Parent(s): 9ff92c0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +86 -4
app.py CHANGED
@@ -1,7 +1,74 @@
1
  import streamlit as st
2
 
 
 
 
 
 
 
 
3
  def calc_exec_time(comp_flop, mem_bytes):
4
  return (comp_flop/TFLOPS + mem_bytes/GB_S)*1000
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
  def print_kernel_execution(c1, c2, comp_flop, mem_bytes):
7
  arith_int = comp_flop/mem_bytes
@@ -14,15 +81,12 @@ def print_kernel_execution(c1, c2, comp_flop, mem_bytes):
14
  c2.write(str(comp_flop))
15
  c1.write("MB: ")
16
  c2.write(str(mem_bytes))
17
- c1.write("Arithm. intensity:")
18
- c2.write(str(arith_int))
19
  c1.write("Time (ms):")
20
  c2.write(str(exec_time))
21
 
22
  return exec_time
23
 
24
- TFLOPS = 312e12
25
- GB_S = 1935e9
26
 
27
  st.sidebar.header("Transformer parameters")
28
  col1, col2 = st.sidebar.columns([2, 4])
@@ -31,11 +95,29 @@ bs = st.sidebar.number_input('Batch size', value=10)
31
  h = st.sidebar.number_input('Num heads',value=16)
32
  d = st.sidebar.number_input('Dimension', value=768)
33
  l = st.sidebar.number_input('Num layers', value=24)
 
34
  n_start = st.sidebar.number_input('Start seq', value=1)
35
  n = st.sidebar.number_input('End seq', value=1024)
36
 
37
  st.sidebar.header("GPU parameters")
38
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  st.header('Attention layer')
40
 
41
  st.subheader('QKV projection')
 
1
  import streamlit as st
2
 
3
+ # A100 specs
4
+ TFLOPS = 312e12
5
+ GB_S = 1935e9
6
+
7
+ # in ms
8
+ THREAD_OVERHEAD = 0.005
9
+
10
  def calc_exec_time(comp_flop, mem_bytes):
11
  return (comp_flop/TFLOPS + mem_bytes/GB_S)*1000
12
+
13
+ def qkv_mha_exec(bs, h, n, d):
14
+ flop = 2*bs*1*d*3*d
15
+ nbytes = 2*bs*1*d + 2*3*d*d + 2*bs*1*3*d
16
+ exec_time = calc_exec_time(flop, nbytes)
17
+ return flop, nbytes, exec_time
18
+
19
+ def qkv_mqa_exec(bs, h, n, d):
20
+ flop = 2*bs*1*d*(1+2/h)*d
21
+ nbytes = 2*bs*1*d + 2*(2/h)*d*d + 2*bs*1*(2/h)*d
22
+ exec_time = calc_exec_time(flop, nbytes)
23
+ return flop, nbytes, exec_time
24
+
25
+ def att1_mha_exec(bs, h, n, d):
26
+ flop = 2*bs*h*(d/h)*n
27
+ nbytes = 2*bs*h*(d/h) + 2*bs*h*n*(d/h) + 2*bs*h*n
28
+ exec_time = calc_exec_time(flop, nbytes)
29
+ return flop, nbytes, exec_time
30
+
31
+ def att1_mqa_exec(bs, h, n, d):
32
+ flop = 2*bs*h*(d/h)*n
33
+ nbytes = 2*bs*h*(d/h) + 2*bs*n*(d/h) + 2*bs*h*n
34
+ exec_time = calc_exec_time(flop, nbytes)
35
+ return flop, nbytes, exec_time
36
+
37
+ def att2_mha_exec(bs, h, n, d):
38
+ flop = 2*bs*h*n*(d/h)
39
+ nbytes = 2*bs*h*n + 2*bs*h*n*(d/h) + 2*bs*h*(d/h)
40
+ exec_time = calc_exec_time(flop, nbytes)
41
+ return flop, nbytes, exec_time
42
+
43
+ def att2_mqa_exec(bs, h, n, d):
44
+ flop = 2*bs*h*n*(d/h)
45
+ nbytes = 2*bs*n*(d/h) + 2*bs*n*(d/h) + 2*bs*h*(d/h)
46
+ exec_time = calc_exec_time(flop, nbytes)
47
+ return flop, nbytes, exec_time
48
+
49
+ def out_exec(bs, h, n, d):
50
+ flop = 2*bs*1*d*d
51
+ nbytes = 2*bs*1*d + 2*d*d + 2*bs*1*d
52
+ exec_time = calc_exec_time(flop, nbytes)
53
+ return flop, nbytes, exec_time
54
+
55
+ def softmax_exec(bs, h, n, d):
56
+ flop = 0
57
+ nbytes = 2*bs*h*n + 2*bs*h*n
58
+ exec_time = calc_exec_time(flop, nbytes)
59
+ return flop, nbytes, exec_time
60
+
61
+ def ln_exec(bs, h, n, d):
62
+ nbytes = 2*bs*1*d + 2*bs*1*d
63
+ flop = 0
64
+ exec_time = calc_exec_time(flop, nbytes)
65
+ return flop, nbytes, exec_time
66
+
67
+ def mlp_exec(bs, h, n, d):
68
+ flop = 2*bs*1*d*4*d
69
+ nbytes = 2*bs*1*d + 2*d*4*d + 2*bs*1*4*d
70
+ exec_time = calc_exec_time(flop, nbytes)
71
+ return flop, nbytes, exec_time
72
 
73
  def print_kernel_execution(c1, c2, comp_flop, mem_bytes):
74
  arith_int = comp_flop/mem_bytes
 
81
  c2.write(str(comp_flop))
82
  c1.write("MB: ")
83
  c2.write(str(mem_bytes))
 
 
84
  c1.write("Time (ms):")
85
  c2.write(str(exec_time))
86
 
87
  return exec_time
88
 
89
+
 
90
 
91
  st.sidebar.header("Transformer parameters")
92
  col1, col2 = st.sidebar.columns([2, 4])
 
95
  h = st.sidebar.number_input('Num heads',value=16)
96
  d = st.sidebar.number_input('Dimension', value=768)
97
  l = st.sidebar.number_input('Num layers', value=24)
98
+
99
  n_start = st.sidebar.number_input('Start seq', value=1)
100
  n = st.sidebar.number_input('End seq', value=1024)
101
 
102
  st.sidebar.header("GPU parameters")
103
 
104
+
105
+ st.header("Total execution time")
106
+
107
+ mqa_total_time = 0.
108
+ mha_total_time = 0.
109
+
110
+ for i in range(n_start, n):
111
+ shared_time = out_exec(bs, h, i, d)[2] + softmax_exec(bs, h, i , d)[2] \
112
+ 2*ln_exec(bs, h, i, d)[2] + 2*mlp_exec(bs, h, i, d)[2] + 3*ln_exec(bs, h, i, d)
113
+ mha_time = shared_time + qkv_mha_exec(bs, h, i, d)[2] + att1_mha_exec(bs, h, i, d)[2] + att2_mha_exec(bs, h, i, d)[2]
114
+ mha_total_time += l*mha_time
115
+ mqa_time = shared_time + qkv_mqa_exec(bs, h, i, d)[2] + att1_mqa_exec(bs, h, i, d)[2] + att2_mqa_exec(bs, h, i, d)[2]
116
+ mqa_total_time += l*mqa_time
117
+
118
+ st.write("MHA exec time: " + str(mha_total_time))
119
+ st.write("MQA exec time: " + str(mqa_total_time))
120
+
121
  st.header('Attention layer')
122
 
123
  st.subheader('QKV projection')