sayakpaul's picture
sayakpaul HF Staff
Update app.py
b2ee710 verified
raw
history blame
5 kB
# Thanks ChatGPT for pairing.
import pandas as pd
import gradio as gr
import matplotlib.pyplot as plt
# Load CSV once at startup
CSV_PATH = "collated_results.csv" # Place your CSV file here
df = pd.read_csv(CSV_PATH).reset_index(drop=True)
# Prepare dropdown choices
model_choices = sorted(df['model_cls'].dropna().unique().tolist())
metric_choices = ["num_params_B", "flops_G", "time_plain_s", "mem_plain_GB", "time_compile_s", "mem_compile_GB"]
group_choices = ["scenario"]
def filter_float(value):
if isinstance(value, str):
return float(value.split()[0])
return value
# Analysis function using global df
def analyze(analysis_type, n_rows, metric, selected_model):
columns = df.columns
preview_cols = [c for c in columns if c != "model_cls"]
if analysis_type == "Preview Data":
return df[df['model_cls'] == selected_model][preview_cols].head(n_rows), None
# β€”β€”β€” Updated Plot Metric β€”β€”β€”
if analysis_type == "Plot Metric":
plot_df = df[df['model_cls'] == selected_model].dropna(subset=[metric])
# empty-data guard
fig, ax = plt.subplots(figsize=(10, 6))
if plot_df.empty:
ax.text(0.5, 0.5, 'No data for selected model', ha='center', va='center', fontsize=14)
# prettify
for spine in ['top','right']:
ax.spines[spine].set_visible(False)
ax.set_axis_off()
return None, fig
# prepare bars
scenarios = plot_df['scenario']
values = plot_df[metric].map(filter_float)
bars = ax.barh(scenarios, values)
# prettify
fig.set_tight_layout(True)
ax.set_xlabel(metric, fontsize=14)
ax.set_ylabel('Scenario', fontsize=14)
ax.set_title(f"{metric} per Scenario for {selected_model}", fontsize=16)
ax.tick_params(axis='both', labelsize=12)
ax.grid(axis='x', linestyle='--', alpha=0.5)
for spine in ['top','right']:
ax.spines[spine].set_visible(False)
# data labels
for bar in bars:
w = bar.get_width()
ax.text(w + w*0.01, bar.get_y() + bar.get_height()/2,
f"{w:.3f}", va='center', fontsize=12)
return None, fig
# β€”β€”β€” Plot Times per Scenario unchanged (already prettified) β€”β€”β€”
if analysis_type == "Plot Times per Scenario":
filt = df[df['model_cls'] == selected_model]
filt = filt.dropna(subset=['time_plain_s', 'time_compile_s'])
fig, ax = plt.subplots(figsize=(10, 6))
if filt.empty:
ax.text(0.5, 0.5, 'No data for selected model', ha='center', va='center', fontsize=14)
for spine in ['top','right']:
ax.spines[spine].set_visible(False)
ax.set_axis_off()
return None, fig
scenarios = filt['scenario']
plain = filt['time_plain_s'].map(filter_float)
compile = filt['time_compile_s'].map(filter_float)
x = range(len(scenarios))
width = 0.35
bars_plain = ax.bar([i - width/2 for i in x], plain, width=width, label='Plain')
bars_compile = ax.bar([i + width/2 for i in x], compile, width=width, label='Compile')
ax.set_xticks(x)
ax.set_xticklabels(scenarios, rotation=45, ha='right')
ax.set_xlabel('Scenario', fontsize=14)
ax.set_ylabel('Time (s)', fontsize=14)
ax.set_title(f"Plain vs Compile Time for {selected_model}", fontsize=16)
ax.tick_params(axis='both', labelsize=12)
ax.legend(frameon=False)
ax.grid(axis='y', linestyle='--', alpha=0.5)
for spine in ['top','right']:
ax.spines[spine].set_visible(False)
# data labels
for bar in bars_plain + bars_compile:
h = bar.get_height()
ax.text(bar.get_x() + bar.get_width()/2, h + h*0.01,
f"{h:.3f}", ha='center', va='bottom', fontsize=12)
return None, fig
return None, None
# Build Gradio interface
demo = gr.Blocks()
with demo:
gr.Markdown("# [Diffusers Benchmarks](https://huggingface.co/datasets/diffusers/benchmarks) Data Analyzer")
model_dropdown = gr.Dropdown(label="Select model_cls", choices=model_choices, value=model_choices[0])
analysis_type = gr.Radio(
choices=["Preview Data", "Plot Metric", "Plot Times per Scenario"],
label="Analysis Type",
value="Preview Data"
)
n_rows = gr.Slider(5, len(df), step=5, label="Number of rows to preview", value=10)
metric = gr.Dropdown(choices=metric_choices, label="Metric to plot", value="time_plain_s")
analyze_button = gr.Button("Analyze")
tbl_output = gr.Dataframe(headers=None, label="Table Output")
plot_output = gr.Plot(label="Plot Output")
analyze_button.click(
fn=analyze,
inputs=[analysis_type, n_rows, metric, model_dropdown],
outputs=[tbl_output, plot_output]
)
if __name__ == "__main__":
demo.launch()