File size: 2,787 Bytes
addbb37
 
 
f8eee5a
 
addbb37
 
 
 
f8eee5a
addbb37
 
 
 
f8eee5a
db805e9
 
f8eee5a
 
 
 
 
 
addbb37
 
f8eee5a
 
 
 
 
 
 
addbb37
 
 
 
 
f8eee5a
 
 
 
 
 
addbb37
 
f8eee5a
 
 
 
 
addbb37
f8eee5a
 
 
 
addbb37
 
 
 
f8eee5a
 
addbb37
 
 
 
 
 
 
 
 
 
f8eee5a
addbb37
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import gradio as gr
import matplotlib.pyplot as plt

import matplotlib.pyplot as plt


def plot_forecast(num_param, batch_size, precision, seq_len):
    # Convert number (input as B)
    num_param = float(num_param) * 1e9

    # Convert precision to bytes
    precision = {"float32": 4, "float16": 2, "bfloat16": 2}[precision]

    # Model Parameters: N×precision
    y1 = num_param * precision / (1000**3)

    # Optimizer States: 2×N×precision
    y2 = 2 * num_param * precision / (1000**3)

    # Activations: B×Sequence Length×K×precision
    K = 4.6894e-4 * num_param + 1.8494e6
    print(K)
    y3 = batch_size * seq_len * K * precision / (1000**3)

    # Gradients: N×precision
    y4 = num_param * precision / (1000**3)

    # Optimizer intermediates: N×precision
    y5 = num_param * precision / (1000**3)

    # Calculate total memory
    total_memory = y1 + y2 + max(y3, y4 + y5)

    fig = plt.figure(figsize=(4, 4))
    ax = fig.add_subplot(111)

    # Create stacked bars
    bar_width = 0.5
    ax.bar(0, y1, width=bar_width, color="r")
    ax.bar(0, y2, bottom=y1, width=bar_width, color="b")
    ax.bar(-bar_width / 4, y3, bottom=y1 + y2, width=bar_width / 2, color="g")
    ax.bar(bar_width / 4, y4, bottom=y1 + y2, width=bar_width / 2, color="y")
    ax.bar(bar_width / 4, y5, bottom=y1 + y2 + y4, width=bar_width / 2, color="c")

    # Add text labels inside the bars
    ax.text(0, y1 / 2, f"Model Parameters ({y1:.1f} GB)", ha="center", va="center", color="white", fontweight="bold")
    ax.text(0, y1 + y2 / 2, f"Optimizer States ({y2:.1f} GB)", ha="center", va="center", color="white", fontweight="bold")
    ax.text(-bar_width / 4, y1 + y2 + y3 / 2, f"Activations\n({y3:.1f} GB)", ha="center", va="center", color="white", fontweight="bold")
    ax.text(bar_width / 4, y1 + y2 + y4 / 2, f"Gradients\n({y4:.1f} GB)", ha="center", va="center", color="white", fontweight="bold")
    ax.text(bar_width / 4, y1 + y2 + y4 + y5 / 2, f"Optimizer\nintermediates\n({y5:.1f} GB)", ha="center", va="center", color="white", fontweight="bold")

    # Or as title
    ax.set_title(f"Total Memory: {total_memory:.1f} GB", fontweight="bold")

    # Remove x-axis
    ax.xaxis.set_visible(False)

    # Set GB as the unit for the y-axis
    ax.set_ylabel("Memory (GB)")

    # Adjust layout
    fig.tight_layout()
    return fig


demo = gr.Interface(
    plot_forecast,
    [
        gr.Number(7, label="Number of parameters (B)"),
        gr.Radio([1, 2, 4, 8, 16, 32, 64, 128], value=8, label="Batch size"),
        gr.Radio(["float32", "float16", "bfloat16"], value="float32", label="Precision"),
        gr.Slider(1, 1000, label="Sequence Length", step=1, value=128),
    ],
    gr.Plot(label="forecast", format="png"),
)

if __name__ == "__main__":
    demo.launch()