PROBE / app.py
mgyigit's picture
Update app.py
39e623c verified
raw
history blame
16.5 kB
__all__ = ['block', 'make_clickable_model', 'make_clickable_user', 'get_submissions']
import gradio as gr
import pandas as pd
import re
import os
import json
import yaml
import matplotlib.pyplot as plt
import seaborn as sns
import plotnine as p9
import sys
import zipfile
import tempfile
sys.path.append('./src')
sys.path.append('.')
from src.about import *
from src.saving_utils import *
from src.vis_utils import *
from src.bin.PROBE import run_probe
def add_new_eval(
human_file,
skempi_file,
model_name_textbox: str,
revision_name_textbox: str,
benchmark_types,
similarity_tasks,
function_prediction_aspect,
function_prediction_dataset,
family_prediction_dataset,
save,
):
# Validate required files based on selected benchmarks
if any(task in benchmark_types for task in ['similarity', 'family', 'function']) and human_file is None:
gr.Warning("Human representations are required for similarity, family, or function benchmarks!")
return -1
if 'affinity' in benchmark_types and skempi_file is None:
gr.Warning("SKEMPI representations are required for affinity benchmark!")
return -1
processing_info = gr.Info("Your submission is being processed...")
representation_name = model_name_textbox if revision_name_textbox == '' else revision_name_textbox
try:
results = run_probe(
benchmark_types,
representation_name,
human_file,
skempi_file,
similarity_tasks,
function_prediction_aspect,
function_prediction_dataset,
family_prediction_dataset,
)
except Exception as e:
gr.Warning("Your submission has not been processed. Please check your representation files!")
return -1
# Even if save is False, we store the submission (e.g., temporarily) so that the leaderboard includes it.
if save:
save_results(representation_name, benchmark_types, results)
else:
save_results(representation_name, benchmark_types, results, temporary=True)
return 0
def refresh_data():
benchmark_types = ["similarity", "function", "family", "affinity", "leaderboard"]
for benchmark_type in benchmark_types:
path = f"/tmp/{benchmark_type}_results.csv"
if os.path.exists(path):
os.remove(path)
benchmark_types.remove("leaderboard")
download_from_hub(benchmark_types)
def download_leaderboard_csv():
"""Generates a CSV file for the updated leaderboard."""
df = get_baseline_df(None, None)
tmp_csv = os.path.join(tempfile.gettempdir(), "leaderboard_download.csv")
df.to_csv(tmp_csv, index=False)
return tmp_csv
def generate_plots_based_on_submission(benchmark_types, similarity_tasks, function_prediction_aspect, function_prediction_dataset, family_prediction_dataset):
"""
For each benchmark type selected during submission, generate a plot based on the corresponding extra parameters.
"""
tmp_dir = tempfile.mkdtemp()
plot_files = []
# Get the current leaderboard to retrieve available method names.
leaderboard = get_baseline_df(None, None)
method_names = leaderboard['Method'].unique().tolist()
for btype in benchmark_types:
# For each benchmark type, choose plotting parameters based on additional selections.
if btype == "similarity":
# Use the user-selected similarity tasks (if provided) to determine the metrics.
x_metric = similarity_tasks[0] if similarity_tasks and len(similarity_tasks) > 0 else None
y_metric = similarity_tasks[1] if similarity_tasks and len(similarity_tasks) > 1 else None
elif btype == "function":
x_metric = function_prediction_aspect if function_prediction_aspect else None
y_metric = function_prediction_dataset if function_prediction_dataset else None
elif btype == "family":
# For family, assume that family_prediction_dataset is a list of datasets.
x_metric = family_prediction_dataset[0] if family_prediction_dataset and len(family_prediction_dataset) > 0 else None
y_metric = family_prediction_dataset[1] if family_prediction_dataset and len(family_prediction_dataset) > 1 else None
elif btype == "affinity":
# For affinity, you may use default plotting parameters.
x_metric, y_metric = None, None
else:
x_metric, y_metric = None, None
# Generate the plot using your benchmark_plot function.
# Here, aspect, dataset, and single_metric are passed as None, but you could extend this logic.
plot_img = benchmark_plot(btype, method_names, x_metric, y_metric, None, None, None)
plot_file = os.path.join(tmp_dir, f"{btype}.png")
if isinstance(plot_img, plt.Figure):
plot_img.savefig(plot_file)
plt.close(plot_img)
else:
# If benchmark_plot already returns a file path, use it directly.
plot_file = plot_img
plot_files.append(plot_file)
# Zip all plot images
zip_path = os.path.join(tmp_dir, "submission_plots.zip")
with zipfile.ZipFile(zip_path, "w") as zipf:
for file in plot_files:
zipf.write(file, arcname=os.path.basename(file))
return zip_path
def submission_callback(
human_file,
skempi_file,
model_name_textbox,
revision_name_textbox,
benchmark_types,
similarity_tasks,
function_prediction_aspect,
function_prediction_dataset,
family_prediction_dataset,
save_checkbox,
return_option, # New radio selection: "Leaderboard CSV" or "Plot Results"
):
"""
Runs the evaluation and then returns either a downloadable CSV of the leaderboard
(which includes the new submission) or a ZIP file of plots generated based on the submission's selections.
"""
eval_status = add_new_eval(
human_file,
skempi_file,
model_name_textbox,
revision_name_textbox,
benchmark_types,
similarity_tasks,
function_prediction_aspect,
function_prediction_dataset,
family_prediction_dataset,
save_checkbox,
)
if eval_status == -1:
return "Submission failed. Please check your files and selections.", None
if return_option == "Leaderboard CSV":
csv_path = download_leaderboard_csv()
return "Your leaderboard CSV (including your submission) is ready for download.", csv_path
elif return_option == "Plot Results":
zip_path = generate_plots_based_on_submission(
benchmark_types,
similarity_tasks,
function_prediction_aspect,
function_prediction_dataset,
family_prediction_dataset,
)
return "Your plots are ready for download.", zip_path
else:
return "Submission processed, but no output option was selected.", None
# --------------------------
# Build the Gradio interface
# --------------------------
block = gr.Blocks()
with block:
gr.Markdown(LEADERBOARD_INTRODUCTION)
with gr.Tabs(elem_classes="tab-buttons") as tabs:
with gr.TabItem("🏅 PROBE Leaderboard", elem_id="probe-benchmark-tab-table", id=1):
# Leaderboard tab (unchanged from before)
leaderboard = get_baseline_df(None, None)
method_names = leaderboard['Method'].unique().tolist()
metric_names = leaderboard.columns.tolist()
metrics_with_method = metric_names.copy()
metric_names.remove('Method')
benchmark_metric_mapping = {
"similarity": [metric for metric in metric_names if metric.startswith('sim_')],
"function": [metric for metric in metric_names if metric.startswith('func')],
"family": [metric for metric in metric_names if metric.startswith('fam_')],
"affinity": [metric for metric in metric_names if metric.startswith('aff_')],
}
leaderboard_method_selector = gr.CheckboxGroup(
choices=method_names,
label="Select Methods for the Leaderboard",
value=method_names,
interactive=True
)
benchmark_type_selector = gr.CheckboxGroup(
choices=list(benchmark_metric_mapping.keys()),
label="Select Benchmark Types",
value=None,
interactive=True
)
leaderboard_metric_selector = gr.CheckboxGroup(
choices=metric_names,
label="Select Metrics for the Leaderboard",
value=None,
interactive=True
)
baseline_value = get_baseline_df(method_names, metric_names)
baseline_value = baseline_value.applymap(lambda x: round(x, 4) if isinstance(x, (int, float)) else x)
baseline_header = ["Method"] + metric_names
baseline_datatype = ['markdown'] + ['number'] * len(metric_names)
with gr.Row(show_progress=True, variant='panel'):
data_component = gr.components.Dataframe(
value=baseline_value,
headers=baseline_header,
type="pandas",
datatype=baseline_datatype,
interactive=False,
visible=True,
)
leaderboard_method_selector.change(
get_baseline_df,
inputs=[leaderboard_method_selector, leaderboard_metric_selector],
outputs=data_component
)
benchmark_type_selector.change(
lambda selected_benchmarks: update_metrics(selected_benchmarks),
inputs=[benchmark_type_selector],
outputs=leaderboard_metric_selector
)
leaderboard_metric_selector.change(
get_baseline_df,
inputs=[leaderboard_method_selector, leaderboard_metric_selector],
outputs=data_component
)
with gr.Row():
gr.Markdown(
"""
## **Visualize the Leaderboard Results**
Select options to update the visualization.
"""
)
# (Plotting section remains available as before; not the focus of the submission callback)
benchmark_type_selector_plot = gr.Dropdown(
choices=list(benchmark_specific_metrics.keys()),
label="Select Benchmark Type for Plotting",
value=None
)
with gr.Row():
x_metric_selector = gr.Dropdown(choices=[], label="Select X-axis Metric", visible=False)
y_metric_selector = gr.Dropdown(choices=[], label="Select Y-axis Metric", visible=False)
aspect_type_selector = gr.Dropdown(choices=[], label="Select Aspect Type", visible=False)
dataset_selector = gr.Dropdown(choices=[], label="Select Dataset", visible=False)
single_metric_selector = gr.Dropdown(choices=[], label="Select Metric", visible=False)
method_selector = gr.CheckboxGroup(
choices=method_names,
label="Select Methods to Visualize",
interactive=True,
value=method_names
)
plot_button = gr.Button("Plot")
with gr.Row(show_progress=True, variant='panel'):
plot_output = gr.Image(label="Plot")
benchmark_type_selector_plot.change(
update_metric_choices,
inputs=[benchmark_type_selector_plot],
outputs=[x_metric_selector, y_metric_selector, aspect_type_selector, dataset_selector, single_metric_selector]
)
plot_button.click(
benchmark_plot,
inputs=[benchmark_type_selector_plot, method_selector, x_metric_selector, y_metric_selector, aspect_type_selector, dataset_selector, single_metric_selector],
outputs=plot_output
)
with gr.TabItem("📝 About", elem_id="probe-benchmark-tab-table", id=2):
with gr.Row():
gr.Markdown(LLM_BENCHMARKS_TEXT, elem_classes="markdown-text")
with gr.Row():
gr.Image(
value="./src/data/PROBE_workflow_figure.jpg",
label="PROBE Workflow Figure",
elem_classes="about-image",
)
with gr.TabItem("🚀 Submit here! ", elem_id="probe-benchmark-tab-table", id=3):
with gr.Row():
gr.Markdown(EVALUATION_QUEUE_TEXT, elem_classes="markdown-text")
with gr.Row():
gr.Markdown("# ✉️✨ Submit your model's representation files here!", elem_classes="markdown-text")
with gr.Row():
with gr.Column():
model_name_textbox = gr.Textbox(label="Method name")
revision_name_textbox = gr.Textbox(label="Revision Method Name")
benchmark_types = gr.CheckboxGroup(
choices=TASK_INFO,
label="Benchmark Types",
interactive=True,
)
similarity_tasks = gr.CheckboxGroup(
choices=similarity_tasks_options,
label="Similarity Tasks (if selected)",
interactive=True,
)
function_prediction_aspect = gr.Radio(
choices=function_prediction_aspect_options,
label="Function Prediction Aspects (if selected)",
interactive=True,
)
family_prediction_dataset = gr.CheckboxGroup(
choices=family_prediction_dataset_options,
label="Family Prediction Datasets (if selected)",
interactive=True,
)
function_dataset = gr.Textbox(
label="Function Prediction Datasets",
visible=False,
value="All_Data_Sets"
)
save_checkbox = gr.Checkbox(
label="Save results for leaderboard and visualization",
value=True
)
with gr.Row():
human_file = gr.components.File(
label="The representation file (csv) for Human dataset",
file_count="single",
type='filepath'
)
skempi_file = gr.components.File(
label="The representation file (csv) for SKEMPI dataset",
file_count="single",
type='filepath'
)
# New radio button for output selection.
return_option = gr.Radio(
choices=["Leaderboard CSV", "Plot Results"],
label="Return Output",
value="Leaderboard CSV",
interactive=True,
)
submit_button = gr.Button("Submit Eval")
submission_result_msg = gr.Markdown()
submission_result_file = gr.File()
submit_button.click(
submission_callback,
inputs=[
human_file,
skempi_file,
model_name_textbox,
revision_name_textbox,
benchmark_types,
similarity_tasks,
function_prediction_aspect,
function_dataset,
family_prediction_dataset,
save_checkbox,
return_option,
],
outputs=[submission_result_msg, submission_result_file]
)
with gr.Row():
data_run = gr.Button("Refresh")
data_run.click(refresh_data, outputs=[data_component])
with gr.Accordion("Citation", open=False):
citation_button = gr.Textbox(
value=CITATION_BUTTON_TEXT,
label=CITATION_BUTTON_LABEL,
elem_id="citation-button",
show_copy_button=True,
)
block.launch()