harmdevries commited on
Commit
8cbefab
·
1 Parent(s): 54cd0e6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -7
app.py CHANGED
@@ -33,12 +33,12 @@ col1, col2 = st.columns([2, 4])
33
  bs = number_field('Batch size', value=10)
34
  h = number_field('Num heads', value=16)
35
  d = number_field('Dimension', value=768)
36
- n = number_field('Seq length', value=1024)
 
 
37
 
38
  st.header('Query, Key, Value projection')
39
 
40
-
41
-
42
  st.subheader("Multi-Head Attention")
43
  mha_flop = 2*bs*1*d*3*d
44
  mha_bytes = 2*bs*1*d + 2*3*d*d + 2*bs*1*3*d
@@ -51,16 +51,28 @@ mqa_bytes = 2*bs*1*d + 2*(2/h)*d*d + 2*bs*1*(2/h)*d
51
  c1, c2 = st.columns([2, 3])
52
  qkv_mha_time = print_kernel_execution(c1, c2, mqa_flop, mqa_bytes)
53
 
54
- st.header('Attention')
 
55
 
56
  st.subheader("Multi-Head Attention")
57
  mha_flop = 2*bs*h*(d/h)*n
58
  mha_bytes = 2*bs*h*(d/h) + 2*bs*h*n*(d/h) + 2*bs*h*n
59
  c1, c2 = st.columns([2, 3])
60
- att_mha_time = print_kernel_execution(c1, c2, mha_flop, mha_bytes)
61
-
62
 
 
 
 
 
 
63
 
 
 
 
 
 
 
 
64
 
65
  st.subheader("Multi-Query Attention")
66
  mqa_flop = 2*bs*h*(d/h)*n
@@ -68,7 +80,6 @@ mqa_bytes = 2*bs*h*(d/h) + 2*bs*n*(d/h) + 2*bs*h*n
68
  c1, c2 = st.columns([2, 3])
69
  att_mqa_time = print_kernel_execution(c1, c2, mqa_flop, mqa_bytes)
70
 
71
-
72
  st.header('MLP')
73
  st.subheader('First Linear')
74
  mlp1_flop = 2*bs*1*d*4*d
 
33
  bs = number_field('Batch size', value=10)
34
  h = number_field('Num heads', value=16)
35
  d = number_field('Dimension', value=768)
36
+ n_start = number_field('Start seq', value=1)
37
+ n = number_field('End seq', value=1024)
38
+ l = number_field('Num layers', value=2)
39
 
40
  st.header('Query, Key, Value projection')
41
 
 
 
42
  st.subheader("Multi-Head Attention")
43
  mha_flop = 2*bs*1*d*3*d
44
  mha_bytes = 2*bs*1*d + 2*3*d*d + 2*bs*1*3*d
 
51
  c1, c2 = st.columns([2, 3])
52
  qkv_mha_time = print_kernel_execution(c1, c2, mqa_flop, mqa_bytes)
53
 
54
+ st.header('Attention scores: query-key gemm')
55
+ st.write("Calculation depends on sequence length (n). Take end of sequence.")
56
 
57
  st.subheader("Multi-Head Attention")
58
  mha_flop = 2*bs*h*(d/h)*n
59
  mha_bytes = 2*bs*h*(d/h) + 2*bs*h*n*(d/h) + 2*bs*h*n
60
  c1, c2 = st.columns([2, 3])
61
+ att1_mha_time = print_kernel_execution(c1, c2, mha_flop, mha_bytes)
 
62
 
63
+ st.subheader("Multi-Query Attention")
64
+ mqa_flop = 2*bs*h*(d/h)*n
65
+ mqa_bytes = 2*bs*h*(d/h) + 2*bs*n*(d/h) + 2*bs*h*n
66
+ c1, c2 = st.columns([2, 3])
67
+ att1_mqa_time = print_kernel_execution(c1, c2, mqa_flop, mqa_bytes)
68
 
69
+ st.header('Attention scores: ')
70
+ st.write("Calculation depends on sequence length. We show numbers for maximum sequence length n.")
71
+ st.subheader("Multi-Head Attention")
72
+ mha_flop = 2*bs*h*(d/h)*n
73
+ mha_bytes = 2*bs*h*(d/h) + 2*bs*h*n*(d/h) + 2*bs*h*n
74
+ c1, c2 = st.columns([2, 3])
75
+ att_mha_time = print_kernel_execution(c1, c2, mha_flop, mha_bytes)
76
 
77
  st.subheader("Multi-Query Attention")
78
  mqa_flop = 2*bs*h*(d/h)*n
 
80
  c1, c2 = st.columns([2, 3])
81
  att_mqa_time = print_kernel_execution(c1, c2, mqa_flop, mqa_bytes)
82
 
 
83
  st.header('MLP')
84
  st.subheader('First Linear')
85
  mlp1_flop = 2*bs*1*d*4*d