harmdevries commited on
Commit
a275f69
·
1 Parent(s): ea57214

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -11
app.py CHANGED
@@ -35,51 +35,54 @@ h = number_field('Num heads', value=16)
35
  d = number_field('Dimension', value=768)
36
  n_start = number_field('Start seq', value=1)
37
  n = number_field('End seq', value=1024)
38
- l = number_field('Num layers', value=2)
39
 
40
- st.header('Query, Key, Value projection')
41
 
42
- st.subheader("Multi-Head Attention")
 
43
  mha_flop = 2*bs*1*d*3*d
44
  mha_bytes = 2*bs*1*d + 2*3*d*d + 2*bs*1*3*d
45
  c1, c2 = st.columns([2, 3])
46
  qkv_mha_time = print_kernel_execution(c1, c2, mha_flop, mha_bytes)
47
 
48
- st.subheader("Multi-Query Attention")
49
  mqa_flop = 2*bs*1*d*(1+2/h)*d
50
  mqa_bytes = 2*bs*1*d + 2*(2/h)*d*d + 2*bs*1*(2/h)*d
51
  c1, c2 = st.columns([2, 3])
52
  qkv_mha_time = print_kernel_execution(c1, c2, mqa_flop, mqa_bytes)
53
 
54
- st.header('Attention scores: query-key gemm')
55
- st.write("Calculation depends on sequence length (n). Take end of sequence.")
56
 
57
- st.subheader("Multi-Head Attention")
58
  mha_flop = 2*bs*h*(d/h)*n
59
  mha_bytes = 2*bs*h*(d/h) + 2*bs*h*n*(d/h) + 2*bs*h*n
60
  c1, c2 = st.columns([2, 3])
61
  att1_mha_time = print_kernel_execution(c1, c2, mha_flop, mha_bytes)
62
 
63
- st.subheader("Multi-Query Attention")
64
  mqa_flop = 2*bs*h*(d/h)*n
65
  mqa_bytes = 2*bs*h*(d/h) + 2*bs*n*(d/h) + 2*bs*h*n
66
  c1, c2 = st.columns([2, 3])
67
  att1_mqa_time = print_kernel_execution(c1, c2, mqa_flop, mqa_bytes)
68
 
69
- st.header('Attention scores: attention-value gemm')
70
  st.write("Calculation depends on sequence length. We show numbers for maximum sequence length n.")
71
- st.subheader("Multi-Head Attention")
72
  mha_flop = 2*bs*h*n*(d/h)
73
  mha_bytes = 2*bs*h*n + 2*bs*h*n*(d/h) + 2*bs*h*(d/h)
74
  c1, c2 = st.columns([2, 3])
75
  att_mha_time = print_kernel_execution(c1, c2, mha_flop, mha_bytes)
76
 
77
- st.subheader("Multi-Query Attention")
78
  mqa_flop = 2*bs*h*n*(d/h)
79
  mqa_bytes = 2*bs*n*(d/h) + 2*bs*n*(d/h) + 2*bs*h*(d/h)
80
  c1, c2 = st.columns([2, 3])
81
  att_mqa_time = print_kernel_execution(c1, c2, mqa_flop, mqa_bytes)
82
 
 
 
83
  st.header('MLP')
84
  st.subheader('First Linear')
85
  mlp1_flop = 2*bs*1*d*4*d
 
35
  d = number_field('Dimension', value=768)
36
  n_start = number_field('Start seq', value=1)
37
  n = number_field('End seq', value=1024)
38
+ l = number_field('Num layers', value=24)
39
 
40
+ st.header('Attention layer')
41
 
42
+ st.subheader('QKV projection')
43
+ st.caption("Multi-Head Attention")
44
  mha_flop = 2*bs*1*d*3*d
45
  mha_bytes = 2*bs*1*d + 2*3*d*d + 2*bs*1*3*d
46
  c1, c2 = st.columns([2, 3])
47
  qkv_mha_time = print_kernel_execution(c1, c2, mha_flop, mha_bytes)
48
 
49
+ st.caption("Multi-Query Attention")
50
  mqa_flop = 2*bs*1*d*(1+2/h)*d
51
  mqa_bytes = 2*bs*1*d + 2*(2/h)*d*d + 2*bs*1*(2/h)*d
52
  c1, c2 = st.columns([2, 3])
53
  qkv_mha_time = print_kernel_execution(c1, c2, mqa_flop, mqa_bytes)
54
 
55
+ st.subheader('QK gemm')
56
+ st.write("Note that calculation depends on sequence length (n)")
57
 
58
+ st.caption("Multi-Head Attention")
59
  mha_flop = 2*bs*h*(d/h)*n
60
  mha_bytes = 2*bs*h*(d/h) + 2*bs*h*n*(d/h) + 2*bs*h*n
61
  c1, c2 = st.columns([2, 3])
62
  att1_mha_time = print_kernel_execution(c1, c2, mha_flop, mha_bytes)
63
 
64
+ st.caption("Multi-Query Attention")
65
  mqa_flop = 2*bs*h*(d/h)*n
66
  mqa_bytes = 2*bs*h*(d/h) + 2*bs*n*(d/h) + 2*bs*h*n
67
  c1, c2 = st.columns([2, 3])
68
  att1_mqa_time = print_kernel_execution(c1, c2, mqa_flop, mqa_bytes)
69
 
70
+ st.header('Attention-value gemm')
71
  st.write("Calculation depends on sequence length. We show numbers for maximum sequence length n.")
72
+ st.caption("Multi-Head Attention")
73
  mha_flop = 2*bs*h*n*(d/h)
74
  mha_bytes = 2*bs*h*n + 2*bs*h*n*(d/h) + 2*bs*h*(d/h)
75
  c1, c2 = st.columns([2, 3])
76
  att_mha_time = print_kernel_execution(c1, c2, mha_flop, mha_bytes)
77
 
78
+ st.caption("Multi-Query Attention")
79
  mqa_flop = 2*bs*h*n*(d/h)
80
  mqa_bytes = 2*bs*n*(d/h) + 2*bs*n*(d/h) + 2*bs*h*(d/h)
81
  c1, c2 = st.columns([2, 3])
82
  att_mqa_time = print_kernel_execution(c1, c2, mqa_flop, mqa_bytes)
83
 
84
+ st.subheader('Output projection')
85
+
86
  st.header('MLP')
87
  st.subheader('First Linear')
88
  mlp1_flop = 2*bs*1*d*4*d