File size: 25,813 Bytes
357a956
9791f0c
 
deacdbd
 
9791f0c
deacdbd
aed6a9f
deacdbd
 
 
 
 
 
 
 
 
 
 
9791f0c
 
 
 
 
 
 
 
ac04d4c
9791f0c
ac04d4c
 
 
 
aed6a9f
ac04d4c
9791f0c
 
9f3e0f7
 
 
 
 
 
 
 
 
 
 
 
9791f0c
 
ac04d4c
9791f0c
ac04d4c
 
719f946
ac04d4c
719f946
9791f0c
 
 
 
 
 
 
 
 
 
 
 
ac04d4c
9791f0c
ac04d4c
 
 
 
 
9791f0c
 
deacdbd
9791f0c
 
 
 
 
 
 
 
 
 
9f3e0f7
9791f0c
 
 
 
 
719f946
9791f0c
 
ac04d4c
9791f0c
 
 
 
 
9f3e0f7
9791f0c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
aed6a9f
 
 
9791f0c
 
aed6a9f
 
 
 
9791f0c
 
 
 
 
 
 
 
 
 
 
aed6a9f
9791f0c
 
 
 
 
 
 
 
aed6a9f
 
 
 
9791f0c
 
aed6a9f
 
 
 
 
9791f0c
 
 
 
 
9f3e0f7
aed6a9f
9f3e0f7
9791f0c
aed6a9f
9791f0c
aed6a9f
 
9791f0c
 
 
 
 
5c53556
9f3e0f7
5c53556
9791f0c
 
 
5c53556
9791f0c
 
 
5c53556
9791f0c
5c53556
9791f0c
5c53556
9791f0c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5c53556
9791f0c
5c53556
9791f0c
 
 
5c53556
9791f0c
 
 
deacdbd
5c53556
9791f0c
5c53556
9791f0c
5c53556
9791f0c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
deacdbd
9791f0c
deacdbd
9791f0c
 
 
 
 
deacdbd
9791f0c
 
 
 
 
5c53556
9791f0c
 
 
 
5c53556
 
 
 
 
 
 
 
 
 
 
 
 
9791f0c
 
 
 
 
 
5c53556
9791f0c
5c53556
9791f0c
 
 
 
 
 
 
 
 
 
 
 
deacdbd
9791f0c
 
 
 
 
 
 
 
 
 
 
 
 
9f3e0f7
aed6a9f
deacdbd
 
 
9f3e0f7
 
deacdbd
 
 
 
 
 
9791f0c
ac04d4c
aed6a9f
9791f0c
 
5c53556
9791f0c
 
9f3e0f7
9791f0c
 
 
 
 
 
 
 
7458be0
 
9791f0c
 
 
affd6ff
9791f0c
9b4534c
 
 
 
 
9791f0c
 
 
 
deacdbd
 
aed6a9f
9791f0c
7458be0
9791f0c
9b4534c
9791f0c
9b4534c
9791f0c
 
9b4534c
9791f0c
 
 
 
 
 
deacdbd
9f3e0f7
 
 
 
deacdbd
 
 
 
 
9791f0c
 
9f3e0f7
9791f0c
deacdbd
9791f0c
 
deacdbd
 
 
 
 
 
0b0eb0d
9f3e0f7
deacdbd
 
9791f0c
deacdbd
 
aed6a9f
9791f0c
 
 
 
 
 
 
 
 
 
 
 
deacdbd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9791f0c
 
 
 
 
 
 
 
f39bfde
9791f0c
 
 
 
deacdbd
 
 
 
 
9f3e0f7
deacdbd
 
 
 
aed6a9f
deacdbd
9f3e0f7
 
deacdbd
9f3e0f7
 
deacdbd
9f3e0f7
 
deacdbd
 
9f3e0f7
 
 
 
deacdbd
 
 
 
 
 
 
 
 
 
 
 
 
 
9f3e0f7
deacdbd
 
 
 
9f3e0f7
 
 
 
deacdbd
 
9f3e0f7
deacdbd
 
 
 
 
 
9f3e0f7
deacdbd
 
 
 
9f3e0f7
deacdbd
aed6a9f
 
 
 
deacdbd
aed6a9f
 
 
deacdbd
 
 
 
9f3e0f7
 
 
 
 
deacdbd
 
 
 
 
 
9791f0c
 
 
 
 
 
 
9f3e0f7
9791f0c
 
 
deacdbd
 
9791f0c
deacdbd
9b4534c
9791f0c
 
9f3e0f7
aed6a9f
9b4534c
9791f0c
 
 
 
 
 
9b4534c
9791f0c
 
357a956
 
 
0b0eb0d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
import gradio as gr
import pandas as pd

# col=['Layer number', 'Hidden size', 'FFN Hidden size', 'Sequence length', 'Head number', 'Group number', 
#         'dp', 'tp', 'pp', 'cp', 'GPU numbers', 'Batch size', 'FP8', 'Model parameters', 'Model_states', 'Activation', 'Total']

col=['L', 'H', 'FFN', 'S', 'A', 'G', 
        'DP', 'TP', 'PP', 'CP', 'GPUs', 'B', 'FP8', 'Model parameters (B)', 'Model states (GB)', 'Activation (GB)', 'Total (GB)']

abbr = """
    <div align="center">

    > **Abbreviations of symbols:**
    |Abbr|Full name|Abbr|Full name|Abbr|Full name|Abbr|Full name|Abbr|Full name|Abbr|Full name|
    |---|---|---|---|---|---|---|---|---|---|---|---|
    |L|Layer number|H|Hidden size|FFN|FFN Hidden size|S|Sequence length|A|Head number|G|Group number|

    </div>
    """

def Get_GigaByte(memory):
    return memory / 1024**3

def Get_BillionParameter(parameter):
    return parameter / 1000**3

# model states:
def Compute_Parameters_input(seq_length, hidden_size, vocab_size, act_func, tp):
    num_parameters_word_embedding = hidden_size * vocab_size / tp
    # position embedding
    if act_func == "LLaMA":
        num_parameters_position_embedding = 0 
    else:
        num_parameters_position_embedding = seq_length * hidden_size / tp

    return num_parameters_word_embedding + num_parameters_position_embedding

def Compute_Parameters_output(hidden_size, vocab_size, is_tie_word_embedding, act_func, tp):
    # layernorm: h/2h
    if act_func == "LLaMA":
        num_parameters_output_layernorm = hidden_size # RMSNorm
    else:
        num_parameters_output_layernorm = 2 * hidden_size # LayerNorm

    if is_tie_word_embedding == "True":
        num_parameters_output_embedding = 0 # due to sharedWordEmbedding
    else:
        num_parameters_output_embedding = hidden_size * vocab_size / tp

    return num_parameters_output_layernorm + num_parameters_output_embedding

def Compute_Parameters_attention(hidden_size, kv_hidden_size, is_bias, act_func, tp):
    # attention: 
    # layernorm: h/2h
    if act_func == "LLaMA":
        num_parameters_attention = hidden_size # RMSNorm
    else:
        num_parameters_attention = 2 * hidden_size # LayerNorm
    # QKV weight: 3h*h/tp, bias: 3h/tp
    # output linear weight: h*h/tp, bias: h
    num_parameters_attention_Q_weight = hidden_size * hidden_size / tp
    num_parameters_attention_KV_weight = 2 * kv_hidden_size * hidden_size / tp
    num_parameters_attention_Linear_weight = hidden_size * hidden_size / tp

    num_parameters_attention += num_parameters_attention_Q_weight + num_parameters_attention_KV_weight + num_parameters_attention_Linear_weight
    if is_bias == "True":
        num_parameters_attention += (hidden_size + 2 * kv_hidden_size) / tp + hidden_size
    
    return num_parameters_attention

def Compute_Parameters_mlp(hidden_size, ffn_size, is_bias, act_func, tp):
    # MLP: 
    # layernorm: h/2h
    if act_func == "LLaMA":
        num_parameters_mlp = hidden_size # RMSNorm
    else:
        num_parameters_mlp = 2 * hidden_size # LayerNorm
    # mlp1 weight: h*ffn/tp, bias: ffn/tp
    # mlp2 weight: ffn*h/tp, bias: h
    if act_func == "LLaMA":
        num_parameters_mlp += hidden_size * ffn_size * 3 / tp
        if is_bias == "True":
            num_parameters_mlp += ffn_size * 2 / tp + hidden_size
    else:
        num_parameters_mlp += hidden_size * ffn_size * 2 / tp
        if is_bias == "True":
            num_parameters_mlp += ffn_size / tp + hidden_size
    
    return num_parameters_mlp

def Compute_Parameters(seq_length, vocab_size, layer_num, hidden_size, ffn_size, is_group_query, group_query_num, is_bias, is_tie_word_embedding, act_func, head_num, tp, pp):
    if is_group_query == "False":
        group_query_num = head_num
    kv_hidden_size = hidden_size / head_num * group_query_num
    
    # input part
    num_parameters_input = Compute_Parameters_input(seq_length, hidden_size, vocab_size, act_func, tp)

    # middle layers part
    num_parameters_attention = Compute_Parameters_attention(hidden_size, kv_hidden_size, is_bias, act_func, tp)
    num_parameters_mlp = Compute_Parameters_mlp(hidden_size, ffn_size, is_bias, act_func, tp)
    num_parameters_in_single_layer = num_parameters_attention + num_parameters_mlp
    num_parameters_in_total_layers = num_parameters_in_single_layer * layer_num / pp
    
    # output part
    parameters_output = Compute_Parameters_output(hidden_size, vocab_size, is_tie_word_embedding, act_func, tp)    

    if pp == 1:
        num_parameters_total = (
            num_parameters_input
            + num_parameters_in_total_layers
            + parameters_output # num_parameters_output_layernorm
        )    
    else:
        num_parameters_total = (
            num_parameters_input
            + num_parameters_in_total_layers
        )   
    
    return num_parameters_total

def Compute_Weight(numParametersTotal, precision, is_fp8, is_fp8_init):
    weight_memory = 0
    if precision == "FP32":
        weight_memory = 4 * numParametersTotal
    else:
        weight_memory = 2 * numParametersTotal

    if is_fp8 == "True" and is_fp8_init == "False":
        weight_memory += 2 * numParametersTotal
    
    return weight_memory

def Compute_Gradient(numParametersTotal, g_ty):
    if g_ty == "FP32":
        gradient_memory = 4 * numParametersTotal 
    elif g_ty =="BF16":
        gradient_memory = 2 * numParametersTotal
    
    return gradient_memory

def Compute_Optimizer_states(numParametersTotal, opt_func, o_ty, is_dist_opt, dp, cp):
    if o_ty == "FP32":
        optimizer_memory = 4 * 2 * numParametersTotal 
    elif o_ty =="BF16":
        optimizer_memory = 2 * 2 * numParametersTotal
    
    if is_dist_opt == "True":
        optimizer_memory = optimizer_memory / (dp * cp)

    # for SGD, we have no optimizer states
    if opt_func == "SGD":
        optimizer_memory = 0

    return optimizer_memory

def Compute_Master_weight(numParametersTotal, precision, is_dist_opt, dp, cp):
    if precision == "BF16":
        master_weight_memory = 4 * numParametersTotal
    else:
        master_weight_memory = 0
    if is_dist_opt == "True":
        master_weight_memory = master_weight_memory / (dp * cp)
    
    return master_weight_memory

def Compute_Model_states(seq_length, vocab_size, layer_num, hidden_size, ffn_size, head_num, is_group_query, group_query_num, is_bias, is_tie_word_embedding, act_func,
        dp, tp, pp, cp, is_dist_opt, precision, is_fp8, is_fp8_init, g_ty, opt_func, o_ty):
    numParametersTotal = Compute_Parameters(seq_length, vocab_size, layer_num, hidden_size, ffn_size, is_group_query, group_query_num, is_bias, is_tie_word_embedding, act_func, head_num, tp, pp)

    weight_memory = Compute_Weight(numParametersTotal, precision, is_fp8, is_fp8_init)
    gradient_memory = Compute_Gradient(numParametersTotal, g_ty)
    optimizer_memory = Compute_Optimizer_states(numParametersTotal, opt_func, o_ty, is_dist_opt, dp, cp)
    master_weight_memory = Compute_Master_weight(numParametersTotal, precision, is_dist_opt, dp, cp)

    return numParametersTotal, weight_memory, gradient_memory, optimizer_memory, master_weight_memory, \
            weight_memory + gradient_memory + optimizer_memory + master_weight_memory

# activation memory:
def compute_activation_memory_attention(training_dtype, gemm_dtype, seq_length, b, hidden_size, kv_hidden_size, is_sp, tp):
    # LN 2bsh
    activation_mem_attn_ln = seq_length * b * hidden_size * training_dtype
    if is_sp == "False":
        activation_mem_attn_ln *= tp
    # attention input X, qkv 2bsh/1bsh
    activation_mem_attn_qkv = seq_length * b * hidden_size * gemm_dtype
    if is_sp == "False":
        activation_mem_attn_qkv *= tp
    # attention q 2bsh
    activation_mem_attn_q = seq_length * b * hidden_size * training_dtype
    # attention k and v 4bsh
    activation_mem_attn_kv = seq_length * b * kv_hidden_size * training_dtype * 2
    # attention proj input 2bsh/1bsh
    activation_mem_attn_proj = seq_length * b * hidden_size * gemm_dtype
    # dropout bsh
    activation_mem_attn_dropout = seq_length * b * hidden_size
    if is_sp == "False":
        activation_mem_attn_dropout *= tp
    # bf16: 2+2+2+4+2+1=13bsh
    # fp8: 2+1+2+4+1+1=11bsh
    activation_memory_attn = (
        activation_mem_attn_ln
        + activation_mem_attn_qkv
        + activation_mem_attn_q 
        + activation_mem_attn_kv 
        + activation_mem_attn_proj 
        + activation_mem_attn_dropout
    )
    return activation_memory_attn

def compute_activation_memory_mlp(training_dtype, gemm_dtype, seq_length, b, hidden_size, ffn_size, act_func, is_sp, tp):
    # LN 2bsh
    activation_mem_mlp_ln = seq_length * b * hidden_size * training_dtype
    if is_sp == "False":
        activation_mem_mlp_ln *= tp
    # FC1 2bsh/1bsh
    activation_mem_mlp_fc1 = seq_length * b * hidden_size * gemm_dtype
    if is_sp == "False":
        activation_mem_mlp_fc1 *= tp
    # Act 8bsh
    if act_func == "LLaMA":
        activation_mem_mlp_act = seq_length * b * ffn_size * training_dtype * 2
    else:
        activation_mem_mlp_act = seq_length * b * ffn_size * training_dtype
    # FC2 8bsh/4bsh
    activation_mem_mlp_fc2 = seq_length * b * ffn_size * gemm_dtype
    # dropout bsh
    activation_mem_mlp_dropout = seq_length * b * hidden_size
    if is_sp == "False":
        activation_mem_mlp_dropout *= tp
    # bf16: 2+2+8+8+1=21
    # fp8: 2+1+8+4+1=16
    activation_memory_mlp = (
        activation_mem_mlp_ln
        + activation_mem_mlp_fc1
        + activation_mem_mlp_act
        + activation_mem_mlp_fc2
        + activation_mem_mlp_dropout
    )
    return activation_memory_mlp

def compute_activation_memory_input(seq_length, b, hidden_size, pp):
    # embedding + Dropout
    return 8 * seq_length * b * pp + seq_length * b * hidden_size * pp

def compute_activation_memory_output(seq_length, b, hidden_size, vocab_size):
    # Inputs to output layer and CE loss(bf16, fp32 * 2).
    return 2 * seq_length * b * hidden_size + (2 + 4 + 4) * seq_length * b * vocab_size

def compute_activation_memory_pp(activation_memory, vp, pp, num_microbatches):
    # Multiply by interleaved PP memory factor.
    if vp > 0:
        interleaved_schedule_memory_penalty = 1 + (pp - 1) / (pp * vp)
        activation_memory *= interleaved_schedule_memory_penalty

    # If using non-interleaved schedule, number of microbatches in pipeline can be less than pp_size,
    # so discount accordingly.
    if vp == 0 and pp > 1:
        if num_microbatches > 1:
            activation_memory *= min(1, num_microbatches / pp)

    return activation_memory 

def compute_activation_memory(vocab_size, seq_length, layer_num, b, b_global, head_num, hidden_size, ffn_size, act_func, precision, is_fp8, is_sp, is_group_query, group_query_num, tp, pp, dp, cp, vp):
    # Using formula in Table 2 of https://arxiv.org/pdf/2205.05198.pdf.
    # We are trying to compute the maximum activation footprint, so all calculations in this function
    # are for the first pipeline stage.

    # activation dataType for Training
    if precision == "FP32":
        training_dtype = 4
    else:
        training_dtype = 2

    # activation dataType for GEMM
    if precision == "FP32":
        gemm_dtype = 4
    elif is_fp8 == "False":
        gemm_dtype = 2
    else:
        gemm_dtype = 1

    # kv_hidden_size
    if is_group_query == "False":
        group_query_num = head_num
    kv_hidden_size = hidden_size / head_num * group_query_num

    activation_memory_attn = compute_activation_memory_attention(training_dtype, gemm_dtype, seq_length, b, hidden_size, kv_hidden_size, is_sp, tp)

    activation_memory_mlp = compute_activation_memory_mlp(training_dtype, gemm_dtype, seq_length, b, hidden_size, ffn_size, act_func, is_sp, tp)

    activation_memory = activation_memory_attn + activation_memory_mlp

    activation_memory *= layer_num

    # Now add activation memory required for input embeddings, last LayerNorm and output layer.
    # Input to embedding (pp_size microbatches in flight).
    activation_memory_input = compute_activation_memory_input(seq_length, b, hidden_size, pp)
    activation_memory += activation_memory_input

    # get num_microbatches
    num_microbatches = b_global / b / dp / cp
    activation_memory = compute_activation_memory_pp(activation_memory, vp, pp, num_microbatches)

    if pp == 1:
        # Inputs to output layer and CE loss(fp32).
        activation_memory_output = compute_activation_memory_output(seq_length, b, hidden_size, vocab_size)
        activation_memory += activation_memory_output
    elif pp > 1:
        # Sendrecv memory
        activation_memory += seq_length * b * hidden_size * 2

    # Activation memory is partitioned by TP size due to tensor and sequence model parallelism.
    return activation_memory / tp / cp

# compute_btn.click.function
def Compute_ALL_Model_memory(vocab_size, layer_num, hidden_size, ffn_size, seq_length, head_num, is_group_query, group_query_num, is_bias, is_tie_word_embedding, act_func,
        dp, tp, pp, cp, is_sp, vp, is_dist_opt, b, b_global, precision, is_fp8, is_fp8_init, g_ty, opt_func, o_ty, record_df, count):
    # data type trans
    if is_group_query == "True":
        group_query_num = int(group_query_num)
    else:
        group_query_num = head_num

    # check input
    [result, Error_message] = check_input(dp, tp, pp, cp, hidden_size, head_num, layer_num, seq_length, vp, b, b_global)
    if result == False:
        return Error_message, record_df, count

    # get model states
    numParameters, weight_memory, gradient_memory, optimizer_memory, master_weight_memory, model_states_memory = Compute_Model_states(seq_length, vocab_size, layer_num, hidden_size, 
        ffn_size, head_num, is_group_query, group_query_num, is_bias, is_tie_word_embedding, act_func, dp, tp, pp, cp, is_dist_opt, precision, is_fp8, is_fp8_init, g_ty, opt_func, o_ty)

    # get activation memory 
    activation_memory = compute_activation_memory(vocab_size, seq_length, layer_num, b, b_global, head_num, hidden_size, ffn_size, act_func, precision, is_fp8, is_sp, is_group_query, group_query_num, tp, pp, dp, cp, vp)

    # get model parameters
    numParametersTotal = Compute_Parameters(seq_length, vocab_size, layer_num, hidden_size, ffn_size, is_group_query, group_query_num, is_bias, is_tie_word_embedding, act_func, head_num, 1, 1)
    # get gpu number
    gpu_num = dp * tp * pp * cp

    # get B/GB
    numParametersTotal = round(Get_BillionParameter(numParametersTotal), 3)
    numParameters = round(Get_BillionParameter(numParameters), 3)
    model_states_memory = round(Get_GigaByte(model_states_memory), 3)
    activation_memory = round(Get_GigaByte(activation_memory), 3)
    other_memory = 5
    Total = round(model_states_memory + activation_memory + other_memory, 3)

    # record
    new_row = pd.DataFrame([[layer_num, hidden_size, ffn_size, seq_length, head_num, group_query_num, dp, tp, pp, cp, gpu_num, b, is_fp8, 
                            numParametersTotal, model_states_memory, activation_memory, Total]], 
                            columns=col)
    if count == 1:
        record_df = new_row
    else:    
        record_df = record_df._append(new_row, ignore_index=True)
    count = count + 1

    # return str(gpu_num), str(model_states) + " GB", str(activation) + " GB", str(total) + " GB", table_data
    return f"""
                GPU numbers = {str(gpu_num)}, \n
                Model parameters = {str(numParametersTotal)} B, \n
                Model parameters on each device = {str(numParameters)} B, \n
                Model_states = Weight + Gradient + Optimizer = {str(model_states_memory)} GB, \n
                Activation = {str(activation_memory)} GB, \n
                Other memory = 5 GB, \n
                Total memory consumption = {str(Total)} GB \n
           """, record_df, count

def generate_csv(record_df):
    # 将 DataFrame 保存为 CSV 文件
    csv_filename = "data.csv"
    record_df.to_csv(csv_filename, index=False)
    
    # 返回 CSV 文件路径
    return csv_filename

# formula string
formula = r"""
        > **Note**🔑: In this formula, we assume LLM training with FP8 training.
        > 1. LlaMA-family Model.
        > 2. Interleaved pipeline.
        > 3. bias = False.
        > 4. SP = True. 

        <div align="center">
        <img src=file/T1.jpg width=50%/>
        </div>

        $$
        {Total\ Model\ parameters} = 
        HV + (4H^2 + 3H \times FFN + 2H) \times L + H
        $$

        ***

        <div align="center">
        <img src=file/ms.png width=40%/>
        </div>
        
        $$
        {Model\ states} = 
        (6 + \frac{12}{dp \times cp}) \times
        (\frac{(\frac{4H^2 + 3H \times FFN}{tp} + 2H) \times L}{pp} + \frac{HV}{tp})
        $$

        $$
        {Activation} = 
        (1 + \frac{pp-1}{pp \times vp}) \times
        \frac{(8BS + BSH) \times pp + (15BSH + 5BS \times FFN) \times L}{tp \times cp}
        $$

        ***

        $$
        \\begin{gather}
        {GPU\ numbers} = tp \times pp \times dp \times cp\\\\
        {Total\ memory\ consumption} = {Model\ states} + Activation
        \\end{gather}
        $$
        """

def check_tp(tp, head_num):
    if head_num % tp == 0:
        return True
    else:
        return False

def check_pp(pp, layer_num):
    if layer_num % pp == 0:
        return True
    else:
        return False

def check_cp(cp, seq_length):
    if seq_length % cp == 0:
        return True
    else:
        return False

def check_hidden(hidden_size, head_num):
    if hidden_size % head_num == 0:
        return True
    else:
        return False

def check_b_global(b_global, b, dp, cp):
    if b_global % (b * dp * cp) == 0:
        return True
    else:
        return False

def check_num_microbatch(layer_num, vp, pp, num_microbatches):
    if vp > 0:
        if layer_num % (pp * vp) == 0:
            return True
        else:
            return False

    if vp == 0 and pp > 1:
        if num_microbatches > 1:
            if num_microbatches % pp == 0:
                return True
            else:
                return False
    return True
    

def check_input(dp, tp, pp, cp, hidden_size, head_num, layer_num, seq_length, vp, b, b_global):
    result = True
    Error_message = ""
    if check_tp(tp, head_num) == False:
        result = False
        Error_message += "Error message: Please reset Tensor parallelism or head_num, make head_num % tp = 0. \n"
    if check_pp(pp, layer_num) == False:
        result = False
        Error_message += "Error message: Please reset Pipeline parallelism or layer_num, make layer_num % pp = 0. \n"
    if check_cp(cp, seq_length) == False:
        result = False
        Error_message += "Error message: Please reset Context parallelism or seq_length, make seq_length % cp = 0. \n"
    if check_hidden(hidden_size, head_num) == False:
        result = False
        Error_message += "Error message: Please reset hidden_size or head_num, make hidden_size % head_num = 0. \n"
    if check_b_global(b_global, b, dp, cp) == False:
        result = False
        Error_message += "Error message: Please reset b_global or batch_size, make b_global % (batch_size * dp * cp) = 0. \n"
    if check_num_microbatch(layer_num, vp, pp, b_global / b / dp / cp) == False:
        result = False
        Error_message += "Error message: Please reset b_global or batch_size or layer_num or Virtual Pipeline Size, make layer_num % (pp * vp) = 0, num_microbatches % pp = 0. \n"
    
    return result, Error_message

with gr.Blocks() as demo:
    with gr.Row():
        # Text
        gr.Markdown(
            """
            <div style="text-align: center;">
                <h1>GPU memory calculator 🌀</h1>
                <p style="font-size:16px;">Here's a GPU memory calculator, it helps you to compute memory comsumption in LLM training. </p>
                <p style="font-size:16px;">Note: Flash-attention is enabled by default. </p>
            </div>
            """
        )

    with gr.Row():
        with gr.Column(): 
            # Input 1.[Model Parameters]
            gr.Markdown(
                """
                <h2>Model Parameters:</h2>
                """
            )
            with gr.Accordion("Model Parameters"):
                # with gr.Row():
                act_func = gr.Radio(["LLaMA", "GPT"], value="LLaMA", label="Model type", info="eg. LLaMa: SwiGLU, RoPE, RMSNorm") #, info="Action Function in MLP, whether to use GLU (Gated Linear Unit). [e.g \"True\" for LlaMA, \"False\" for GPT.]")
                with gr.Row():
                    vocab_size = gr.Number(label="Vocab size (V)", value=32000)
                    layer_num = gr.Number(label="Layer number (L)", value=32)
                with gr.Row():
                    hidden_size = gr.Number(label="Hidden size (H)", value=4096)
                    ffn_size = gr.Number(label="FFN Hidden size (FFN)", value=11008)
                with gr.Row():
                    sequence_len = gr.Number(label="Sequence length (S)", value=2048)
                    head_num = gr.Number(label="Number of Attention Heads (A)", value=32)
                with gr.Row():
                    is_group_query = gr.Radio(["True", "False"], value="False", label="Use Group Query Attention")
                    group_query_num = gr.Textbox(label="Number of Query Groups (G)", max_lines=1, value=None, interactive=False)
                with gr.Row():
                    is_bias = gr.Radio(["True", "False"], value="False", label="Use Bias")
                    is_tie_word_embedding = gr.Radio(["True", "False"], value="False", label="Tie word embeddings")
                # change editable function
                def toggle_textbox_editable(radio_value):
                    # 根据 radio_value 的值来决定 textbox 是否可编辑
                    if radio_value == "True":
                        return gr.update(interactive=True, value="96")
                    else:
                        return gr.update(interactive=False, value="")
                # 将 radio 组件的变化连接到函数
                is_group_query.change(toggle_textbox_editable, inputs=is_group_query, outputs=group_query_num)

        with gr.Column():    
            # Input 2.[Parallelism]
            gr.Markdown(
                """
                <h2>Parallelism config:</h2>
                """
            )
            with gr.Accordion("Parallelism config"):
                # with gr.Row():
                dp = gr.Number(label="Data parallelism (dp)", value=2)
                tp = gr.Number(label="Tensor parallelism (tp)", value=2)
                pp = gr.Number(label="Pipeline parallelism (pp)", value=2)
                cp = gr.Number(label="Context parallelism (cp)", value=1)
                # with gr.Row():
                is_sp = gr.Radio(["True", "False"], value="True", label="Sequence parallelism")
                vp = gr.Number(label="Virtual Pipeline Size (vp)")
                is_dist_opt = gr.Radio(["True", "False"], value="True", label="Use Distributed Optimizer(Zero1)")

        with gr.Column():
            # Input 3.[Training Settings]
            gr.Markdown(
                """
                <h2>Training Config:</h2>
                """
            )
            with gr.Accordion("Training Config"):
                # with gr.Row():
                b = gr.Number(label="Micro Batch size (B)", value=4)
                b_global = gr.Number(label="Global Batch size", value=64)
                precision = gr.Dropdown(["FP32", "BF16"], value="BF16", label="Training precision")
                with gr.Row():
                    is_fp8 = gr.Radio(["True", "False"], value="True", label="FP8 Training")
                    is_fp8_init = gr.Radio(["True", "False"], value="True", label="FP8 Initialization(will reduce memory)")
                g_ty = gr.Dropdown(["FP32", "BF16"], value="FP32", label="Gradients Dtype")
                with gr.Row():
                    opt_func = gr.Radio(["Adam", "SGD"], value="Adam", label="Optimizer function")
                    o_ty = gr.Dropdown(["FP32", "BF16"], value="FP32", label="Optimizer State Dtype")

    compute_btn = gr.Button("Compute")
    with gr.Tab("Output"):
        with gr.Column():
            # gr.Markdown(
            #     """
            #     <h1>Output Data:</h1>
            #     """
            # )
            output_text = gr.Textbox(
                label="Compute result", 
                interactive=False, 
            )
            
    with gr.Tab("Formula"):
        formula = formula

        gr.Markdown(
            formula
            , latex_delimiters=[{ "left": "$$", "right": "$$", "display": True }]
        )

    # gr.Markdown(abbr)

    record_df = gr.Dataframe(
        label="Record Table",
        headers=col,
        interactive=False
    )
    download_btn = gr.Button("Download")
    count = gr.Number(label="Row count", value=1, visible=False)
    compute_btn.click(
        fn=Compute_ALL_Model_memory, 
        inputs=[vocab_size, layer_num, hidden_size, ffn_size, sequence_len, head_num, is_group_query, group_query_num, is_bias, is_tie_word_embedding, act_func,
                dp, tp, pp, cp, is_sp, vp, is_dist_opt, b, b_global, precision, is_fp8, is_fp8_init, g_ty, opt_func, o_ty, record_df, count],
        outputs=[output_text, record_df, count]
    )

    output_file=gr.File(label="When you click the download button, the downloaded form will be displayed here.")
    # download func
    download_btn.click(
        fn=generate_csv,
        inputs=record_df,
        outputs=output_file
    )


if __name__ == "__main__":
    demo.launch(share=False, allowed_paths=["/"])