File size: 2,220 Bytes
0af560f
 
 
 
f2c15d5
0af560f
 
 
f2c15d5
0af560f
 
 
f2c15d5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0af560f
 
 
 
 
 
 
f2c15d5
 
 
 
 
0af560f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
def activation_memory(
    a, # attention heads
    b, # micro batch size
    h, # hidden dimension size
    h_ff, # feedforward dimension size (often h_ff = 4h)
    L, # number of layers
    s, # sequence length
    mixed=True,
    recomputation="none"
    ):
    
    # https://arxiv.org/pdf/2205.05198
    if mixed:
        bytes_per_value = 2 
    else:
        bytes_per_value = 4

    one_layer_attention = s * b * h * (bytes_per_value * 5 + 1) + ((2 * bytes_per_value + 1) * a * s * s * b) # eq (2)
    one_layer_feedforward_mlp = (s * b * h * bytes_per_value + (s * b * h_ff * bytes_per_value)   # inputs of 1st/2nd linear layers
         + s * b * h_ff * bytes_per_value # inputs of activation function (not really necessary for Relu though)
            + s * b * h)  # dropout
    one_layer_feedforward_swiglu = (s * b * h * bytes_per_value + (s * b * h_ff * bytes_per_value)   # inputs of input/output linear layers
         + s * b * h_ff * bytes_per_value * 3 # inputs of activation function
            + s * b * h)  # dropout (note that dropout is lower-precision - boolean)


    if recomputation == "none":
        one_layer = one_layer_attention # eq (2)
    elif recomputation =="selective":
        one_layer = s * b * h * 34 # eq (6)
    elif recomputation =="full":
        one_layer = s * b * h * 2
    else:
        raise ValueError()
    
    input_dropout = 0  # s * b * h # section 4.3

    total = L * one_layer + input_dropout
        
    return total


def param_grads_opt(
    h, # hidden dimension size
    L, # number of layers
    s, # sequence length
    v, # vocab size
    k=8, # parameters for optimizer (Adam: 8 = 4 bytes moments + 4 bytes variance)
    mixed=True # mixed precision training
    ):
    
    # https://michaelwornow.net/2024/01/18/counting-params-in-transformer
    # note: this is without GQA or MQA
    
    emb = h*(v+s)
    one_layer = 12 * h**2 + 13*h
    other = 2*h

    n = emb + L * one_layer + other
    
    # 3.1 https://arxiv.org/pdf/1910.02054
    
    if mixed:
        k += 4 # additional full precision weights
        bytes_per_paramter = 2
    else:
        bytes_per_paramter = 4
    
    return bytes_per_paramter*n, bytes_per_paramter*n, k*n