Update app.py
Browse files
app.py
CHANGED
@@ -128,6 +128,7 @@ def update_leaderboard_dataset_parallel(hivex_env, path):
|
|
128 |
|
129 |
def process_model(model_id):
|
130 |
meta = get_metadata(model_id)
|
|
|
131 |
if meta is None:
|
132 |
return None
|
133 |
user_id = model_id.split("/")[0]
|
@@ -137,20 +138,19 @@ def update_leaderboard_dataset_parallel(hivex_env, path):
|
|
137 |
results = meta["model-index"][0]["results"][0]
|
138 |
row["Task-ID"] = results["task"]["task-id"]
|
139 |
row["Task"] = results["task"]["name"]
|
140 |
-
|
141 |
if "pattern-id" in results["task"] or "difficulty-id" in results["task"]:
|
142 |
key = "Pattern" if "pattern-id" in results["task"] else "Difficulty"
|
143 |
row[key] = (
|
144 |
pattern_map[results["task"]["pattern-id"]]
|
145 |
if "pattern-id" in results["task"]
|
146 |
-
else
|
147 |
)
|
148 |
-
|
149 |
results_metrics = results["metrics"]
|
150 |
-
|
151 |
for result in results_metrics:
|
152 |
row[result["name"]] = float(result["value"].split("+/-")[0].strip())
|
153 |
-
|
154 |
return row
|
155 |
|
156 |
data = list(thread_map(process_model, model_ids, desc="Processing models"))
|
@@ -267,37 +267,17 @@ def filter_data(rl_env, task_id, selected_values, path):
|
|
267 |
"""
|
268 |
data = get_data(rl_env, task_id, path)
|
269 |
|
270 |
-
#
|
271 |
-
|
272 |
-
|
273 |
-
filter_column = "Pattern"
|
274 |
-
elif "Difficulty" in data.columns:
|
275 |
-
filter_column = "Difficulty"
|
276 |
-
|
277 |
-
# If there are selected values and a filter column, filter the DataFrame
|
278 |
-
if selected_values and filter_column:
|
279 |
data = data[data[filter_column].isin(selected_values)]
|
280 |
|
281 |
return data
|
282 |
|
283 |
-
|
284 |
def update_filtered_data(selected_values, rl_env, task_id, path):
|
285 |
filtered_data = filter_data(rl_env, task_id, selected_values, path)
|
286 |
return filtered_data
|
287 |
|
288 |
-
|
289 |
-
def update_checkbox_group(rl_env, path):
|
290 |
-
dp_key, difficulty_pattern_ids = get_difficulty_pattern_ids_and_key(rl_env, path)
|
291 |
-
|
292 |
-
# Debugging: Print to verify what's being returned
|
293 |
-
print(f"Updated dp_key: {dp_key}, difficulty_pattern_ids: {difficulty_pattern_ids}")
|
294 |
-
|
295 |
-
if dp_key and difficulty_pattern_ids:
|
296 |
-
return {"choices": [str(dp_id) for dp_id in difficulty_pattern_ids], "label": dp_key}
|
297 |
-
else:
|
298 |
-
return {"choices": [], "label": "No Data Available"}
|
299 |
-
|
300 |
-
|
301 |
run_update_dataset()
|
302 |
|
303 |
block = gr.Blocks(css=custom_css) # Attach the custom CSS here
|
@@ -335,50 +315,49 @@ with block:
|
|
335 |
dp_key, difficulty_pattern_ids = get_difficulty_pattern_ids_and_key(
|
336 |
hivex_env["hivex_env"], path_
|
337 |
)
|
338 |
-
|
339 |
-
|
340 |
-
|
341 |
-
|
342 |
-
|
343 |
-
|
344 |
selected_checkboxes = gr.CheckboxGroup(
|
345 |
-
|
346 |
-
)
|
347 |
-
|
348 |
-
for task_id in range(0, hivex_env["task_count"]):
|
349 |
-
task_title = convert_to_title_case(
|
350 |
-
get_task(hivex_env["hivex_env"], task_id, path_)
|
351 |
)
|
352 |
-
|
353 |
-
|
354 |
-
|
355 |
-
|
356 |
-
row_count = len(data)
|
357 |
-
|
358 |
-
gr_dataframe = gr.DataFrame(
|
359 |
-
value=data,
|
360 |
-
headers=["User", "Model"],
|
361 |
-
datatype=["markdown", "markdown"],
|
362 |
-
row_count=(row_count, "fixed"),
|
363 |
-
)
|
364 |
-
|
365 |
-
# Use gr.State to hold environment and task information
|
366 |
-
rl_env_state = gr.State(value=hivex_env["hivex_env"])
|
367 |
-
task_id_state = gr.State(value=task_id)
|
368 |
-
path_state = gr.State(value=path_)
|
369 |
-
|
370 |
-
# Add a callback to update the DataFrame when checkboxes are changed
|
371 |
-
selected_checkboxes.change(
|
372 |
-
fn=lambda selected_values, rl_env, task_id, path: (
|
373 |
-
update_filtered_data(selected_values, rl_env, task_id, path),
|
374 |
-
update_checkbox_group(rl_env, path)
|
375 |
-
),
|
376 |
-
inputs=[selected_checkboxes, rl_env_state, task_id_state, path_state],
|
377 |
-
outputs=[gr_dataframe, selected_checkboxes],
|
378 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
379 |
|
380 |
scheduler = BackgroundScheduler()
|
381 |
scheduler.add_job(restart, "interval", seconds=86400)
|
382 |
scheduler.start()
|
383 |
|
384 |
-
block.launch()
|
|
|
128 |
|
129 |
def process_model(model_id):
|
130 |
meta = get_metadata(model_id)
|
131 |
+
# LOADED_MODEL_METADATA[model_id] = meta if meta is not None else ''
|
132 |
if meta is None:
|
133 |
return None
|
134 |
user_id = model_id.split("/")[0]
|
|
|
138 |
results = meta["model-index"][0]["results"][0]
|
139 |
row["Task-ID"] = results["task"]["task-id"]
|
140 |
row["Task"] = results["task"]["name"]
|
|
|
141 |
if "pattern-id" in results["task"] or "difficulty-id" in results["task"]:
|
142 |
key = "Pattern" if "pattern-id" in results["task"] else "Difficulty"
|
143 |
row[key] = (
|
144 |
pattern_map[results["task"]["pattern-id"]]
|
145 |
if "pattern-id" in results["task"]
|
146 |
+
else results["task"]["difficulty-id"]
|
147 |
)
|
148 |
+
|
149 |
results_metrics = results["metrics"]
|
150 |
+
|
151 |
for result in results_metrics:
|
152 |
row[result["name"]] = float(result["value"].split("+/-")[0].strip())
|
153 |
+
|
154 |
return row
|
155 |
|
156 |
data = list(thread_map(process_model, model_ids, desc="Processing models"))
|
|
|
267 |
"""
|
268 |
data = get_data(rl_env, task_id, path)
|
269 |
|
270 |
+
# If there are selected values, filter the DataFrame
|
271 |
+
if selected_values:
|
272 |
+
filter_column = "Pattern" if "Pattern" in data.columns else "Difficulty"
|
|
|
|
|
|
|
|
|
|
|
|
|
273 |
data = data[data[filter_column].isin(selected_values)]
|
274 |
|
275 |
return data
|
276 |
|
|
|
277 |
def update_filtered_data(selected_values, rl_env, task_id, path):
|
278 |
filtered_data = filter_data(rl_env, task_id, selected_values, path)
|
279 |
return filtered_data
|
280 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
281 |
run_update_dataset()
|
282 |
|
283 |
block = gr.Blocks(css=custom_css) # Attach the custom CSS here
|
|
|
315 |
dp_key, difficulty_pattern_ids = get_difficulty_pattern_ids_and_key(
|
316 |
hivex_env["hivex_env"], path_
|
317 |
)
|
318 |
+
|
319 |
+
print(dp_key)
|
320 |
+
print(difficulty_pattern_ids)
|
321 |
+
|
322 |
+
# Check if dp_key is defined and difficulty_pattern_ids is not empty
|
323 |
+
if dp_key is not None and len(difficulty_pattern_ids) > 0:
|
324 |
selected_checkboxes = gr.CheckboxGroup(
|
325 |
+
[str(dp_id) for dp_id in difficulty_pattern_ids], label=dp_key
|
|
|
|
|
|
|
|
|
|
|
326 |
)
|
327 |
+
|
328 |
+
for task_id in range(0, hivex_env["task_count"]):
|
329 |
+
task_title = convert_to_title_case(
|
330 |
+
get_task(hivex_env["hivex_env"], task_id, path_)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
331 |
)
|
332 |
+
with gr.TabItem(f"Task {task_id}: {task_title}"):
|
333 |
+
|
334 |
+
# Display initial data
|
335 |
+
data = get_data(hivex_env["hivex_env"], task_id, path_)
|
336 |
+
row_count = len(data)
|
337 |
+
|
338 |
+
gr_dataframe = gr.DataFrame(
|
339 |
+
value=data,
|
340 |
+
headers=["User", "Model"],
|
341 |
+
datatype=["markdown", "markdown"],
|
342 |
+
row_count=(row_count, "fixed"),
|
343 |
+
)
|
344 |
+
|
345 |
+
# Use gr.State to hold environment and task information
|
346 |
+
rl_env_state = gr.State(value=hivex_env["hivex_env"])
|
347 |
+
task_id_state = gr.State(value=task_id)
|
348 |
+
path_state = gr.State(value=path_)
|
349 |
+
|
350 |
+
# Add a callback to update the DataFrame when checkboxes are changed
|
351 |
+
selected_checkboxes.change(
|
352 |
+
fn=update_filtered_data,
|
353 |
+
inputs=[selected_checkboxes, rl_env_state, task_id_state, path_state],
|
354 |
+
outputs=gr_dataframe,
|
355 |
+
)
|
356 |
+
else:
|
357 |
+
gr.HTML("<p>No difficulty or pattern data available for this environment.</p>")
|
358 |
|
359 |
scheduler = BackgroundScheduler()
|
360 |
scheduler.add_job(restart, "interval", seconds=86400)
|
361 |
scheduler.start()
|
362 |
|
363 |
+
block.launch()
|