import gradio as gr import matplotlib matplotlib.use('Agg') import matplotlib.pyplot as plt import numpy as np from matplotlib.ticker import MultipleLocator HARM_INTRO = """ The Chinchilla scaling laws focus on optimally scaling training compute but often we also care about inference cost. This tool follows [Harm de Vries' blog post](https://www.harmdevries.com/post/model-size-vs-compute-overhead/) and visualizes the tradeoff between training comput and inference cost (i.e. model size). """ ### GPU specs: A100_flops = 312e12 H100_flops = 990e12 ### CHINCHILLA PARAMS: E = 1.62 A = 406.4 B = 410.7 alpha = 0.336 beta = 0.283 Bn = 10**9 G = ((alpha*A)/(beta*B))**(1/(alpha+beta)) ### FUNCTIONS def to_flops(N, D): return 6 * N * D def n_opt(C): return G * ((C/6) ** (beta / (alpha+beta))) def d_opt(C): return (1/G) * ((C/6) ** (alpha / (alpha+beta))) def compute_kd(kn): frac = (A/B)*(G**(-alpha-beta)) kd = (1-((kn**-alpha -1)*frac))**(1/(-beta)) return kd def compute_overhead(kn, kd): return kn*kd - 1 ### PRECOMPUTE CURVE: kn_min = 0.18 kn_max = 2 kns = np.linspace(kn_min, kn_max, 100) overheads = [] for kn in kns: kd = compute_kd(kn) overheads.append(compute_overhead(kn, kd)*100) def plot_curve(kn, kd): fig, ax = plt.subplots(dpi=200, figsize=(5, 3)) plt.plot(kns, overheads, color="black", zorder=1) plt.scatter([kn], [compute_overhead(kn, kd)*100], s=100, marker="o", c="red", label="You are here!", zorder=2) plt.scatter([1.0], [0.0], marker="o", s=100, c="blue", label="Chinchilla optimal", zorder=2) plt.xlabel("Fraction of Chinchilla optimal model size") plt.ylabel("Compute overhead (%)") plt.legend(loc="best") plt.grid(True, which="both") plt.grid(True, which="minor", alpha=0.5) ax.yaxis.set_minor_locator(MultipleLocator(10)) plt.tight_layout() return fig def compute(N, D, gpu_type, gpu_util, n_gpus, gpu_price): C = to_flops(N * Bn, D * Bn) N_opt = n_opt(C) D_opt = d_opt(C) kn = Bn*N/N_opt kd = compute_kd(kn) fig = plot_curve(kn, kd) gpu_util = 0.5 if gpu_type=="H100": gpu_flops = H100_flops * gpu_util else: gpu_flops = A100_flops * gpu_util gpu_hours = (C / (gpu_flops * 3600)) text = f"""\ ## Training summary |Training compute| Training cost | Training time | Total GPU hours | |:----|:-------|:-------|:-------| |{C:.2E} TFLOPs | ${(gpu_hours * gpu_price)/1e6:.2f}M | {gpu_hours/(24*n_gpus):.2f} days | {gpu_hours/1_000_000:.2f}M | ## Chinchilla and Training/Inference Trade-off Optimal model/dataset size for training compute and how it translates to training overhead and inference savings according to Harm's law |Chinchilla optimal model | Chinchilla optimal dataset | Training overhead | Inference savings| |:----|:-------|:----|:-------| | {N_opt/Bn:.2f}B parameters | {D_opt/Bn:.2f}B tokens | {100*compute_overhead(kn, kd):.2f}%| {100 - kn*100:.2f}% | """ return text, fig with gr.Blocks() as demo: gr.Markdown("# LLM training calculator") gr.Markdown("## Training configuration") with gr.Row(): N = gr.Number(value=7, label="Model size (in B parameters):") D = gr.Number(value=2000, label="Dataset size (in B tokens):") gr.Markdown("## Cluster configuration") with gr.Row(): n_gpus = gr.Number(value=1000, label="Number of GPUs") gpu_type = gr.Dropdown(choices=["A100", "H100"], value="H100", label="GPU type") gpu_util = gr.Number(value=50, label="% GPU utilization") gpu_price = gr.Number(value=3.00, label="$/GPU/Hour") button = gr.Button("Compute!") with gr.Row(): with gr.Column(): gr.Markdown("## Harm's law") plot = gr.Plot(value=plt) gr.Markdown(HARM_INTRO) with gr.Column(): md = gr.Markdown("") button.click(fn=compute, inputs=[N, D, gpu_type, gpu_util, n_gpus, gpu_price], outputs=[md, plot]) demo.load(fn=compute, inputs=[N, D, gpu_type, gpu_util, n_gpus, gpu_price], outputs=[md, plot]) demo.launch()