harmdevries commited on
Commit
c904168
·
1 Parent(s): b9ffa51

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -60
app.py CHANGED
@@ -6,6 +6,24 @@ def number_field(label, **args):
6
 
7
  return c2.number_input('', **args)
8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  TFLOPS = 312e12
10
  GB_S = 1935e9
11
 
@@ -20,82 +38,38 @@ n = number_field('Seq length', value=1024)
20
  st.header('Query, Key, Value projection')
21
 
22
 
23
- mha_flop = 2*bs*1*d*3*d
24
- mha_bytes = 2*bs*1*d + 2*3*d*d + 2*bs*1*3*d
25
- mha_int = mha_flop/mha_bytes
26
- mha_time = (mha_flop/TFLOPS + mha_bytes/GB_S)*1000
27
-
28
- mha_flop = round(mha_flop/1e9, 2)
29
- mha_bytes = round(mha_bytes/1e6, 2)
30
-
31
 
32
  st.subheader("Multi-Head Attention")
 
 
33
  c1, c2 = st.columns([2, 3])
34
- c1.write("GFLOP:")
35
- c2.write(str(mha_flop))
36
- c1.write("MB: ")
37
- c2.write(str(mha_bytes))
38
- c1.write("Arithm. intensity:")
39
- c2.write(str(mha_int))
40
- c1.write("Time (ms):")
41
- c2.write(str(mha_time))
42
-
43
 
 
44
  mqa_flop = 2*bs*1*d*(1+2/h)*d
45
  mqa_bytes = 2*bs*1*d + 2*(2/h)*d*d + 2*bs*1*(2/h)*d
46
- mqa_intensity = mqa_flop/mqa_bytes
47
- mqa_time = (mqa_flop/TFLOPS + mqa_bytes/GB_S)*1000
48
-
49
- mqa_flop = round(mqa_flop/1e9, 2)
50
- mqa_bytes = round(mqa_bytes/1e6, 2)
51
-
52
- st.subheader("Multi-Query Attention")
53
  c1, c2 = st.columns([2, 3])
54
- c1.write("GFLOP:")
55
- c2.write(str(mqa_flop))
56
- c1.write("MB:")
57
- c2.write(str(mqa_bytes))
58
- c1.write("Arithm. intensity:")
59
- c2.write(str(mqa_intensity))
60
- c1.write("Time (ms):")
61
- c2.write(str(mqa_time))
62
 
63
  st.header('Attention')
 
 
64
  mha_flop = 2*bs*h*(d/h)*n
65
  mha_bytes = 2*bs*h*(d/h) + 2*bs*h*n*(d/h) + 2*bs*h*n
66
- mha_int = mha_flop/mha_bytes
67
- mha_time = (mha_flop/TFLOPS + mha_bytes/GB_S)*1000
68
 
69
- mha_flop = round(mha_flop/1e9, 2)
70
- mha_bytes = round(mha_bytes/1e6, 2)
71
 
72
 
73
- st.subheader("Multi-Head Attention")
74
- c1, c2 = st.columns([2, 3])
75
- c1.write("GFLOP:")
76
- c2.write(str(mha_flop))
77
- c1.write("MB: ")
78
- c2.write(str(mha_bytes))
79
- c1.write("Arithm. intensity:")
80
- c2.write(str(mha_int))
81
- c1.write("Time (ms):")
82
- c2.write(str(mha_time))
83
 
 
84
  mqa_flop = 2*bs*h*(d/h)*n
85
  mqa_bytes = 2*bs*h*(d/h) + 2*bs*n*(d/h) + 2*bs*h*n
86
- mqa_intensity = mqa_flop/mqa_bytes
87
- mqa_time = (mqa_flop/TFLOPS + mqa_bytes/GB_S)*1000
88
 
89
- mqa_flop = round(mqa_flop/1e9, 2)
90
- mqa_bytes = round(mqa_bytes/1e6, 2)
91
 
92
- st.subheader("Multi-Query Attention")
93
- c1, c2 = st.columns([2, 3])
94
- c1.write("GFLOP:")
95
- c2.write(str(mqa_flop))
96
- c1.write("MB:")
97
- c2.write(str(mqa_bytes))
98
- c1.write("Arithm. intensity:")
99
- c2.write(str(mqa_intensity))
100
- c1.write("Time (ms):")
101
- c2.write(str(mqa_time))
 
6
 
7
  return c2.number_input('', **args)
8
 
9
+ def print_kernel_execution(c1, c2, comp_flop, mem_bytes):
10
+ arith_int = comp_flop/mem_bytes
11
+ exec_time = (comp_flop/TFLOPS + mem_bytes/GB_S)*1000
12
+
13
+ comp_flop = round(mha_flop/1e9, 2)
14
+ mem_bytes = round(mha_bytes/1e6, 2)
15
+
16
+ c1.write("GFLOP:")
17
+ c2.write(str(flop))
18
+ c1.write("MB: ")
19
+ c2.write(str(mha_bytes))
20
+ c1.write("Arithm. intensity:")
21
+ c2.write(str(mha_int))
22
+ c1.write("Time (ms):")
23
+ c2.write(str(mha_time))
24
+
25
+ return exec_time
26
+
27
  TFLOPS = 312e12
28
  GB_S = 1935e9
29
 
 
38
  st.header('Query, Key, Value projection')
39
 
40
 
 
 
 
 
 
 
 
 
41
 
42
  st.subheader("Multi-Head Attention")
43
+ mha_flop = 2*bs*1*d*3*d
44
+ mha_bytes = 2*bs*1*d + 2*3*d*d + 2*bs*1*3*d
45
  c1, c2 = st.columns([2, 3])
46
+ qkv_mha_time = print_kernel_execution(c1, c2, mha_flop, mha_bytes)
 
 
 
 
 
 
 
 
47
 
48
+ st.subheader("Multi-Query Attention")
49
  mqa_flop = 2*bs*1*d*(1+2/h)*d
50
  mqa_bytes = 2*bs*1*d + 2*(2/h)*d*d + 2*bs*1*(2/h)*d
 
 
 
 
 
 
 
51
  c1, c2 = st.columns([2, 3])
52
+ qkv_mha_time = print_kernel_execution(c1, c2, mqa_flop, mqa_bytes)
 
 
 
 
 
 
 
53
 
54
  st.header('Attention')
55
+
56
+ st.subheader("Multi-Head Attention")
57
  mha_flop = 2*bs*h*(d/h)*n
58
  mha_bytes = 2*bs*h*(d/h) + 2*bs*h*n*(d/h) + 2*bs*h*n
59
+ c1, c2 = st.columns([2, 3])
60
+ att_mha_time = print_kernel_execution(c1, c2, mha_flop, mha_bytes)
61
 
 
 
62
 
63
 
 
 
 
 
 
 
 
 
 
 
64
 
65
+ st.subheader("Multi-Query Attention")
66
  mqa_flop = 2*bs*h*(d/h)*n
67
  mqa_bytes = 2*bs*h*(d/h) + 2*bs*n*(d/h) + 2*bs*h*n
68
+ c1, c2 = st.columns([2, 3])
69
+ att_mqa_time = print_kernel_execution(c1, c2, mqa_flop, mqa_bytes)
70
 
 
 
71
 
72
+ st.header('MLP')
73
+ st.subheader('First Linear')
74
+
75
+ st.subheader('Second Linear')