File size: 11,862 Bytes
3698d0a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
import numpy as np


def multiplication_in_int64(array):
    return np.cumprod(np.array(array, dtype=np.int64))[-1]

def matrix_operation(shapeA, shapeB):
    assert(shapeA[-1] == shapeB[0])
    op = np.cumprod(np.array(shapeA[:-1], np.float64))    
    return multiplication_in_int64([2, op[-1], shapeA[-1], shapeB[-1]])

def word_embedding_operation(model_config, inference_config):
    #Given:
    #\begin{itemize}
    #    \item Matrix \( X \) of size \( B \times s \) (representing the batch size and sequence length respectively).
    #    \item Embedding matrix \( W_e \) of size \( n_{vocab} \times d_{model} \).
    #\end{itemize}
    
    #The resultant matrix after the multiplication will be of size \( B \times s \times d_{model} \).
    #For each element in this resultant matrix, the number of FLOPs required is \( 2 \times n_{vocab} \). This is because for a single element in the output matrix, we have \( 2N \) FLOPs (with \( N \) being the common dimension), leading to the matrix multiplication FLOP count as:
    #\begin{equation}
    #2 \times B \times s \times n_{vocab} \times d_{model}
    #\end{equation}
    A = [inference_config['batchsize'], inference_config['input_seq_length'], model_config['vocab_size']]
    B = [model_config['vocab_size'], model_config['hidden_size']]
    return matrix_operation(A, B)


def positional_embedding_operation(model_config, inference_config):
    return multiplication_in_int64([inference_config['batchsize'], inference_config['input_seq_length'], model_config['hidden_size']])

### Below three are the same
def attention_K_operation(model_config, inference_config, seq_length):
    A = [inference_config['batchsize'], seq_length, model_config['hidden_size']]
    B = [model_config['hidden_size'], model_config['hidden_size']/model_config['num_attention_heads']]
    return model_config['num_hidden_layers'] * model_config['num_attention_heads'] * matrix_operation(A, B)

def attention_Q_operation(model_config, inference_config, seq_length):
    A = [inference_config['batchsize'], seq_length, model_config['hidden_size']]
    B = [model_config['hidden_size'], model_config['hidden_size']/model_config['num_attention_heads']]
    return model_config['num_hidden_layers'] * model_config['num_attention_heads'] * matrix_operation(A, B)

def attention_V_operation(model_config, inference_config, seq_length):
    A = [inference_config['batchsize'], seq_length, model_config['hidden_size']]
    B = [model_config['hidden_size'], model_config['hidden_size']/model_config['num_attention_heads']]
    return model_config['num_hidden_layers'] * model_config['num_attention_heads'] * matrix_operation(A, B)

## 
def attention_QK_operation(model_config, inference_config, seq_length_Q, seq_length_K):
    A = [inference_config['batchsize'], seq_length_Q, model_config['hidden_size']/model_config['num_attention_heads']]
    B = [model_config['hidden_size']/model_config['num_attention_heads'], seq_length_K]
    return model_config['num_hidden_layers'] * model_config['num_attention_heads']* matrix_operation(A, B)

def attention_softmax_operation(model_config, inference_config,seq_length):
    # Ref: Ouyang, A. (2023). Understanding the Performance of Transformer Inference (Doctoral dissertation, Massachusetts Institute of Technology).
    # 3 is a modeled value
    softmax_operation = (3*inference_config['batchsize']*seq_length*seq_length)
    return model_config['num_hidden_layers'] * model_config['num_attention_heads'] * softmax_operation

def attention_multV_operation(model_config, inference_config, seq_length_Q, seq_length_V):
    A = [inference_config['batchsize'], seq_length_Q, seq_length_V]
    B = [seq_length_V, model_config['hidden_size']/model_config['num_attention_heads']]
    return model_config['num_hidden_layers'] * model_config['num_attention_heads']* matrix_operation(A, B)

def attention_out_operation(model_config, inference_config, seq_length):
    A = [inference_config['batchsize'], seq_length, model_config['hidden_size']]
    B = [model_config['hidden_size'], model_config['hidden_size']]
    return model_config['num_hidden_layers'] * matrix_operation(A, B)

def layernorm_operation(model_config, inference_config, seq_length):
    # Ref: Ouyang, A. (2023). Understanding the Performance of Transformer Inference (Doctoral dissertation, Massachusetts Institute of Technology).
    # 5 is a modeled value
    layernorm_operation = (5*inference_config['batchsize']*seq_length*model_config['hidden_size'])
    return model_config['num_hidden_layers'] * model_config['layernorm_operation'] * layernorm_operation


def mlp1_operation(model_config, inference_config, seq_length):
    A = [inference_config['batchsize'], seq_length, model_config['hidden_size']]
    B = [model_config['hidden_size'], model_config['intermediate_size']]
    return model_config['num_hidden_layers'] * matrix_operation(A, B)

def mlp2_operation(model_config, inference_config, seq_length):
    A = [inference_config['batchsize'], seq_length, model_config['intermediate_size']]
    B = [model_config['intermediate_size'], model_config['hidden_size']]
    return model_config['num_hidden_layers'] * matrix_operation(A, B)

def prefilling_operation(model_config, inference_config):
    prefilling_operation_count = {}
    prefilling_operation_count['word_embedding'] = word_embedding_operation(model_config, inference_config)
    prefilling_operation_count['positional_embedding'] = positional_embedding_operation(model_config, inference_config)
    
    prefilling_operation_count['attention_Q'] = attention_Q_operation(model_config, inference_config, inference_config['input_seq_length'])
    prefilling_operation_count['attention_K'] = attention_K_operation(model_config, inference_config, inference_config['input_seq_length'])
    prefilling_operation_count['attention_V'] = attention_V_operation(model_config, inference_config, inference_config['input_seq_length'])
    prefilling_operation_count['attention_QK'] = attention_QK_operation(model_config, inference_config, inference_config['input_seq_length'], inference_config['input_seq_length'])
    prefilling_operation_count['attention_softmax'] = attention_softmax_operation(model_config, inference_config, inference_config['input_seq_length'])
    prefilling_operation_count['attention_multV'] = attention_multV_operation(model_config, inference_config, inference_config['input_seq_length'], inference_config['input_seq_length'])
    prefilling_operation_count['attention_out'] = attention_out_operation(model_config, inference_config, inference_config['input_seq_length'])

    prefilling_operation_count['layernorm'] =layernorm_operation(model_config, inference_config, inference_config['input_seq_length'])

    prefilling_operation_count['mlp1'] = mlp1_operation(model_config, inference_config, inference_config['input_seq_length'])
    prefilling_operation_count['mlp2'] = mlp2_operation(model_config, inference_config, inference_config['input_seq_length'])
    
    prefilling_operation_count['embeddings'] = prefilling_operation_count['word_embedding'] + prefilling_operation_count['positional_embedding']
    prefilling_operation_count['attention'] = sum([v for k,v in prefilling_operation_count.items() if 'attention' in k])
    prefilling_operation_count['mlp'] = prefilling_operation_count['mlp1'] + prefilling_operation_count['mlp2']
    prefilling_operation_count['total'] = (prefilling_operation_count['embeddings'] + prefilling_operation_count['attention'] + prefilling_operation_count['mlp'] + prefilling_operation_count['layernorm'])
    
    return prefilling_operation_count

def generation_operation(model_config, inference_config):
    generation_operation_count = {}
    generation_operation_count['word_embedding'] = 0
    generation_operation_count['positional_embedding'] = 0
    generation_operation_count['attention_K'] = 0
    generation_operation_count['attention_V'] = 0
    generation_operation_count['attention_Q'] = 0
    generation_operation_count['attention_QK'] = 0
    generation_operation_count['attention_softmax'] = 0
    generation_operation_count['attention_multV'] = 0
    generation_operation_count['attention_out'] = 0
    generation_operation_count['mlp1'] = 0
    generation_operation_count['mlp2'] = 0
    generation_operation_count['layernorm'] = 0

    for t in range(inference_config['output_seq_length']):
        if inference_config['KV_cache']:
            generation_operation_count['attention_K'] += attention_K_operation(model_config, inference_config, 1)
            generation_operation_count['attention_V'] += attention_V_operation(model_config, inference_config, 1)
            generation_operation_count['attention_Q'] += attention_Q_operation(model_config, inference_config, 1)
            generation_operation_count['attention_QK'] += attention_QK_operation(model_config, inference_config, seq_length_Q=1, seq_length_K=(t+1)+inference_config['input_seq_length'])
            generation_operation_count['attention_softmax'] += attention_softmax_operation(model_config, inference_config, 1)
            generation_operation_count['attention_multV'] += attention_multV_operation(model_config, inference_config, seq_length_Q=1, seq_length_V=(t+1)+inference_config['input_seq_length'])
            generation_operation_count['attention_out'] += attention_out_operation(model_config, inference_config, 1)
            generation_operation_count['mlp1'] += mlp1_operation(model_config, inference_config, 1)
            generation_operation_count['mlp2'] += mlp2_operation(model_config, inference_config, 1)
        else:
            generation_operation_count['attention_K'] += attention_K_operation(model_config, inference_config, (t+1)+inference_config['input_seq_length'])
            generation_operation_count['attention_V'] += attention_V_operation(model_config, inference_config, (t+1)+inference_config['input_seq_length'])
            generation_operation_count['attention_Q'] += attention_Q_operation(model_config, inference_config, (t+1)+inference_config['input_seq_length'])
            generation_operation_count['attention_QK'] += attention_QK_operation(model_config, inference_config, seq_length_Q=(t+1)+inference_config['input_seq_length'], seq_length_K=(t+1)+inference_config['input_seq_length'])
            generation_operation_count['attention_softmax'] += attention_softmax_operation(model_config, inference_config, (t+1)+inference_config['input_seq_length'])
            generation_operation_count['attention_multV'] += attention_multV_operation(model_config, inference_config, seq_length_Q=(t+1)+inference_config['input_seq_length'], seq_length_V=(t+1)+inference_config['input_seq_length'])
            generation_operation_count['attention_out'] += attention_out_operation(model_config, inference_config, (t+1)+inference_config['input_seq_length'])
            generation_operation_count['mlp1'] += mlp1_operation(model_config, inference_config, (t+1)+inference_config['input_seq_length'])
            generation_operation_count['mlp2'] += mlp2_operation(model_config, inference_config, (t+1)+inference_config['input_seq_length'])

        generation_operation_count['layernorm'] += layernorm_operation(model_config, inference_config, (t+1)+inference_config['input_seq_length'])

    generation_operation_count['embeddings'] = generation_operation_count['word_embedding'] + generation_operation_count['positional_embedding'] 
    generation_operation_count['attention'] = sum([v for k,v in generation_operation_count.items() if 'attention' in k])
    generation_operation_count['mlp'] = generation_operation_count['mlp1'] + generation_operation_count['mlp2']
    generation_operation_count['total'] = (generation_operation_count['attention'] + generation_operation_count['mlp'] + generation_operation_count['layernorm'])

    return generation_operation_count