hassonofer commited on
Commit
4fbbca3
·
1 Parent(s): 5ca6824

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -3
app.py CHANGED
@@ -77,7 +77,7 @@ def plot_acc_rate(rate_compare_results_df: pl.DataFrame, width: int = 1000, heig
77
 
78
 
79
  def update_data(
80
- dataset: str, benchmark: str, intermediate: bool, mim: bool, log_x: bool
81
  ) -> tuple[alt.LayerChart, pl.DataFrame]:
82
  compare_results_df = pl.read_csv(f"results_{dataset}.csv")
83
  if intermediate is False:
@@ -87,6 +87,9 @@ def update_data(
87
 
88
  x_scale_type = "log" if log_x is True else "linear"
89
 
 
 
 
90
  # Parameter count
91
  if benchmark == "Parameters":
92
  param_compare_results_df = compare_results_df.unique(subset=["Model name"]).sort(
@@ -122,6 +125,7 @@ def update_data(
122
  for col in output_df.columns
123
  ]
124
  )
 
125
  return (chart, output_df.drop("Mistakes", "Samples"))
126
 
127
 
@@ -174,10 +178,20 @@ def app() -> None:
174
  with gr.Column():
175
  pass
176
 
 
 
 
 
 
 
 
 
 
 
177
  plot = gr.Plot(container=False)
178
- table = gr.Dataframe(show_search=True)
179
 
180
- inputs = [dataset_dropdown, benchmark_dropdown, intermediate, mim, log_x]
181
  outputs = [plot, table]
182
  leaderboard.load(update_data, inputs=inputs, outputs=outputs)
183
 
@@ -186,6 +200,7 @@ def app() -> None:
186
  intermediate.change(update_data, inputs=inputs, outputs=outputs)
187
  mim.change(update_data, inputs=inputs, outputs=outputs)
188
  log_x.change(update_data, inputs=inputs, outputs=outputs)
 
189
 
190
  leaderboard.launch()
191
 
 
77
 
78
 
79
  def update_data(
80
+ dataset: str, benchmark: str, intermediate: bool, mim: bool, log_x: bool, search_bar: str
81
  ) -> tuple[alt.LayerChart, pl.DataFrame]:
82
  compare_results_df = pl.read_csv(f"results_{dataset}.csv")
83
  if intermediate is False:
 
87
 
88
  x_scale_type = "log" if log_x is True else "linear"
89
 
90
+ # Filter models
91
+ compare_results_df = compare_results_df.filter(pl.col("Model name").str.contains(search_bar))
92
+
93
  # Parameter count
94
  if benchmark == "Parameters":
95
  param_compare_results_df = compare_results_df.unique(subset=["Model name"]).sort(
 
125
  for col in output_df.columns
126
  ]
127
  )
128
+
129
  return (chart, output_df.drop("Mistakes", "Samples"))
130
 
131
 
 
178
  with gr.Column():
179
  pass
180
 
181
+ with gr.Row():
182
+ with gr.Column():
183
+ pass
184
+
185
+ with gr.Column(scale=2):
186
+ search_bar = gr.Textbox(label="Model Filter", placeholder="e.g. convnext, efficient|mobile")
187
+
188
+ with gr.Column():
189
+ pass
190
+
191
  plot = gr.Plot(container=False)
192
+ table = gr.Dataframe(show_search="search")
193
 
194
+ inputs = [dataset_dropdown, benchmark_dropdown, intermediate, mim, log_x, search_bar]
195
  outputs = [plot, table]
196
  leaderboard.load(update_data, inputs=inputs, outputs=outputs)
197
 
 
200
  intermediate.change(update_data, inputs=inputs, outputs=outputs)
201
  mim.change(update_data, inputs=inputs, outputs=outputs)
202
  log_x.change(update_data, inputs=inputs, outputs=outputs)
203
+ search_bar.change(update_data, inputs=inputs, outputs=outputs)
204
 
205
  leaderboard.launch()
206