PROBE / app.py
mgyigit's picture
Update app.py
01432f8 verified
raw
history blame
16.4 kB
import gradio as gr
import pandas as pd
import re
import os
import json
import yaml
import matplotlib.pyplot as plt
import seaborn as sns
import plotnine as p9
import sys
sys.path.append('./src')
sys.path.append('.')
from huggingface_hub import HfApi
repo_id = "HUBioDataLab/PROBE"
api = HfApi()
from src.about import *
from src.saving_utils import *
from src.vis_utils import *
from src.bin.PROBE import run_probe
# ------------------------------------------------------------------
# Helper functions --------------------------------------------------
# ------------------------------------------------------------------
def add_new_eval(
human_file,
skempi_file,
model_name_textbox: str,
benchmark_types,
similarity_tasks,
function_prediction_aspect,
function_prediction_dataset,
family_prediction_dataset,
save,
):
"""Validate inputs, run evaluation and (optionally) save results."""
# map the user‐facing labels back to the original codes
try:
benchmark_types_mapped = [benchmark_type_map[b] for b in benchmark_types]
similarity_tasks_mapped = [similarity_tasks_map[s] for s in similarity_tasks]
function_prediction_aspect_mapped = function_prediction_aspect_map[function_prediction_aspect]
family_prediction_dataset_mapped = [family_prediction_dataset_map[f] for f in family_prediction_dataset]
except KeyError as e:
gr.Warning(f"Unrecognized option: {e.args[0]}")
return -1
# validate inputs
if any(task in benchmark_types for task in ['similarity', 'family', 'function']) and human_file is None:
gr.Warning("Human representations are required for similarity, family, or function benchmarks!")
return -1
if 'affinity' in benchmark_types and skempi_file is None:
gr.Warning("SKEMPI representations are required for affinity benchmark!")
return -1
gr.Info("Your submission is being processed…")
representation_name = model_name_textbox
try:
results = run_probe(
benchmark_types,
representation_name,
human_file,
skempi_file,
similarity_tasks,
function_prediction_aspect,
function_prediction_dataset,
family_prediction_dataset,
)
except Exception:
gr.Warning("Your submission has not been processed. Please check your representation files!")
return -1
if save:
save_results(representation_name, benchmark_types, results)
gr.Info("Your submission has been processed and results are saved!")
else:
gr.Info("Your submission has been processed!")
return 0
def refresh_data():
"""Re‑start the space and pull fresh leaderboard CSVs from the HF Hub."""
api.restart_space(repo_id=repo_id)
benchmark_types = ["similarity", "function", "family", "affinity", "leaderboard"]
for benchmark_type in benchmark_types:
path = f"/tmp/{benchmark_type}_results.csv"
if os.path.exists(path):
os.remove(path)
benchmark_types.remove("leaderboard")
download_from_hub(benchmark_types)
# ------- Leaderboard helpers -----------------------------------------------
def update_metrics(selected_benchmarks):
updated_metrics = set()
for benchmark in selected_benchmarks:
updated_metrics.update(benchmark_metric_mapping.get(benchmark, []))
return list(updated_metrics)
def update_leaderboard(selected_methods, selected_metrics):
df = get_baseline_df(selected_methods, selected_metrics)
df = df.applymap(lambda x: round(x, 4) if isinstance(x, (int, float)) else x)
df['Method'] = df['Method'].apply(format_method)
return df
def format_method(name: str) -> str:
"""Wrap a model name in a coloured <span> so Gradio renders it."""
colour = color_dict.get(name, "black")
return f"<span style='color:{colour}; font-weight:bold;'>{name}</span>"
# ――― text colour for the Method column ―――
def style_method(val: str) -> str:
return f"color:{color_dict.get(val, 'black')}; font-weight:bold;"
# ――― background shading for the top-5 of a numeric column ―――
# darkest → lightest
TOP5_GREENS = ["#006400", "#228B22", "#32CD32", "#7CFC00", "#ADFF2F"]
def highlight_top5(col: pd.Series) -> list[str]:
"""Return a list of CSS strings for one numeric column."""
# ignore non-numeric cols
if not pd.api.types.is_numeric_dtype(col):
return [""] * len(col)
ranks = col.rank(ascending=False, method="first") # 1 = best
styles = []
for r in ranks:
if r <= 5:
styles.append(f"background-color:{TOP5_GREENS[int(r)-1]};")
else:
styles.append("")
return styles
# ------- Visualisation helpers ---------------------------------------------
def generate_plot(benchmark_type, methods_selected, x_metric, y_metric, aspect, dataset, single_metric):
plot_path = benchmark_plot(
benchmark_type,
methods_selected,
x_metric,
y_metric,
aspect,
dataset,
single_metric,
)
return plot_path
# ---------------------------------------------------------------------------
# Custom CSS for frozen first column and clearer table styles
# ---------------------------------------------------------------------------
CUSTOM_CSS = """
/* freeze first column */
#leaderboard-table table tr th:first-child,
#leaderboard-table table tr td:first-child {
position: sticky;
left: 0;
background: white;
z-index: 2;
}
/* striped rows for readability */
#leaderboard-table table tr:nth-child(odd) {
background: #fafafa;
}
/* centre numeric cells */
#leaderboard-table td:not(:first-child) {
text-align: center;
}
/* scrollable and taller table */
#leaderboard-table .dataframe-wrap {
max-height: 1200px;
overflow-y: auto;
overflow-x: auto;
}
"""
# ---------------------------------------------------------------------------
# UI definition
# ---------------------------------------------------------------------------
block = gr.Blocks(css=CUSTOM_CSS)
with block:
gr.Markdown(LEADERBOARD_INTRODUCTION)
with gr.Tabs(elem_classes="tab-buttons") as tabs:
# ------------------------------------------------------------------
# 1️⃣ Leaderboard tab
# ------------------------------------------------------------------
with gr.TabItem("🏅 PROBE Leaderboard", elem_id="probe-benchmark-tab-table", id=1):
# small workflow figure at top
gr.Image(
value="./src/data/PROBE_workflow_figure.jpg",
show_label=False,
height=1000,
container=False,
)
gr.Markdown(
"## For detailed explanations of the metrics and benchmarks, please refer to the 📝 About tab.",
elem_classes="leaderboard-note",
)
leaderboard = get_baseline_df(None, None)
method_names = leaderboard['Method'].unique().tolist()
metric_names = leaderboard.columns.tolist(); metric_names.remove('Method')
benchmark_metric_mapping = {
"Semantic Similarity Inference": [m for m in metric_names if m.startswith('sim_')],
"Ontology-based Protein Function Prediction": [m for m in metric_names if m.startswith('func')],
"Drug Target Protein Family Classification": [m for m in metric_names if m.startswith('fam_')],
"Protein-Protein Binding Affinity Estimation": [m for m in metric_names if m.startswith('aff_')],
}
leaderboard_method_selector = gr.CheckboxGroup(
choices=method_names,
label="Select Methods",
value=method_names,
interactive=True,
)
benchmark_type_selector_lb = gr.CheckboxGroup(
choices=list(benchmark_metric_mapping.keys()),
label="Select Benchmark Types",
value=None,
interactive=True,
)
leaderboard_metric_selector = gr.CheckboxGroup(
choices=metric_names,
label="Select Metrics",
value=None,
interactive=True,
)
baseline_value = get_baseline_df(method_names, metric_names)
baseline_value = baseline_value.applymap(lambda x: round(x, 4) if isinstance(x, (int, float)) else x)
#baseline_value['Method'] = baseline_value['Method'].apply(format_method)
baseline_header = ["Method"] + metric_names
baseline_datatype = ['markdown'] + ['number'] * len(metric_names)
styler = (
baseline_value
.style
.applymap(style_method, subset=["Method"]) # text colours
.apply(highlight_top5, axis=0) # green shades
.format(precision=4) # keep tidy numbers
)
with gr.Row(show_progress=True, variant='panel'):
data_component = gr.Dataframe(
value=styler,
headers=baseline_header,
type="pandas",
datatype=baseline_datatype,
interactive=False,
elem_id="leaderboard-table",
pinned_columns=1,
max_height=1000,
show_fullscreen_button=True
)
# callbacks
leaderboard_method_selector.change(
update_leaderboard,
inputs=[leaderboard_method_selector, leaderboard_metric_selector],
outputs=data_component,
)
benchmark_type_selector_lb.change(
lambda selected: update_metrics(selected),
inputs=[benchmark_type_selector_lb],
outputs=leaderboard_metric_selector,
)
leaderboard_metric_selector.change(
update_leaderboard,
inputs=[leaderboard_method_selector, leaderboard_metric_selector],
outputs=data_component,
)
# ------------------------------------------------------------------
# 2️⃣ Visualisation tab
# ------------------------------------------------------------------
with gr.TabItem("📊 Visualization", elem_id="probe-benchmark-tab-visualization", id=2):
gr.Markdown(
"""## **Interactive Visualizations**
Choose a benchmark type; context‑specific options will appear.""",
elem_classes="markdown-text",
)
vis_benchmark_type_selector = gr.Dropdown(
choices=list(benchmark_specific_metrics.keys()),
label="Benchmark Type",
value=None,
)
with gr.Row():
vis_x_metric_selector = gr.Dropdown(choices=[], label="X‑axis Metric", visible=False)
vis_y_metric_selector = gr.Dropdown(choices=[], label="Y‑axis Metric", visible=False)
vis_aspect_type_selector = gr.Dropdown(choices=[], label="Aspect", visible=False)
vis_dataset_selector = gr.Dropdown(choices=[], label="Dataset", visible=False)
vis_single_metric_selector = gr.Dropdown(choices=[], label="Metric", visible=False)
vis_method_selector = gr.CheckboxGroup(
choices=method_names,
label="Methods",
value=method_names,
interactive=True,
)
plot_button = gr.Button("Plot")
with gr.Row(show_progress=True, variant='panel'):
plot_output = gr.Image(label="Plot")
plot_explanation = gr.Markdown(visible=False)
# callbacks
vis_benchmark_type_selector.change(
update_metric_choices,
inputs=[vis_benchmark_type_selector],
outputs=[
vis_x_metric_selector,
vis_y_metric_selector,
vis_aspect_type_selector,
vis_dataset_selector,
vis_single_metric_selector,
],
)
plot_button.click(
generate_plot,
inputs=[
vis_benchmark_type_selector,
vis_method_selector,
vis_x_metric_selector,
vis_y_metric_selector,
vis_aspect_type_selector,
vis_dataset_selector,
vis_single_metric_selector,
],
outputs=[plot_output],
)
# ------------------------------------------------------------------
# 3️⃣ About tab
# ------------------------------------------------------------------
with gr.TabItem("📝 About", elem_id="probe-benchmark-tab-table", id=3):
with gr.Row():
gr.Markdown(LLM_BENCHMARKS_TEXT, elem_classes="markdown-text")
with gr.Row():
gr.Image(
value="./src/data/PROBE_workflow_figure.jpg",
label="PROBE Workflow Figure",
elem_classes="about-image",
)
# ------------------------------------------------------------------
# 4️⃣ Submit tab
# ------------------------------------------------------------------
with gr.TabItem("🚀 Submit here! ", elem_id="probe-benchmark-tab-table", id=4):
with gr.Row():
gr.Markdown(EVALUATION_QUEUE_TEXT, elem_classes="markdown-text")
with gr.Row():
gr.Markdown("# ✉️✨ Submit your model's representation files here!", elem_classes="markdown-text")
with gr.Row():
with gr.Column():
model_name_textbox = gr.Textbox(label="Method name")
benchmark_types = gr.CheckboxGroup(choices=TASK_INFO, label="Benchmark Types", interactive=True)
similarity_tasks = gr.CheckboxGroup(choices=similarity_tasks_options, label="Similarity Tasks", interactive=True)
function_prediction_aspect = gr.Radio(choices=function_prediction_aspect_options, label="Function Prediction Aspects", interactive=True)
family_prediction_dataset = gr.CheckboxGroup(choices=family_prediction_dataset_options, label="Family Prediction Datasets", interactive=True)
function_dataset = gr.Textbox(label="Function Prediction Datasets", visible=False, value="All_Data_Sets")
save_checkbox = gr.Checkbox(label="Save results for leaderboard and visualization", value=True)
with gr.Row():
human_file = gr.File(label="Representation file (CSV) for Human dataset", file_count="single", type='filepath')
skempi_file = gr.File(label="Representation file (CSV) for SKEMPI dataset", file_count="single", type='filepath')
submit_button = gr.Button("Submit Eval")
submission_result = gr.Markdown()
submit_button.click(
add_new_eval,
inputs=[
human_file,
skempi_file,
model_name_textbox,
benchmark_types,
similarity_tasks,
function_prediction_aspect,
function_dataset,
family_prediction_dataset,
save_checkbox,
],
)
# global refresh + citation ---------------------------------------------
with gr.Row():
data_run = gr.Button("Refresh")
data_run.click(refresh_data, outputs=[data_component])
with gr.Accordion("Citation", open=False):
citation_button = gr.Textbox(
value=CITATION_BUTTON_TEXT,
label=CITATION_BUTTON_LABEL,
elem_id="citation-button",
show_copy_button=True,
)
# ---------------------------------------------------------------------------
block.launch()