sayakpaul HF Staff commited on
Commit
b087582
·
verified ·
1 Parent(s): 0216c2f

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +127 -0
  2. collated_results.csv +14 -0
app.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import gradio as gr
3
+ import matplotlib.pyplot as plt
4
+
5
+ # Load CSV once at startup
6
+ CSV_PATH = "collated_results.csv" # Place your CSV file here
7
+
8
+ df = pd.read_csv(CSV_PATH).reset_index(drop=True)
9
+
10
+ # Prepare dropdown choices\ nmodel_choices = sorted(df['model_cls'].dropna().unique().tolist())
11
+ metric_choices = ["num_params_B", "flops_G", "time_plain_s", "mem_plain_GB", "time_compile_s", "mem_compile_GB"]
12
+ group_choices = ["scenario"]
13
+
14
+ # Analysis function using global df
15
+ def analyze(analysis_type, n_rows, metric, selected_model):
16
+ columns = df.columns
17
+ preview_cols = [c for c in columns if c != "model_cls"]
18
+
19
+ if analysis_type == "Preview Data":
20
+ return df[df['model_cls'] == selected_model][preview_cols].head(n_rows), None
21
+
22
+ # ——— Updated Plot Metric ———
23
+ if analysis_type == "Plot Metric":
24
+ plot_df = df[df['model_cls'] == selected_model].dropna(subset=[metric])
25
+ # empty-data guard
26
+ fig, ax = plt.subplots(figsize=(10, 6))
27
+ if plot_df.empty:
28
+ ax.text(0.5, 0.5, 'No data for selected model', ha='center', va='center', fontsize=14)
29
+ # prettify
30
+ for spine in ['top','right']:
31
+ ax.spines[spine].set_visible(False)
32
+ ax.set_axis_off()
33
+ return None, fig
34
+
35
+ # prepare bars
36
+ scenarios = plot_df['scenario']
37
+ values = plot_df[metric]
38
+ bars = ax.barh(scenarios, values)
39
+
40
+ # prettify
41
+ fig.set_tight_layout(True)
42
+ ax.set_xlabel(metric, fontsize=14)
43
+ ax.set_ylabel('Scenario', fontsize=14)
44
+ ax.set_title(f"{metric} per Scenario for {selected_model}", fontsize=16)
45
+ ax.tick_params(axis='both', labelsize=12)
46
+ ax.grid(axis='x', linestyle='--', alpha=0.5)
47
+ for spine in ['top','right']:
48
+ ax.spines[spine].set_visible(False)
49
+
50
+ # data labels
51
+ for bar in bars:
52
+ w = bar.get_width()
53
+ ax.text(w + w*0.01, bar.get_y() + bar.get_height()/2,
54
+ f"{w:.3f}", va='center', fontsize=12)
55
+
56
+ return None, fig
57
+
58
+ # ——— Plot Times per Scenario unchanged (already prettified) ———
59
+ if analysis_type == "Plot Times per Scenario":
60
+ filt = df[df['model_cls'] == selected_model]
61
+ filt = filt.dropna(subset=['time_plain_s', 'time_compile_s'])
62
+ fig, ax = plt.subplots(figsize=(10, 6))
63
+ if filt.empty:
64
+ ax.text(0.5, 0.5, 'No data for selected model', ha='center', va='center', fontsize=14)
65
+ for spine in ['top','right']:
66
+ ax.spines[spine].set_visible(False)
67
+ ax.set_axis_off()
68
+ return None, fig
69
+
70
+ scenarios = filt['scenario']
71
+ plain = filt['time_plain_s']
72
+ compile = filt['time_compile_s']
73
+ x = range(len(scenarios))
74
+ width = 0.35
75
+
76
+ bars_plain = ax.bar([i - width/2 for i in x], plain, width=width, label='Plain')
77
+ bars_compile = ax.bar([i + width/2 for i in x], compile, width=width, label='Compile')
78
+
79
+ ax.set_xticks(x)
80
+ ax.set_xticklabels(scenarios, rotation=45, ha='right')
81
+ ax.set_xlabel('Scenario', fontsize=14)
82
+ ax.set_ylabel('Time (s)', fontsize=14)
83
+ ax.set_title(f"Plain vs Compile Time for {selected_model}", fontsize=16)
84
+ ax.tick_params(axis='both', labelsize=12)
85
+ ax.legend(frameon=False)
86
+ ax.grid(axis='y', linestyle='--', alpha=0.5)
87
+ for spine in ['top','right']:
88
+ ax.spines[spine].set_visible(False)
89
+
90
+ # data labels
91
+ for bar in bars_plain + bars_compile:
92
+ h = bar.get_height()
93
+ ax.text(bar.get_x() + bar.get_width()/2, h + h*0.01,
94
+ f"{h:.3f}", ha='center', va='bottom', fontsize=12)
95
+
96
+ return None, fig
97
+
98
+ return None, None
99
+
100
+ # Build Gradio interface
101
+ demo = gr.Blocks()
102
+ with demo:
103
+ gr.Markdown("# CSV Data Analyzer")
104
+
105
+ model_dropdown = gr.Dropdown(label="Select model_cls", choices=model_choices, value=model_choices[0])
106
+
107
+ analysis_type = gr.Radio(
108
+ choices=["Preview Data", "Plot Metric", "Plot Times per Scenario"],
109
+ label="Analysis Type",
110
+ value="Preview Data"
111
+ )
112
+
113
+ n_rows = gr.Slider(5, len(df), step=5, label="Number of rows to preview", value=10)
114
+ metric = gr.Dropdown(choices=metric_choices, label="Metric to plot", value="time_plain_s")
115
+
116
+ analyze_button = gr.Button("Analyze")
117
+ tbl_output = gr.Dataframe(headers=None, label="Table Output")
118
+ plot_output = gr.Plot(label="Plot Output")
119
+
120
+ analyze_button.click(
121
+ fn=analyze,
122
+ inputs=[analysis_type, n_rows, metric, model_dropdown],
123
+ outputs=[tbl_output, plot_output]
124
+ )
125
+
126
+ if __name__ == "__main__":
127
+ demo.launch()
collated_results.csv ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ scenario,model_cls,num_params_B,flops_G,time_plain_s,mem_plain_GB,time_compile_s,mem_compile_GB,fullgraph,mode,github_sha
2
+ black-forest-labs/FLUX.1-dev-bf16,FluxTransformer2DModel,11.9,59529.52,0.534,22.61,0.383,22.7,True,default,807f5113860869a15c6639377745dabce5d0d1ba
3
+ black-forest-labs/FLUX.1-dev-bnb-nf4,FluxTransformer2DModel,5.95,,0.563,6.7,,,,,807f5113860869a15c6639377745dabce5d0d1ba
4
+ black-forest-labs/FLUX.1-dev-layerwise-upcasting,FluxTransformer2DModel,11.9,59529.52,0.599,22.18,,,,,807f5113860869a15c6639377745dabce5d0d1ba
5
+ black-forest-labs/FLUX.1-dev-group-offload-leaf,FluxTransformer2DModel,11.9,59529.52,1.897,0.53,,,,,807f5113860869a15c6639377745dabce5d0d1ba
6
+ Lightricks/LTX-Video-0.9.7-dev-bf16,LTXVideoTransformer3DModel,13.04,167583.45,1.598,25.21,1.079,25.31,True,default,807f5113860869a15c6639377745dabce5d0d1ba
7
+ Lightricks/LTX-Video-0.9.7-dev-layerwise-upcasting,LTXVideoTransformer3DModel,13.04,167583.45,1.656,24.38,,,,,807f5113860869a15c6639377745dabce5d0d1ba
8
+ Lightricks/LTX-Video-0.9.7-dev-group-offload-leaf,LTXVideoTransformer3DModel,13.04,167583.45,2.762,1.04,,,,,807f5113860869a15c6639377745dabce5d0d1ba
9
+ stabilityai/stable-diffusion-xl-base-1.0-bf16,UNet2DConditionModel,2.57,5979.1,0.076,5.05,0.054,5.24,True,default,807f5113860869a15c6639377745dabce5d0d1ba
10
+ stabilityai/stable-diffusion-xl-base-1.0-layerwise-upcasting,UNet2DConditionModel,2.57,5979.1,0.154,4.89,,,,,807f5113860869a15c6639377745dabce5d0d1ba
11
+ stabilityai/stable-diffusion-xl-base-1.0-group-offload-leaf,UNet2DConditionModel,2.57,5979.1,0.524,0.2,,,,,807f5113860869a15c6639377745dabce5d0d1ba
12
+ Wan-AI/Wan2.1-T2V-14B-Diffusers-bf16,WanTransformer3DModel,14.29,785611.67,10.922,31.17,8.456,31.77,True,default,807f5113860869a15c6639377745dabce5d0d1ba
13
+ Wan-AI/Wan2.1-T2V-14B-Diffusers-layerwise-upcasting,WanTransformer3DModel,14.29,785611.67,10.766,26.78,,,,,807f5113860869a15c6639377745dabce5d0d1ba
14
+ Wan-AI/Wan2.1-T2V-14B-Diffusers-group-offload-leaf,WanTransformer3DModel,14.29,785611.67,11.262,4.48,,,,,807f5113860869a15c6639377745dabce5d0d1ba